diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index da09e37..8f60f0c 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -8,11 +8,12 @@ on: jobs: test: - name: ${{ matrix.python-version }}-build + name: ${{ matrix.python-version }}-pydantic${{ matrix.pydantic-version }}-build runs-on: ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11"] + pydantic-version: ["<2", ">=2"] steps: - uses: actions/checkout@v3 @@ -26,8 +27,9 @@ jobs: uses: actions/cache@v3.3.1 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/dev-requirements.txt') }} + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-pydantic${{ matrix.pydantic-version }}-${{ hashFiles('**/dev-requirements.txt') }} restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}-pydantic${{ matrix.pydantic-version }} ${{ runner.os }}-pip-${{ matrix.python-version }} ${{ runner.os }}-pip ${{ runner.os }}-pip-dev @@ -36,6 +38,7 @@ jobs: run: | python -m pip install -r dev-requirements.txt python -m pip install --no-deps -e . + python -m pip install "pydantic${{ matrix.pydantic-version }}" python -m pip list - name: Running Tests @@ -44,7 +47,6 @@ jobs: - name: Upload coverage to Codecov uses: codecov/codecov-action@v3.1.4 - if: ${{ matrix.python-version }} == 3.9 with: file: ./coverage.xml fail_ci_if_error: false diff --git a/.readthedocs.yml b/.readthedocs.yml index bdc1c07..10e10e2 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,6 +20,6 @@ build: # Optionally set the version of Python and requirements required to build your docs python: install: - - requirements: docs/requirements.txt - method: pip path: . + - requirements: docs/requirements.txt diff --git a/dev-requirements.txt b/dev-requirements.txt index 3da5902..f177fde 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,11 +1,20 @@ +cachey +dask +fastapi>=0.78.0 fsspec httpx importlib-metadata netcdf4 +numcodecs +numpy +pluggy pooch pytest pytest-mock pytest-sugar pytest-cov requests --r requirements.txt +toolz +uvicorn +xarray +zarr diff --git a/docs/requirements.txt b/docs/requirements.txt index f15cf61..d664298 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ sphinx>=3.1 +pydantic<2.0 sphinx-autosummary-accessors pydata-sphinx-theme sphinx-autodoc-typehints diff --git a/docs/source/getting-started/tutorial/dataset-provider-plugin.py b/docs/source/getting-started/tutorial/dataset-provider-plugin.py index 1b3f9d0..338ab24 100644 --- a/docs/source/getting-started/tutorial/dataset-provider-plugin.py +++ b/docs/source/getting-started/tutorial/dataset-provider-plugin.py @@ -5,7 +5,7 @@ class TutorialDataset(Plugin): - name = 'xarray-tutorial-dataset' + name: str = 'xarray-tutorial-dataset' @hookimpl def get_datasets(self): diff --git a/docs/source/getting-started/tutorial/dataset-router-plugin.py b/docs/source/getting-started/tutorial/dataset-router-plugin.py index fcaf134..693788c 100644 --- a/docs/source/getting-started/tutorial/dataset-router-plugin.py +++ b/docs/source/getting-started/tutorial/dataset-router-plugin.py @@ -1,3 +1,5 @@ +from typing import Sequence + import xarray as xr from fastapi import APIRouter, Depends, HTTPException @@ -5,10 +7,10 @@ class MeanPlugin(Plugin): - name = 'mean' + name: str = 'mean' - dataset_router_prefix = '' - dataset_router_tags = ['mean'] + dataset_router_prefix: str = '' + dataset_router_tags: Sequence[str] = ['mean'] @hookimpl def dataset_router(self, deps: Dependencies): diff --git a/docs/source/user-guide/plugins.md b/docs/source/user-guide/plugins.md index a2afcf4..2496699 100644 --- a/docs/source/user-guide/plugins.md +++ b/docs/source/user-guide/plugins.md @@ -47,10 +47,10 @@ from xpublish import Plugin class HelloWorldPlugin(Plugin): - name = "hello_world" + name: str = "hello_world" ``` -At the minimum, a plugin needs to specify a `name` attribute. +At the minimum, a plugin needs to specify a `name` attribute with a type annotation. For example, `name: str = my_plugin_name`. ### Marking implementation methods @@ -66,7 +66,7 @@ from xpublish import Plugin, hookimpl from fastapi import APIRouter class HelloWorldPlugin(Plugin): - name = "hello_world" + name: str = "hello_world" @hookimpl def app_router(self): @@ -188,7 +188,7 @@ from xpublish import Plugin, Dependencies, hookimpl class DatasetAttrs(Plugin): - name = "dataset-attrs" + name: str = "dataset-attrs" @hookimpl def dataset_router(self, deps: Dependencies): @@ -224,7 +224,7 @@ from xpublish import Plugin, Dependencies, hookimpl class DatasetInfoPlugin(Plugin): - name = "dataset-info" + name: str = "dataset-info" dataset_router_prefix = "/info" dataset_router_tags = ["info"] @@ -257,7 +257,7 @@ from xpublish import Plugin, Dependencies, hookimpl class PluginInfo(Plugin): - name = "plugin_info" + name: str = "plugin_info" app_router_prefix = "/info" app_router_tags = ["info"] @@ -298,7 +298,7 @@ from xpublish import Plugin, hookimpl class TutorialDataset(Plugin): - name = "xarray-tutorial-dataset" + name: str = "xarray-tutorial-dataset" @hookimpl def get_datasets(self): diff --git a/tests/test_rest_api.py b/tests/test_rest_api.py index 37b6a31..601b345 100644 --- a/tests/test_rest_api.py +++ b/tests/test_rest_api.py @@ -64,7 +64,7 @@ def get_dims(dataset: xr.Dataset = Depends(get_dataset)): @pytest.fixture(scope='function') def dataset_plugin(airtemp_ds): class AirtempPlugin(Plugin): - name = 'airtemp' + name: str = 'airtemp' @hookimpl def get_dataset(self, dataset_id: str): @@ -86,7 +86,7 @@ def hello(self): pass class HookSpecPlugin(Plugin): - name = 'hook_spec' + name: str = 'hook_spec' @hookimpl def register_hookspec(self): @@ -98,7 +98,7 @@ def register_hookspec(self): @pytest.fixture(scope='function') def hook_implementation_plugin(): class HookImplementationPlugin(Plugin): - name = 'hook_implementation' + name: str = 'hook_implementation' @hookimpl def hello(self): diff --git a/xpublish/dependencies.py b/xpublish/dependencies.py index fba7e54..5932602 100644 --- a/xpublish/dependencies.py +++ b/xpublish/dependencies.py @@ -9,7 +9,7 @@ from fastapi import Depends from .utils.api import DATASET_ID_ATTR_KEY -from .utils.zarr import create_zmetadata, create_zvariables, zarr_metadata_key +from .utils.zarr import ZARR_METADATA_KEY, create_zmetadata, create_zvariables if TYPE_CHECKING: from .plugins import Plugin # pragma: no cover @@ -90,7 +90,7 @@ def get_zmetadata( ): """FastAPI dependency that returns a consolidated zmetadata dictionary.""" - cache_key = dataset.attrs.get(DATASET_ID_ATTR_KEY, '') + '/' + zarr_metadata_key + cache_key = dataset.attrs.get(DATASET_ID_ATTR_KEY, '') + '/' + ZARR_METADATA_KEY zmeta = cache.get(cache_key) if zmeta is None: diff --git a/xpublish/plugins/hooks.py b/xpublish/plugins/hooks.py index e59d1aa..d65e3d7 100644 --- a/xpublish/plugins/hooks.py +++ b/xpublish/plugins/hooks.py @@ -26,19 +26,24 @@ class Dependencies(BaseModel): """ dataset_ids: Callable[..., List[str]] = Field( - get_dataset_ids, description='Returns a list of all valid dataset ids' + get_dataset_ids, + description='Returns a list of all valid dataset ids', ) dataset: Callable[[str], xr.Dataset] = Field( - get_dataset, description='Returns a dataset using ``//`` in the path.' + get_dataset, + description='Returns a dataset using ``//`` in the path.', ) cache: Callable[..., cachey.Cache] = Field( - get_cache, description='Provide access to :py:class:`cachey.Cache`' + get_cache, + description='Provide access to :py:class:`cachey.Cache`', ) plugins: Callable[..., Dict[str, 'Plugin']] = Field( - get_plugins, description='A dictionary of plugins allowing direct access' + get_plugins, + description='A dictionary of plugins allowing direct access', ) plugin_manager: Callable[..., pluggy.PluginManager] = Field( - get_plugin_manager, description='The plugin manager itself, allowing for maximum creativity' + get_plugin_manager, + description='The plugin manager itself, allowing for maximum creativity', ) def __hash__(self): @@ -64,7 +69,13 @@ def __hash__(self): """Make sure that the plugin is hashable to load with pluggy""" things_to_hash = [] - for e in self.dict(): + # try/except is for pydantic backwards compatibility + try: + model_dict = self.model_dump() + except AttributeError: + model_dict = self.dict() + + for e in model_dict: if isinstance(e, list): things_to_hash.append(tuple(e)) # pragma: no cover else: @@ -115,7 +126,8 @@ def get_datasets(self) -> Iterable[str]: # type: ignore """Return an iterable of dataset ids that the plugin can provide""" @hookspec(firstresult=True) - def get_dataset(self, dataset_id: str) -> Optional[xr.Dataset]: # type: ignore + # type: ignore + def get_dataset(self, dataset_id: str) -> Optional[xr.Dataset]: """Return a dataset by requested dataset_id. If the plugin does not have the dataset, return None diff --git a/xpublish/plugins/included/dataset_info.py b/xpublish/plugins/included/dataset_info.py index b62813f..b08a879 100644 --- a/xpublish/plugins/included/dataset_info.py +++ b/xpublish/plugins/included/dataset_info.py @@ -14,19 +14,22 @@ class DatasetInfoPlugin(Plugin): """Dataset metadata""" - name = 'dataset_info' + name: str = 'dataset_info' dataset_router_prefix: str = '' dataset_router_tags: Sequence[str] = ['dataset_info'] @hookimpl - def dataset_router(self, deps: Dependencies): - router = APIRouter(prefix=self.dataset_router_prefix, tags=list(self.dataset_router_tags)) + def dataset_router(self, deps: Dependencies) -> APIRouter: + router = APIRouter( + prefix=self.dataset_router_prefix, + tags=list(self.dataset_router_tags), + ) @router.get('/') def html_representation( dataset=Depends(deps.dataset), - ): + ) -> HTMLResponse: """Returns the xarray HTML representation of the dataset.""" with xr.set_options(display_style='html'): @@ -43,7 +46,7 @@ def list_keys( @router.get('/dict') def to_dict( dataset=Depends(deps.dataset), - ): + ) -> dict: """The full dataset as a dictionary""" return JSONResponse(dataset.to_dict(data=False)) @@ -51,7 +54,7 @@ def to_dict( def info( dataset=Depends(deps.dataset), cache=Depends(deps.cache), - ): + ) -> dict: """Dataset schema (close to the NCO-JSON schema).""" zvariables = get_zvariables(dataset, cache) @@ -66,6 +69,7 @@ def info( for name, var in zvariables.items(): attrs = meta[f'{name}/{attrs_key}'].copy() attrs.pop('_ARRAY_DIMENSIONS') + info['variables'][name] = { 'type': var.data.dtype.name, 'dimensions': list(var.dims), diff --git a/xpublish/plugins/included/module_version.py b/xpublish/plugins/included/module_version.py index 77f1964..afb911e 100644 --- a/xpublish/plugins/included/module_version.py +++ b/xpublish/plugins/included/module_version.py @@ -3,7 +3,7 @@ """ import importlib import sys -from typing import List +from typing import Sequence from fastapi import APIRouter @@ -14,17 +14,20 @@ class ModuleVersionPlugin(Plugin): """Share the currently loaded versions of key libraries""" - name = 'module_version' + name: str = 'module_version' app_router_prefix: str = '' - app_router_tags: List[str] = ['module_version'] + app_router_tags: Sequence[str] = ['module_version'] @hookimpl - def app_router(self): - router = APIRouter(prefix=self.app_router_prefix, tags=self.app_router_tags) + def app_router(self) -> APIRouter: + router = APIRouter( + prefix=self.app_router_prefix, + tags=self.app_router_tags, + ) @router.get('/versions') - def get_versions(): + def get_versions() -> dict: """Currently loaded versions of key libraries""" versions = dict(get_sys_info() + netcdf_and_hdf5_versions()) modules = [ diff --git a/xpublish/plugins/included/plugin_info.py b/xpublish/plugins/included/plugin_info.py index 36968ac..546bc3b 100644 --- a/xpublish/plugins/included/plugin_info.py +++ b/xpublish/plugins/included/plugin_info.py @@ -12,20 +12,23 @@ class PluginInfo(BaseModel): path: str - version: Optional[str] + version: Optional[str] = None class PluginInfoPlugin(Plugin): """Expose plugin source and version""" - name = 'plugin_info' + name: str = 'plugin_info' app_router_prefix: str = '' app_router_tags: Sequence[str] = ['plugin_info'] @hookimpl - def app_router(self, deps: Dependencies): - router = APIRouter(prefix=self.app_router_prefix, tags=list(self.app_router_tags)) + def app_router(self, deps: Dependencies) -> APIRouter: + router = APIRouter( + prefix=self.app_router_prefix, + tags=list(self.app_router_tags), + ) @router.get('/plugins') def get_plugins( @@ -44,7 +47,8 @@ def get_plugins( version = None # pragma: no cover plugin_info[name] = PluginInfo( - path=f'{plugin_type.__module__}.{plugin.__repr_name__()}', version=version + path=f'{plugin_type.__module__}.{plugin.__repr_name__()}', + version=version, ) return plugin_info diff --git a/xpublish/plugins/included/zarr.py b/xpublish/plugins/included/zarr.py index 26ef012..1f4f4e2 100644 --- a/xpublish/plugins/included/zarr.py +++ b/xpublish/plugins/included/zarr.py @@ -12,7 +12,12 @@ from ...dependencies import get_zmetadata, get_zvariables from ...utils.api import DATASET_ID_ATTR_KEY from ...utils.cache import CostTimer -from ...utils.zarr import encode_chunk, get_data_chunk, jsonify_zmetadata, zarr_metadata_key +from ...utils.zarr import ( + ZARR_METADATA_KEY, + encode_chunk, + get_data_chunk, + jsonify_zmetadata, +) from .. import Dependencies, Plugin, hookimpl logger = logging.getLogger('zarr_api') @@ -21,20 +26,23 @@ class ZarrPlugin(Plugin): """Adds Zarr-like accessing endpoints for datasets""" - name = 'zarr' + name: str = 'zarr' dataset_router_prefix: str = '/zarr' dataset_router_tags: Sequence[str] = ['zarr'] @hookimpl - def dataset_router(self, deps: Dependencies): - router = APIRouter(prefix=self.dataset_router_prefix, tags=list(self.dataset_router_tags)) + def dataset_router(self, deps: Dependencies) -> APIRouter: + router = APIRouter( + prefix=self.dataset_router_prefix, + tags=list(self.dataset_router_tags), + ) - @router.get(f'/{zarr_metadata_key}') + @router.get(f'/{ZARR_METADATA_KEY}') def get_zarr_metadata( dataset=Depends(deps.dataset), cache=Depends(deps.cache), - ): + ) -> dict: """Consolidated Zarr metadata""" zvariables = get_zvariables(dataset, cache) zmetadata = get_zmetadata(dataset, cache, zvariables) @@ -47,7 +55,7 @@ def get_zarr_metadata( def get_zarr_group( dataset=Depends(deps.dataset), cache=Depends(deps.cache), - ): + ) -> dict: """Zarr group data""" zvariables = get_zvariables(dataset, cache) zmetadata = get_zmetadata(dataset, cache, zvariables) @@ -58,7 +66,7 @@ def get_zarr_group( def get_zarr_attrs( dataset=Depends(deps.dataset), cache=Depends(deps.cache), - ): + ) -> dict: """Zarr attributes""" zvariables = get_zvariables(dataset, cache) zmetadata = get_zmetadata(dataset, cache, zvariables) @@ -99,7 +107,11 @@ def get_variable_chunk( arr_meta = zmetadata['metadata'][f'{var}/{array_meta_key}'] da = zvariables[var].data - data_chunk = get_data_chunk(da, chunk, out_shape=arr_meta['chunks']) + data_chunk = get_data_chunk( + da, + chunk, + out_shape=arr_meta['chunks'], + ) echunk = encode_chunk( data_chunk.tobytes(), @@ -107,7 +119,10 @@ def get_variable_chunk( compressor=arr_meta['compressor'], ) - response = Response(echunk, media_type='application/octet-stream') + response = Response( + echunk, + media_type='application/octet-stream', + ) cache.put(cache_key, response, ct.time, len(echunk)) diff --git a/xpublish/plugins/manage.py b/xpublish/plugins/manage.py index be04248..c4d1c09 100644 --- a/xpublish/plugins/manage.py +++ b/xpublish/plugins/manage.py @@ -25,7 +25,9 @@ def find_default_plugins( return plugins -def load_default_plugins(exclude_plugins: Optional[Iterable[str]] = None) -> Dict[str, Plugin]: +def load_default_plugins( + exclude_plugins: Optional[Iterable[str]] = None, +) -> Dict[str, Plugin]: """Find and initialize plugins from entry point group `xpublish.plugin`""" initialized_plugins: Dict[str, Plugin] = {} @@ -36,7 +38,8 @@ def load_default_plugins(exclude_plugins: Optional[Iterable[str]] = None) -> Dic def configure_plugins( - plugins: Dict[str, Type[Plugin]], plugin_configs: Optional[Dict] = None + plugins: Dict[str, Type[Plugin]], + plugin_configs: Optional[Dict] = None, ) -> Dict[str, Plugin]: """Initialize and configure plugins with given dictionary of configurations""" initialized_plugins: Dict[str, Plugin] = {} diff --git a/xpublish/rest.py b/xpublish/rest.py index 2dcd6be..5f74579 100644 --- a/xpublish/rest.py +++ b/xpublish/rest.py @@ -82,7 +82,10 @@ def __init__( self.setup_datasets(datasets or {}) self.setup_plugins(plugins) - routers = normalize_app_routers(routers or [], self._dataset_route_prefix) + routers = normalize_app_routers( + routers or [], + self._dataset_route_prefix, + ) check_route_conflicts(routers) self._routers = routers @@ -119,7 +122,8 @@ def get_datasets_from_plugins(self) -> List[str]: return dataset_ids def get_dataset_from_plugins( - self, dataset_id: str = Path(description='Unique ID of dataset') + self, + dataset_id: str = Path(description='Unique ID of dataset'), ) -> xr.Dataset: """Attempt to load dataset from plugins, otherwise return dataset from passed in dictionary of datasets @@ -144,7 +148,10 @@ def get_dataset_from_plugins( return self._datasets[dataset_id] - def setup_plugins(self, plugins: Optional[Dict[str, Plugin]] = None): + def setup_plugins( + self, + plugins: Optional[Dict[str, Plugin]] = None, + ) -> None: """Initialize and load plugins from entry_points unless explicitly provided Parameters: @@ -168,8 +175,11 @@ def setup_plugins(self, plugins: Optional[Dict[str, Plugin]] = None): self.pm.add_hookspecs(hookspec) def register_plugin( - self, plugin: Plugin, plugin_name: Optional[str] = None, overwrite: bool = False - ): + self, + plugin: Plugin, + plugin_name: Optional[str] = None, + overwrite: bool = False, + ) -> None: """ Register a plugin with the xpublish system @@ -206,14 +216,14 @@ def register_plugin( )(): self.pm.add_hookspecs(hookspec) - def init_cache_kwargs(self, cache_kws): + def init_cache_kwargs(self, cache_kws: dict) -> None: """Set up cache kwargs""" self._cache = None self._cache_kws = {'available_bytes': 1e6} if cache_kws is not None: self._cache_kws.update(cache_kws) - def init_app_kwargs(self, app_kws): + def init_app_kwargs(self, app_kws: dict) -> None: """Set up FastAPI application kwargs""" self._app = None self._app_kws = {} @@ -233,7 +243,7 @@ def plugins(self) -> Dict[str, Plugin]: """Returns the loaded plugins""" return dict(self.pm.list_name_plugin()) - def _init_routers(self, dataset_routers: Optional[APIRouter]): + def _init_routers(self, dataset_routers: Optional[APIRouter]) -> None: """Setup plugin and dataset routers. Needs to run after dataset and plugin setup""" app_routers, plugin_dataset_routers = self.plugin_routers() @@ -282,7 +292,7 @@ def dependencies(self) -> Dependencies: return deps - def _init_dependencies(self): + def _init_dependencies(self) -> None: """Initialize dependencies""" deps = self.dependencies() @@ -292,7 +302,7 @@ def _init_dependencies(self): self._app.dependency_overrides[get_plugins] = deps.plugins self._app.dependency_overrides[get_plugin_manager] = deps.plugin_manager - def _init_app(self): + def _init_app(self) -> FastAPI: """Initiate the FastAPI application.""" self._app = FastAPI(**self._app_kws) @@ -318,7 +328,13 @@ def app(self) -> FastAPI: self._app = self._init_app() return self._app - def serve(self, host: str = '0.0.0.0', port: int = 9000, log_level: str = 'debug', **kwargs): + def serve( + self, + host: str = '0.0.0.0', + port: int = 9000, + log_level: str = 'debug', + **kwargs, + ) -> None: """Serve this FastAPI application via :func:`uvicorn.run`. Parameters @@ -338,7 +354,13 @@ def serve(self, host: str = '0.0.0.0', port: int = 9000, log_level: str = 'debug This method is blocking and does not return. """ - uvicorn.run(self.app, host=host, port=port, log_level=log_level, **kwargs) + uvicorn.run( + self.app, + host=host, + port=port, + log_level=log_level, + **kwargs, + ) class SingleDatasetRest(Rest): @@ -363,7 +385,7 @@ def __init__( super().__init__({}, routers, cache_kws, app_kws, plugins) - def setup_datasets(self, datasets): + def setup_datasets(self, datasets) -> str: """Modifies the dataset loading to instead connect to the single dataset""" self._dataset_route_prefix = '' @@ -373,7 +395,7 @@ def setup_datasets(self, datasets): return self._dataset_route_prefix - def _init_app(self): + def _init_app(self) -> FastAPI: self._app = super()._init_app() self._app.openapi = SingleDatasetOpenAPIOverrider(self._app).openapi diff --git a/xpublish/utils/api.py b/xpublish/utils/api.py index b00041a..ef0562c 100644 --- a/xpublish/utils/api.py +++ b/xpublish/utils/api.py @@ -32,7 +32,10 @@ def normalize_datasets(datasets) -> Dict[str, xr.Dataset]: raise TypeError(error_msg) -def normalize_app_routers(routers: list, prefix: str) -> List[Tuple[APIRouter, Dict]]: +def normalize_app_routers( + routers: list, + prefix: str, +) -> List[Tuple[APIRouter, Dict]]: """Normalise the given list of (dataset-specific) API routers. Add or prepend ``prefix`` to all routers. @@ -56,7 +59,7 @@ def normalize_app_routers(routers: list, prefix: str) -> List[Tuple[APIRouter, D return new_routers -def check_route_conflicts(routers): +def check_route_conflicts(routers) -> None: paths = [] for router, kws in routers: @@ -90,10 +93,10 @@ class SingleDatasetOpenAPIOverrider: """ - def __init__(self, app): + def __init__(self, app) -> None: self._app = app - def openapi(self): + def openapi(self) -> dict: if self._app.openapi_schema: return self._app.openapi_schema @@ -123,7 +126,7 @@ def openapi(self): class JSONResponse(StarletteJSONResponse): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: self._render_kwargs = { 'ensure_ascii': True, 'allow_nan': True, diff --git a/xpublish/utils/info.py b/xpublish/utils/info.py index a8cc923..7e18313 100644 --- a/xpublish/utils/info.py +++ b/xpublish/utils/info.py @@ -8,9 +8,10 @@ import struct import subprocess import sys +from typing import Union -def get_sys_info(): +def get_sys_info() -> list: 'Returns system information as a dict' blob = [] @@ -57,7 +58,7 @@ def get_sys_info(): return blob -def netcdf_and_hdf5_versions(): +def netcdf_and_hdf5_versions() -> list[tuple[str, Union[str, None]]]: libhdf5_version = None libnetcdf_version = None try: diff --git a/xpublish/utils/zarr.py b/xpublish/utils/zarr.py index 67cc98d..b4dd2b7 100644 --- a/xpublish/utils/zarr.py +++ b/xpublish/utils/zarr.py @@ -1,9 +1,14 @@ import copy import logging +from typing import ( + Any, + Optional, +) import dask.array import numpy as np import xarray as xr +from numcodecs.abc import Codec from numcodecs.compat import ensure_ndarray from xarray.backends.zarr import ( DIMENSION_KEY, @@ -12,20 +17,25 @@ extract_zarr_variable_encoding, ) from zarr.meta import encode_fill_value -from zarr.storage import array_meta_key, attrs_key, default_compressor, group_meta_key +from zarr.storage import ( + array_meta_key, + attrs_key, + default_compressor, + group_meta_key, +) from zarr.util import normalize_shape from .api import DATASET_ID_ATTR_KEY -dask_array_type = (dask.array.Array,) -zarr_format = 2 -zarr_consolidated_format = 1 -zarr_metadata_key = '.zmetadata' +DaskArrayType = (dask.array.Array,) +ZARR_FORMAT = 2 +ZARR_CONSOLIDATED_FORMAT = 1 +ZARR_METADATA_KEY = '.zmetadata' logger = logging.getLogger('api') -def _extract_dataset_zattrs(dataset: xr.Dataset): +def _extract_dataset_zattrs(dataset: xr.Dataset) -> dict: """helper function to create zattrs dictionary from Dataset global attrs""" zattrs = {} for k, v in dataset.attrs.items(): @@ -37,7 +47,7 @@ def _extract_dataset_zattrs(dataset: xr.Dataset): return zattrs -def _extract_dataarray_zattrs(da): +def _extract_dataarray_zattrs(da: xr.DataArray) -> dict: """helper function to extract zattrs dictionary from DataArray""" zattrs = {} for k, v in da.attrs.items(): @@ -51,7 +61,10 @@ def _extract_dataarray_zattrs(da): return zattrs -def _extract_dataarray_coords(da, zattrs): +def _extract_dataarray_coords( + da: xr.DataArray, + zattrs: dict, +) -> dict: '''helper function to extract coords from DataArray into a directionary''' if da.coords: # Coordinates are only encoded if there are non-dimension coordinates @@ -63,13 +76,20 @@ def _extract_dataarray_coords(da, zattrs): return zattrs -def _extract_fill_value(da, dtype): +def _extract_fill_value( + da: xr.DataArray, + dtype: np.dtype, +) -> Any: """helper function to extract fill value from DataArray.""" fill_value = da.attrs.pop('_FillValue', None) return encode_fill_value(fill_value, dtype) -def _extract_zarray(da, encoding, dtype): +def _extract_zarray( + da: xr.DataArray, + encoding: dict, + dtype: np.dtype, +) -> dict: """helper function to extract zarr array metadata.""" meta = { 'compressor': encoding.get('compressor', da.encoding.get('compressor', default_compressor)), @@ -79,14 +99,14 @@ def _extract_zarray(da, encoding, dtype): 'fill_value': _extract_fill_value(da, dtype), 'order': 'C', 'shape': list(normalize_shape(da.shape)), - 'zarr_format': zarr_format, + 'zarr_format': ZARR_FORMAT, } if meta['chunks'] is None: meta['chunks'] = da.shape # validate chunks - if isinstance(da.data, dask_array_type): + if isinstance(da.data, DaskArrayType): var_chunks = tuple([c[0] for c in da.data.chunks]) else: var_chunks = da.shape @@ -98,7 +118,7 @@ def _extract_zarray(da, encoding, dtype): return meta -def create_zvariables(dataset): +def create_zvariables(dataset: xr.Dataset) -> dict: """Helper function to create a dictionary of zarr encoded variables.""" zvariables = {} @@ -109,11 +129,14 @@ def create_zvariables(dataset): return zvariables -def create_zmetadata(dataset): +def create_zmetadata(dataset: xr.Dataset) -> dict: """Helper function to create a consolidated zmetadata dictionary.""" - zmeta = {'zarr_consolidated_format': zarr_consolidated_format, 'metadata': {}} - zmeta['metadata'][group_meta_key] = {'zarr_format': zarr_format} + zmeta = { + 'zarr_consolidated_format': ZARR_CONSOLIDATED_FORMAT, + 'metadata': {}, + } + zmeta['metadata'][group_meta_key] = {'zarr_format': ZARR_FORMAT} zmeta['metadata'][attrs_key] = _extract_dataset_zattrs(dataset) for key, dvar in dataset.variables.items(): @@ -124,13 +147,18 @@ def create_zmetadata(dataset): zattrs = _extract_dataarray_coords(da, zattrs) zmeta['metadata'][f'{key}/{attrs_key}'] = zattrs zmeta['metadata'][f'{key}/{array_meta_key}'] = _extract_zarray( - encoded_da, encoding, encoded_da.dtype + encoded_da, + encoding, + encoded_da.dtype, ) return zmeta -def jsonify_zmetadata(dataset: xr.Dataset, zmetadata: dict) -> dict: +def jsonify_zmetadata( + dataset: xr.Dataset, + zmetadata: dict, +) -> dict: """Helper function to convert zmetadata dictionary to a json compatible dictionary. @@ -149,7 +177,11 @@ def jsonify_zmetadata(dataset: xr.Dataset, zmetadata: dict) -> dict: return zjson -def encode_chunk(chunk, filters=None, compressor=None): +def encode_chunk( + chunk: np.typing.ArrayLike, + filters: Optional[list[Codec]] = None, + compressor: Optional[Codec] = None, +) -> np.typing.ArrayLike: """helper function largely copied from zarr.Array""" # apply filters if filters: @@ -169,13 +201,17 @@ def encode_chunk(chunk, filters=None, compressor=None): return cdata -def get_data_chunk(da, chunk_id, out_shape): +def get_data_chunk( + da: xr.DataArray, + chunk_id: str, + out_shape: tuple, +) -> np.typing.ArrayLike: """Get one chunk of data from this DataArray (da). If this is an incomplete edge chunk, pad the returned array to match out_shape. """ ikeys = tuple(map(int, chunk_id.split('.'))) - if isinstance(da, dask_array_type): + if isinstance(da, DaskArrayType): chunk_data = da.blocks[ikeys] else: if da.ndim > 0 and ikeys != ((0,) * da.ndim): @@ -187,7 +223,7 @@ def get_data_chunk(da, chunk_id, out_shape): logger.debug('checking chunk output size, %s == %s' % (chunk_data.shape, out_shape)) - if isinstance(chunk_data, dask_array_type): + if isinstance(chunk_data, DaskArrayType): chunk_data = chunk_data.compute() # zarr expects full edge chunks, contents out of bounds for the array are undefined