Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/xyz routes #88

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ venv.bak/
.spyderproject
.spyproject

# Pycharm project settings
.idea

# Rope project settings
.ropeproject

Expand Down
5 changes: 5 additions & 0 deletions xpublish/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .base import base_router
from .common import common_router, dataset_collection_router
from .zarr import zarr_router

try:
from .xyz import xyz_router
except ImportError:
pass
79 changes: 79 additions & 0 deletions xpublish/routers/xyz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import xarray as xr
import cachey
from fastapi import APIRouter, Depends, Response, Query, Path
from typing import Optional

from xpublish.utils.cache import CostTimer
from xpublish.utils.api import DATASET_ID_ATTR_KEY
from xpublish.dependencies import get_dataset, get_cache
from xpublish.utils.ows import (
get_image_datashader,
get_bounds,
LayerOptionsMixin,
get_tiles,
)


class XYZRouter(APIRouter, LayerOptionsMixin):
pass


xyz_router = XYZRouter()


def query_builder(time, xleft, xright, ybottom, ytop, xlab, ylab):
query = {}
query.update({xlab: slice(xleft, xright), ylab: slice(ytop, ybottom)})
if time:
query["time"] = time
return query


@xyz_router.get("/tiles/{var}/{z}/{x}/{y}")
@xyz_router.get("/tiles/{var}/{z}/{x}/{y}.{format}")
async def tiles(
var: str = Path(
..., description="Dataset's variable. It defines the map's data layer"
),
z: int = Path(..., description="Tiles' zoom level"),
x: int = Path(..., description="Tiles' column"),
y: int = Path(..., description="Tiles' row"),
format: str = Query("PNG", description="Image format. Default to PNG"),
time: str = Query(
None,
description="Filter by time in time-varying datasets. String time format should match dataset's time format",
),
xlab: str = Query("x", description="Dataset x coordinate label"),
ylab: str = Query("y", description="Dataset y coordinate label"),
cache: cachey.Cache = Depends(get_cache),
dataset: xr.Dataset = Depends(get_dataset),
):

# color mapping settings
datashader_settings = getattr(xyz_router, "datashader_settings")

TMS = getattr(xyz_router, "TMS")

xleft, xright, ybottom, ytop = get_bounds(TMS, z, x, y)

query = query_builder(time, xleft, xright, ybottom, ytop, xlab, ylab)

cache_key = (
dataset.attrs.get(DATASET_ID_ATTR_KEY, "")
+ "/"
+ f"/tiles/{var}/{z}/{x}/{y}.{format}?{time}"
)
response = cache.get(cache_key)

if response is None:
with CostTimer() as ct:

tile = get_tiles(var, dataset, query)

byte_image = get_image_datashader(tile, datashader_settings, format)

response = Response(content=byte_image, media_type=f"image/{format}")

cache.put(cache_key, response, ct.time, len(byte_image))

return response
89 changes: 89 additions & 0 deletions xpublish/utils/ows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from datashader import transfer_functions as tf
import datashader as ds
import xarray as xr
from fastapi import HTTPException
import morecantile


# From Morecantile, morecantile.tms.list()
WEB_CRS = {
3857: "WebMercatorQuad",
32631: "UTM31WGS84Quad",
3978: "CanadianNAD83_LCC",
5482: "LINZAntarticaMapTilegrid",
4326: "WorldCRS84Quad",
5041: "UPSAntarcticWGS84Quad",
3035: "EuropeanETRS89_LAEAQuad",
3395: "WorldMercatorWGS84Quad",
2193: "NZTM2000",
}


class DataValidationError(KeyError):
pass


class LayerOptionsMixin:
def set_options(self, crs_epsg: int = 3857, color_mapping: dict = {}) -> None:

self.datashader_settings = color_mapping.get("datashader_settings")
self.matplotlib_settings = color_mapping.get("matplotlib_settings")

if crs_epsg not in WEB_CRS.keys():
raise DataValidationError(f"User input {crs_epsg} not supported")

self.TMS = morecantile.tms.get(WEB_CRS[crs_epsg])


def get_bounds(TMS, zoom, x, y):

bbx = TMS.xy_bounds(morecantile.Tile(int(x), int(y), int(zoom)))

return bbx.left, bbx.right, bbx.bottom, bbx.top


def get_tiles(var, dataset, query) -> xr.DataArray:

if query.get("time"):
tile = dataset[var].sel(query) # noqa
else:
tile = dataset[var].sel(query) # noqa

if 0 in tile.sizes.values():
raise HTTPException(status_code=406, detail=f"Map outside dataset domain")

return tile


def get_image_datashader(tile, datashader_settings, format):

raster_param = datashader_settings.get("raster", {})
shade_param = datashader_settings.get("shade", {"cmap": ["blue", "red"]})

cvs = ds.Canvas(plot_width=256, plot_height=256)

agg = cvs.raster(tile, **raster_param)

img = tf.shade(agg, **shade_param)

img_io = img.to_bytesio(format)

return img_io.read()


def get_legend():
pass


def validate_dataset(dataset):
dims = dataset.dims
if "x" not in dims or "y" not in dims:
raise DataValidationError(
f" Expected spatial dimension names 'x' and 'y', found: {dims}"
)
if "time" not in dims and len(dims) >= 3:
raise DataValidationError(
f" Expected time dimension name 'time', found: {dims}"
)
if len(dims) > 4:
raise DataValidationError(f" Not implemented for dimensions > 4")