Skip to content

Commit

Permalink
Parse string attributes in constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryPetrochenkov-NOAA committed Oct 12, 2023
1 parent 0e11fc1 commit 79a2c91
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [
requires-python = ">=3.8"
keywords = ["geospatial", "evaluations"]
license = {text = "MIT"}
version = "0.2.2.post1"
version = "0.2.3"
dynamic = ["readme", "dependencies"]


Expand Down
36 changes: 2 additions & 34 deletions src/gval/accessors/gval_xarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Iterable, Optional, Tuple, Union, Callable, Dict, List
from numbers import Number
from functools import partial
import ast

import numpy as np
import numba as nb
Expand All @@ -22,6 +21,7 @@
from gval.comparison.compute_categorical_metrics import _compute_categorical_metrics
from gval.comparison.compute_continuous_metrics import _compute_continuous_metrics
from gval.attributes.attributes import _attribute_tracking_xarray
from gval.utils.loading_datasets import _parse_string_attributes
from gval.utils.schemas import Crosstab_df, Metrics_df, AttributeTrackingDf
from gval.utils.visualize import _map_plot
from gval.comparison.pairing_functions import difference
Expand All @@ -41,7 +41,7 @@ class GVALXarray:
"""

def __init__(self, xarray_obj):
self._obj = xarray_obj
self._obj = _parse_string_attributes(xarray_obj)
self.data_type = type(xarray_obj)
self.agreement_map_format = "raster"

Expand Down Expand Up @@ -97,38 +97,6 @@ def __handle_attribute_tracking(

return results

def parse_string_attributes(self):
"""
Parses string attributes stored in rasters
"""

if "pairing_dictionary" in self._obj.attrs and isinstance(
self._obj.attrs["pairing_dictionary"], str
):
eval = ast.literal_eval(
self._obj.attrs["pairing_dictionary"].replace("nan", '"nan"')
)
self._obj.attrs["pairing_dictionary"] = {
(float(k[0]), float(k[1])): float(v) for k, v in eval.items()
}

if isinstance(self._obj, xr.Dataset):
for var in self._obj.data_vars:
self._obj[var].attrs["pairing_dictionary"] = self._obj.attrs[
"pairing_dictionary"
]

def attributes_to_string(self): # pragma: no cover
"""
Converts attributes to string to mimic a raster loaded from disk
"""
if "pairing_dictionary" in self._obj.attrs and isinstance(
self._obj.attrs["pairing_dictionary"], dict
):
self._obj.attrs["pairing_dictionary"] = str(
self._obj.attrs["pairing_dictionary"]
)

@Comparison.comparison_function_from_string
def categorical_compare(
self,
Expand Down
2 changes: 0 additions & 2 deletions src/gval/comparison/tabulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def not_nan(number):
return not np.isnan(number)

# Handle pairing dictionary attribute
agreement_map.gval.parse_string_attributes()
pairing_dict = agreement_map.attrs["pairing_dictionary"]

rev_dict = {}
Expand Down Expand Up @@ -240,7 +239,6 @@ def _crosstab_Datasets(agreement_map: xr.DataArray) -> DataFrame[Crosstab_df]:

# gets variable names
agreement_variable_names = list(agreement_map.data_vars)
agreement_map.gval.parse_string_attributes()

# loop variables
previous_crosstab_df = None # initializing to avoid having unset
Expand Down
68 changes: 61 additions & 7 deletions src/gval/utils/loading_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from typing import Union, Optional, Tuple, Dict, Any
import os
import ast

import rioxarray as rxr
import xarray as xr
Expand Down Expand Up @@ -192,8 +193,13 @@ def _handle_xarray_memory(
if band_as_var
else data_obj.encoding["source"]
)
new_obj = rxr.open_rasterio(
file_name, mask_and_scale=True, band_as_variable=band_as_var, cache=cache
new_obj = _parse_string_attributes(
rxr.open_rasterio(
file_name,
mask_and_scale=True,
band_as_variable=band_as_var,
cache=cache,
)
)
del data_obj
return new_obj
Expand All @@ -207,11 +213,13 @@ def _handle_xarray_memory(
data_obj.rio.to_raster(in_file.name, tiled=True, windowed=True)
del data_obj
cog_translate(in_file.name, out_file.name, dst_profile, in_memory=True)
return rxr.open_rasterio(
out_file.name,
mask_and_scale=True,
band_as_variable=band_as_var,
cache=cache,
return _parse_string_attributes(
rxr.open_rasterio(
out_file.name,
mask_and_scale=True,
band_as_variable=band_as_var,
cache=cache,
)
)


Expand All @@ -236,3 +244,49 @@ def _check_dask_array(original_map: Union[xr.DataArray, xr.Dataset]) -> bool:
else original_map.chunks
)
return chunks is not None


def _parse_string_attributes(
obj: Union[xr.DataArray, xr.Dataset]
) -> Union[xr.DataArray, xr.Dataset]:
"""
Parses string attributes stored in rasters
Parameters
----------
obj: Union[xr.DataArray, xr.Dataset]
Xarray object with possible string attributes
Returns
-------
Union[xr.DataArray, xr.Dataset]
Object returned with parsed attributes
"""

if "pairing_dictionary" in obj.attrs and isinstance(
obj.attrs["pairing_dictionary"], str
):
eval_str = ast.literal_eval(
obj.attrs["pairing_dictionary"].replace("nan", '"nan"')
)
obj.attrs["pairing_dictionary"] = {
(float(k[0]), float(k[1])): float(v) for k, v in eval_str.items()
}

if isinstance(obj, xr.Dataset):
for var in obj.data_vars:
obj[var].attrs["pairing_dictionary"] = obj.attrs["pairing_dictionary"]

if isinstance(obj, xr.Dataset):
for var in obj.data_vars:
if "pairing_dictionary" in obj[var].attrs and isinstance(
obj[var].attrs["pairing_dictionary"], str
):
eval_str = ast.literal_eval(
obj[var].attrs["pairing_dictionary"].replace("nan", '"nan"')
)
obj[var].attrs["pairing_dictionary"] = {
(float(k[0]), float(k[1])): float(v) for k, v in eval_str.items()
}

return obj
18 changes: 16 additions & 2 deletions tests/test_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ def tmp_path_apply(row):
# check that the values are the same
pd.testing.assert_frame_equal(agreement_catalog, expected)

def attributes_to_string(obj): # pragma: no cover
"""
Converts attributes to string to mimic a raster loaded from disk
"""
if "pairing_dictionary" in obj.attrs and isinstance(
obj.attrs["pairing_dictionary"], dict
):
obj.attrs["pairing_dictionary"] = str(obj.attrs["pairing_dictionary"])

return obj

# load agreement maps and check metadata
if agreement_map_field is not None:
# load agreement maps with apply and check that they are the same
Expand All @@ -105,8 +116,11 @@ def load_agreement_and_check(row, counter=[0]):
expected_agreement_map_xr = rxr.open_rasterio(
expected_agreement_map[counter[0]], **open_kwargs
)
agreement_map.gval.attributes_to_string()
xr.testing.assert_identical(agreement_map, expected_agreement_map_xr)

xr.testing.assert_identical(
attributes_to_string(agreement_map),
attributes_to_string(expected_agreement_map_xr),
)

# increment counter
counter[0] += 1
Expand Down
5 changes: 5 additions & 0 deletions tests/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,15 @@ def test_difference(c, b, expected_value):
@parametrize_with_cases("agreement_map, expected_df", glob="crosstab_2d_DataArrays")
def test_crosstab_2d_DataArrays(agreement_map, expected_df):
"""Test crosstabbing agreement DataArrays"""
agreement_map.gval # Initialize gval for attributes
crosstab_df = _crosstab_2d_DataArrays(agreement_map)
pd.testing.assert_frame_equal(crosstab_df, expected_df, check_dtype=False)


@parametrize_with_cases("agreement_map, expected_df", glob="crosstab_3d_DataArrays")
def test_crosstab_3d_DataArrays(agreement_map, expected_df):
"""Test crosstabbing agreement DataArrays"""
agreement_map.gval # Initialize gval for attributes
crosstab_df = _crosstab_3d_DataArrays(agreement_map)
pd.testing.assert_frame_equal(crosstab_df, expected_df, check_dtype=False)

Expand All @@ -146,20 +148,23 @@ def test_crosstab_3d_DataArrays(agreement_map, expected_df):
)
def test_crosstab_DataArrays_success(agreement_map, expected_df):
"""Test crosstabbing agreement DataArrays"""
agreement_map.gval # Initialize gval for attributes
crosstab_df = _crosstab_DataArrays(agreement_map)
pd.testing.assert_frame_equal(crosstab_df, expected_df, check_dtype=False)


@parametrize_with_cases("agreement_map", glob="crosstab_DataArrays_fail")
def test_crosstab_DataArrays_fail(agreement_map):
"""Test crosstabbing agreement DataArrays"""
agreement_map.gval # Initialize gval for attributes
with raises(ValueError):
_crosstab_DataArrays(agreement_map)


@parametrize_with_cases("agreement_map, expected_df", glob="crosstab_Datasets")
def test_crosstab_Datasets(agreement_map, expected_df):
"""Test crosstabbing agreement datasets"""
agreement_map.gval # Initialize gval for attributes
crosstab_df = _crosstab_Datasets(agreement_map)
# takes band_# pattern to just #
crosstab_df["band"] = crosstab_df["band"].apply(lambda x: x.split("_")[-1])
Expand Down

0 comments on commit 79a2c91

Please sign in to comment.