From 79a2c91944f96d17467d7482f88df928a3abce9a Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Wed, 11 Oct 2023 21:40:40 -0400 Subject: [PATCH] Parse string attributes in constructor --- pyproject.toml | 2 +- src/gval/accessors/gval_xarray.py | 36 +--------------- src/gval/comparison/tabulation.py | 2 - src/gval/utils/loading_datasets.py | 68 +++++++++++++++++++++++++++--- tests/test_catalogs.py | 18 +++++++- tests/test_compare.py | 5 +++ 6 files changed, 85 insertions(+), 46 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ee26875b..f6a07cd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ requires-python = ">=3.8" keywords = ["geospatial", "evaluations"] license = {text = "MIT"} -version = "0.2.2.post1" +version = "0.2.3" dynamic = ["readme", "dependencies"] diff --git a/src/gval/accessors/gval_xarray.py b/src/gval/accessors/gval_xarray.py index 09ab2849..fbe9ff9f 100644 --- a/src/gval/accessors/gval_xarray.py +++ b/src/gval/accessors/gval_xarray.py @@ -1,7 +1,6 @@ from typing import Iterable, Optional, Tuple, Union, Callable, Dict, List from numbers import Number from functools import partial -import ast import numpy as np import numba as nb @@ -22,6 +21,7 @@ from gval.comparison.compute_categorical_metrics import _compute_categorical_metrics from gval.comparison.compute_continuous_metrics import _compute_continuous_metrics from gval.attributes.attributes import _attribute_tracking_xarray +from gval.utils.loading_datasets import _parse_string_attributes from gval.utils.schemas import Crosstab_df, Metrics_df, AttributeTrackingDf from gval.utils.visualize import _map_plot from gval.comparison.pairing_functions import difference @@ -41,7 +41,7 @@ class GVALXarray: """ def __init__(self, xarray_obj): - self._obj = xarray_obj + self._obj = _parse_string_attributes(xarray_obj) self.data_type = type(xarray_obj) self.agreement_map_format = "raster" @@ -97,38 +97,6 @@ def __handle_attribute_tracking( return results - def parse_string_attributes(self): - """ - Parses string attributes stored in rasters - """ - - if "pairing_dictionary" in self._obj.attrs and isinstance( - self._obj.attrs["pairing_dictionary"], str - ): - eval = ast.literal_eval( - self._obj.attrs["pairing_dictionary"].replace("nan", '"nan"') - ) - self._obj.attrs["pairing_dictionary"] = { - (float(k[0]), float(k[1])): float(v) for k, v in eval.items() - } - - if isinstance(self._obj, xr.Dataset): - for var in self._obj.data_vars: - self._obj[var].attrs["pairing_dictionary"] = self._obj.attrs[ - "pairing_dictionary" - ] - - def attributes_to_string(self): # pragma: no cover - """ - Converts attributes to string to mimic a raster loaded from disk - """ - if "pairing_dictionary" in self._obj.attrs and isinstance( - self._obj.attrs["pairing_dictionary"], dict - ): - self._obj.attrs["pairing_dictionary"] = str( - self._obj.attrs["pairing_dictionary"] - ) - @Comparison.comparison_function_from_string def categorical_compare( self, diff --git a/src/gval/comparison/tabulation.py b/src/gval/comparison/tabulation.py index 7dd16cbe..df6cd0c1 100644 --- a/src/gval/comparison/tabulation.py +++ b/src/gval/comparison/tabulation.py @@ -106,7 +106,6 @@ def not_nan(number): return not np.isnan(number) # Handle pairing dictionary attribute - agreement_map.gval.parse_string_attributes() pairing_dict = agreement_map.attrs["pairing_dictionary"] rev_dict = {} @@ -240,7 +239,6 @@ def _crosstab_Datasets(agreement_map: xr.DataArray) -> DataFrame[Crosstab_df]: # gets variable names agreement_variable_names = list(agreement_map.data_vars) - agreement_map.gval.parse_string_attributes() # loop variables previous_crosstab_df = None # initializing to avoid having unset diff --git a/src/gval/utils/loading_datasets.py b/src/gval/utils/loading_datasets.py index 556a8dc0..bb61dae9 100644 --- a/src/gval/utils/loading_datasets.py +++ b/src/gval/utils/loading_datasets.py @@ -7,6 +7,7 @@ from typing import Union, Optional, Tuple, Dict, Any import os +import ast import rioxarray as rxr import xarray as xr @@ -192,8 +193,13 @@ def _handle_xarray_memory( if band_as_var else data_obj.encoding["source"] ) - new_obj = rxr.open_rasterio( - file_name, mask_and_scale=True, band_as_variable=band_as_var, cache=cache + new_obj = _parse_string_attributes( + rxr.open_rasterio( + file_name, + mask_and_scale=True, + band_as_variable=band_as_var, + cache=cache, + ) ) del data_obj return new_obj @@ -207,11 +213,13 @@ def _handle_xarray_memory( data_obj.rio.to_raster(in_file.name, tiled=True, windowed=True) del data_obj cog_translate(in_file.name, out_file.name, dst_profile, in_memory=True) - return rxr.open_rasterio( - out_file.name, - mask_and_scale=True, - band_as_variable=band_as_var, - cache=cache, + return _parse_string_attributes( + rxr.open_rasterio( + out_file.name, + mask_and_scale=True, + band_as_variable=band_as_var, + cache=cache, + ) ) @@ -236,3 +244,49 @@ def _check_dask_array(original_map: Union[xr.DataArray, xr.Dataset]) -> bool: else original_map.chunks ) return chunks is not None + + +def _parse_string_attributes( + obj: Union[xr.DataArray, xr.Dataset] +) -> Union[xr.DataArray, xr.Dataset]: + """ + Parses string attributes stored in rasters + + Parameters + ---------- + obj: Union[xr.DataArray, xr.Dataset] + Xarray object with possible string attributes + + Returns + ------- + Union[xr.DataArray, xr.Dataset] + Object returned with parsed attributes + """ + + if "pairing_dictionary" in obj.attrs and isinstance( + obj.attrs["pairing_dictionary"], str + ): + eval_str = ast.literal_eval( + obj.attrs["pairing_dictionary"].replace("nan", '"nan"') + ) + obj.attrs["pairing_dictionary"] = { + (float(k[0]), float(k[1])): float(v) for k, v in eval_str.items() + } + + if isinstance(obj, xr.Dataset): + for var in obj.data_vars: + obj[var].attrs["pairing_dictionary"] = obj.attrs["pairing_dictionary"] + + if isinstance(obj, xr.Dataset): + for var in obj.data_vars: + if "pairing_dictionary" in obj[var].attrs and isinstance( + obj[var].attrs["pairing_dictionary"], str + ): + eval_str = ast.literal_eval( + obj[var].attrs["pairing_dictionary"].replace("nan", '"nan"') + ) + obj[var].attrs["pairing_dictionary"] = { + (float(k[0]), float(k[1])): float(v) for k, v in eval_str.items() + } + + return obj diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index 4577af29..8ffc8820 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -93,6 +93,17 @@ def tmp_path_apply(row): # check that the values are the same pd.testing.assert_frame_equal(agreement_catalog, expected) + def attributes_to_string(obj): # pragma: no cover + """ + Converts attributes to string to mimic a raster loaded from disk + """ + if "pairing_dictionary" in obj.attrs and isinstance( + obj.attrs["pairing_dictionary"], dict + ): + obj.attrs["pairing_dictionary"] = str(obj.attrs["pairing_dictionary"]) + + return obj + # load agreement maps and check metadata if agreement_map_field is not None: # load agreement maps with apply and check that they are the same @@ -105,8 +116,11 @@ def load_agreement_and_check(row, counter=[0]): expected_agreement_map_xr = rxr.open_rasterio( expected_agreement_map[counter[0]], **open_kwargs ) - agreement_map.gval.attributes_to_string() - xr.testing.assert_identical(agreement_map, expected_agreement_map_xr) + + xr.testing.assert_identical( + attributes_to_string(agreement_map), + attributes_to_string(expected_agreement_map_xr), + ) # increment counter counter[0] += 1 diff --git a/tests/test_compare.py b/tests/test_compare.py index 4c034d1b..54a2e8ca 100644 --- a/tests/test_compare.py +++ b/tests/test_compare.py @@ -130,6 +130,7 @@ def test_difference(c, b, expected_value): @parametrize_with_cases("agreement_map, expected_df", glob="crosstab_2d_DataArrays") def test_crosstab_2d_DataArrays(agreement_map, expected_df): """Test crosstabbing agreement DataArrays""" + agreement_map.gval # Initialize gval for attributes crosstab_df = _crosstab_2d_DataArrays(agreement_map) pd.testing.assert_frame_equal(crosstab_df, expected_df, check_dtype=False) @@ -137,6 +138,7 @@ def test_crosstab_2d_DataArrays(agreement_map, expected_df): @parametrize_with_cases("agreement_map, expected_df", glob="crosstab_3d_DataArrays") def test_crosstab_3d_DataArrays(agreement_map, expected_df): """Test crosstabbing agreement DataArrays""" + agreement_map.gval # Initialize gval for attributes crosstab_df = _crosstab_3d_DataArrays(agreement_map) pd.testing.assert_frame_equal(crosstab_df, expected_df, check_dtype=False) @@ -146,6 +148,7 @@ def test_crosstab_3d_DataArrays(agreement_map, expected_df): ) def test_crosstab_DataArrays_success(agreement_map, expected_df): """Test crosstabbing agreement DataArrays""" + agreement_map.gval # Initialize gval for attributes crosstab_df = _crosstab_DataArrays(agreement_map) pd.testing.assert_frame_equal(crosstab_df, expected_df, check_dtype=False) @@ -153,6 +156,7 @@ def test_crosstab_DataArrays_success(agreement_map, expected_df): @parametrize_with_cases("agreement_map", glob="crosstab_DataArrays_fail") def test_crosstab_DataArrays_fail(agreement_map): """Test crosstabbing agreement DataArrays""" + agreement_map.gval # Initialize gval for attributes with raises(ValueError): _crosstab_DataArrays(agreement_map) @@ -160,6 +164,7 @@ def test_crosstab_DataArrays_fail(agreement_map): @parametrize_with_cases("agreement_map, expected_df", glob="crosstab_Datasets") def test_crosstab_Datasets(agreement_map, expected_df): """Test crosstabbing agreement datasets""" + agreement_map.gval # Initialize gval for attributes crosstab_df = _crosstab_Datasets(agreement_map) # takes band_# pattern to just # crosstab_df["band"] = crosstab_df["band"].apply(lambda x: x.split("_")[-1])