Skip to content

Commit

Permalink
Merge pull request #169 from NOAA-OWP/prob_compare
Browse files Browse the repository at this point in the history
Prob compare
  • Loading branch information
GregoryPetrochenkov-NOAA committed Dec 15, 2023
2 parents c5964b6 + 3c09fcf commit 02db25c
Show file tree
Hide file tree
Showing 23 changed files with 1,975 additions and 217 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dev = ["pytest==7.2.2",
"sqlalchemy==2.0.17",
"colorama==0.4.6",
"build==0.10.0",
"twine==4.0.2"
"twine==4.0.2",
"deepdiff==6.7.1"
]

[project.urls]
Expand Down
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ rio-cogeo==4.0.0
matplotlib==3.7.1
contextily==1.3.0
flox==0.7.2
xskillscore==0.0.24


###################################
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
max-line-length = 100
extend-ignore = E203, E501
ban-relative-imports = parents
per-file-ignores = __init__.py:F401,F403,E402
per-file-ignores =
__init__.py:F401,F403,E402
;exclude = *.py

[black]
Expand Down
2 changes: 1 addition & 1 deletion src/gval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
ContStats = ContinuousStatistics()

from gval.accessors import gval_array, gval_dataset, gval_dataframe
from gval.catalogs.catalogs import catalog_compare
from gval.catalogs.catalogs import catalog_compare
4 changes: 2 additions & 2 deletions src/gval/accessors/gval_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def create_subsampling_df(
if "geometry" not in geo_df.columns:
geo_df.set_geometry(
geometries if geometries is not None else self._obj["geometry"],
inplace=True
inplace=True,
)
geo_df.crs = crs

if subsampling_type:
geo_df["subsample_type"] = subsampling_type

Expand Down
114 changes: 108 additions & 6 deletions src/gval/accessors/gval_xarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Iterable, Optional, Tuple, Union, Callable, Dict, List
"""
Defines gval accessor for xarray objects.
"""

from typing import Iterable, Optional, Tuple, Union, Callable, Dict, List, Any
from numbers import Number

import numpy as np
Expand All @@ -19,17 +23,23 @@
from gval.comparison.tabulation import _crosstab_Datasets, _crosstab_DataArrays
from gval.comparison.compute_categorical_metrics import _compute_categorical_metrics
from gval.comparison.compute_continuous_metrics import _compute_continuous_metrics
from gval.comparison.compute_probabilistic_metrics import _compute_probabilistic_metrics
from gval.attributes.attributes import _attribute_tracking_xarray
from gval.utils.loading_datasets import _parse_string_attributes
from gval.utils.schemas import Crosstab_df, Metrics_df, AttributeTrackingDf
from gval.utils.visualize import _map_plot
from gval.comparison.pairing_functions import difference
from gval.subsampling.subsampling import subsample
from gval.utils.schemas import (
Crosstab_df,
Metrics_df,
AttributeTrackingDf,
Prob_metrics_df,
)


class GVALXarray:
"""
Class for extending xarray functionality
Class for extending xarray functionality.
Attributes
----------
Expand Down Expand Up @@ -385,11 +395,103 @@ def continuous_compare(
else:
attributes_df = results

del candidate_map, benchmark_map

return agreement_map, metrics_df, attributes_df

del candidate_map, benchmark_map
return agreement_map, metrics_df

def probabilistic_compare(
self,
benchmark_map: Union[gpd.GeoDataFrame, xr.Dataset, xr.DataArray],
metric_kwargs: dict,
return_on_error: Optional[Any] = None,
target_map: Optional[Union[xr.Dataset, str]] = "benchmark",
resampling: Optional[Resampling] = Resampling.nearest,
rasterize_attributes: Optional[list] = None,
attribute_tracking: bool = False,
attribute_tracking_kwargs: Optional[Dict] = None,
) -> DataFrame[Prob_metrics_df]:
"""
Computes probabilistic metrics from candidate and benchmark maps.
Parameters
----------
benchmark_map : xr.DataArray or xr.Dataset
Benchmark map.
metric_kwargs : dict
Dictionary of keyword arguments to metric functions. Keys must be metrics. Values are keyword arguments to metric functions. Don't pass keys or values for 'observations' or 'forecasts' as these are handled internally with `benchmark_map` and `candidate_map`, respectively. Available keyword arguments by metric are available in `DEFAULT_METRIC_KWARGS`. If values are None or empty dictionary, default values in `DEFAULT_METRIC_KWARGS` are used.
return_on_error : Optional[Any], default = None
Value to return within metrics dataframe if an error occurs when computing a metric. If None, the metric is not computed and None is returned. If 'error', the raised error is returned.
target_map: Optional[xr.Dataset or str], default = "benchmark"
xarray object to match the CRS's and coordinates of candidates and benchmarks to or str with 'candidate' or 'benchmark' as accepted values.
resampling : rasterio.enums.Resampling
See :func:`rasterio.warp.reproject` for more details.
nodata : Optional[Number], default = None
No data value to write to agreement map output. This will use `rxr.rio.write_nodata(nodata)`.
encode_nodata : Optional[bool], default = False
Encoded no data value to write to agreement map output. A nodata argument must be passed. This will use `rxr.rio.write_nodata(nodata, encode=encode_nodata)`.
rasterize_attributes: Optional[list], default = None
Numerical attributes of a GeoDataFrame to rasterize.
attribute_tracking: bool, default = False
Whether to return a dataframe with the attributes of the candidate and benchmark maps.
attribute_tracking_kwargs: Optional[Dict], default = None
Keyword arguments to pass to `gval.attribute_tracking()`. This is only used if `attribute_tracking` is True. By default, agreement maps are used for attribute tracking but this can be set to None within this argument to override. See `gval.attribute_tracking` for more information.
Returns
-------
DataFrame[Prob_metrics_df]
Probabilistic metrics Pandas DataFrame with computed xarray's per metric and sample.
Raises
------
ValueError
If keyword argument is required for metric but not passed.
If keyword argument is not available for metric but passed.
If metric is not available.
Warns
-----
UserWarning
Warns if a metric cannot be computed. `return_on_error` determines whether the metric is not computed and None is returned or if the raised error is returned.
References
----------
.. [1] `Scoring rule <https://en.wikipedia.org/wiki/Scoring_rule#Examples_of_proper_scoring_rules>`_
.. [2] `Strictly Proper Scoring Rules, Prediction, and Estimation <https://sites.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf>`_
.. [3] `Properscoring/properscoring <https://github.com/properscoring/properscoring>`_
.. [4] `xskillscore/xskillscore <https://xskillscore.readthedocs.io/en/stable/api.html#probabilistic-metrics>`_
.. [5] `What is a proper scoring rule? <https://statisticaloddsandends.wordpress.com/2021/03/27/what-is-a-proper-scoring-rule/>`_
"""

# agreement_map temporarily None
agreement_map = None

# using homogenize accessor to avoid code reuse
candidate, benchmark = self._obj.gval.homogenize(
benchmark_map, target_map, resampling, rasterize_attributes
)

metrics_df = _compute_probabilistic_metrics(
candidate_map=candidate,
benchmark_map=benchmark,
metric_kwargs=metric_kwargs,
return_on_error=return_on_error,
)

if attribute_tracking:
results = self.__handle_attribute_tracking(
candidate_map=candidate,
benchmark_map=benchmark,
agreement_map=agreement_map,
attribute_tracking_kwargs=attribute_tracking_kwargs,
)

if len(results) == 2:
# attributes_df, agreement_map = results
attributes_df, _ = results
else:
attributes_df = results

return agreement_map, metrics_df, attributes_df

return agreement_map, metrics_df

Expand Down
2 changes: 1 addition & 1 deletion src/gval/comparison/compute_continuous_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _compute_continuous_metrics(
subsampling_df: gpd.GeoDataFrame = None,
) -> DataFrame[Metrics_df]:
"""
Computes categorical metrics from a crosstab df.
Computes continuous metrics.
Parameters
----------
Expand Down
Loading

0 comments on commit 02db25c

Please sign in to comment.