From f5334fef2c61117e7bc9e7c21698471cfaf68ac2 Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Thu, 22 Jun 2023 18:37:09 -0400 Subject: [PATCH 1/5] Continuous Comparison and Tech Debt PT1 --- CONTRIBUTING.MD | 7 +- docs/sphinx/continuous_stat_funcs.rst | 7 + docs/sphinx/continuous_statistics.rst | 7 + docs/sphinx/schemas.rst | 1 + docs/sphinx/statistics.rst | 2 + src/gval/__init__.py | 2 + src/gval/accessors/gval_xarray.py | 17 +- src/gval/comparison/compute_comparison.py | 7 +- .../comparison/compute_continuous_metrics.py | 73 +---- src/gval/homogenize/rasterize.py | 10 +- src/gval/statistics/categorical_stat_funcs.py | 51 ++- src/gval/statistics/categorical_statistics.py | 6 - src/gval/statistics/continuous_stat_funcs.py | 153 +++------ src/gval/statistics/continuous_statistics.py | 292 ++++++++++++++++++ tests/cases_compute_continuous_metrics.py | 161 +++++++++- tests/cases_continuous_metrics.py | 8 +- tests/test_compute_continuous_metrics.py | 96 +++++- tests/test_continuous_metrics.py | 26 +- 18 files changed, 696 insertions(+), 230 deletions(-) create mode 100644 docs/sphinx/continuous_stat_funcs.rst create mode 100644 docs/sphinx/continuous_statistics.rst create mode 100644 src/gval/statistics/continuous_statistics.py diff --git a/CONTRIBUTING.MD b/CONTRIBUTING.MD index efa12306..daa2b374 100644 --- a/CONTRIBUTING.MD +++ b/CONTRIBUTING.MD @@ -55,10 +55,11 @@ cpu time performance loss. To run everything (in root project directory): up of other docs located in Markdown. Edit those directly if need be before running the preceding command.) To build sphinx documentation locally, change to the docs/sphinx folder and run `make clean && make html`. The html will be created in the _build/html folder. Open index.html in a browser to preview docs. -8. Commit your changes: `git commit -m 'feature message'` This will invoke pre-commit hooks mentioned on step 5 +9. Commit your changes: `git commit -m 'feature message'` This will invoke pre-commit hooks mentioned on step 5 that will lint the code. Make sure all of these checks pass, if not make changes and re-commit. -9. Push to the branch: `git push origin ` -10. Open a pull request (review checklist in PR template before requesting a review) +10. Push to the branch: `git push -u origin`, or if the branch is not pushed up yet: +`git push --set-upstream origin ` +11. Open a pull request (review checklist in PR template before requesting a review) ## Standards diff --git a/docs/sphinx/continuous_stat_funcs.rst b/docs/sphinx/continuous_stat_funcs.rst new file mode 100644 index 00000000..90d5e05d --- /dev/null +++ b/docs/sphinx/continuous_stat_funcs.rst @@ -0,0 +1,7 @@ +Continuous Statistics Functions +############################### + +:doc:`Return to Homepage <../index>` + +.. automodule:: gval.statistics.continuous_stat_funcs + :members: diff --git a/docs/sphinx/continuous_statistics.rst b/docs/sphinx/continuous_statistics.rst new file mode 100644 index 00000000..0c317a37 --- /dev/null +++ b/docs/sphinx/continuous_statistics.rst @@ -0,0 +1,7 @@ +Continuous Statistics +##################### + +:doc:`Return to Homepage <../index>` + +.. automodule:: gval.statistics.continuous_statistics + :members: diff --git a/docs/sphinx/schemas.rst b/docs/sphinx/schemas.rst index ed2efe39..987a6c47 100644 --- a/docs/sphinx/schemas.rst +++ b/docs/sphinx/schemas.rst @@ -5,3 +5,4 @@ Schemas .. automodule:: gval.utils.schemas :members: + :undoc-members: diff --git a/docs/sphinx/statistics.rst b/docs/sphinx/statistics.rst index e0560118..c6e06d06 100644 --- a/docs/sphinx/statistics.rst +++ b/docs/sphinx/statistics.rst @@ -9,3 +9,5 @@ Statistics categorical_stat_funcs categorical_statistics + continuous_stat_funcs + continuous_statistics diff --git a/src/gval/__init__.py b/src/gval/__init__.py index 1b3c5393..af880196 100644 --- a/src/gval/__init__.py +++ b/src/gval/__init__.py @@ -1,7 +1,9 @@ from gval.comparison.compute_comparison import ComparisonProcessing from gval.statistics.categorical_statistics import CategoricalStatistics +from gval.statistics.continuous_statistics import ContinuousStatistics Comparison = ComparisonProcessing() CatStats = CategoricalStatistics() +ContStats = ContinuousStatistics() from gval.accessors import gval_array, gval_dataset, gval_dataframe diff --git a/src/gval/accessors/gval_xarray.py b/src/gval/accessors/gval_xarray.py index a1d3583f..25976467 100644 --- a/src/gval/accessors/gval_xarray.py +++ b/src/gval/accessors/gval_xarray.py @@ -96,8 +96,9 @@ def categorical_compare( Number or list of numbers representing the values to consider as the positive condition. For average types "macro" and "weighted", this represents the categories to compute metrics for. comparison_function : Union[Callable, nb.np.ufunc.dufunc.DUFunc, np.ufunc, np.vectorize, str], default = 'szudzik' Comparison function. Created by decorating function with @nb.vectorize() or using np.ufunc(). Use of numba is preferred as it is faster. Strings with registered comparison_functions are also accepted. Possible options include "pairing_dict". If passing "pairing_dict" value, please see the description for the argument for more information on behaviour. + All available comparison functions can be found with gval.Comparison.available_functions(). metrics: Union[str, Iterable[str]], default = "all" - Statistics to return in metric table. + Statistics to return in metric table. All returns every default and registered metric. This can be seen with gval.CatStats.available_functions(). target_map: Optional[Union[xr.Dataset, str]], default = "benchmark" xarray object to match the CRS's and coordinates of candidates and benchmarks to or str with 'candidate' or 'benchmark' as accepted values. resampling : rasterio.enums.Resampling @@ -109,9 +110,9 @@ def categorical_compare( A pairing dictionary can be used by the user to note which values to allow and which to ignore for comparisons. It can also be used to decide how nans are handled for cases where either the candidate and benchmark maps have nans or both. allow_candidate_values : Optional[Iterable[Union[int,float]]], default = None - List of values in candidate to include in computation of agreement map. Remaining values are excluded. If "pairing_dict" is set selected for comparison_function and pairing_function is None, this argument is necessary to construct the dictionary. Otherwise, this argument is optional and by default this value is set to None and all values are considered. + List of values in candidate to include in computation of agreement map. Remaining values are excluded. If "pairing_dict" is provided for comparison_function and pairing_function is None, this argument is necessary to construct the dictionary. Otherwise, this argument is optional and by default this value is set to None and all values are considered. allow_benchmark_values : Optional[Iterable[Union[int,float]]], default = None - List of values in benchmark to include in computation of agreement map. Remaining values are excluded. If "pairing_dict" is set selected for comparison_function and pairing_function is None, this argument is necessary to construct the dictionary. Otherwise, this argument is optional and by default this value is set to None and all values are considered. + List of values in benchmark to include in computation of agreement map. Remaining values are excluded. If "pairing_dict" is provided for comparison_function and pairing_function is None, this argument is necessary to construct the dictionary. Otherwise, this argument is optional and by default this value is set to None and all values are considered. nodata : Optional[Number], default = None No data value to write to agreement map output. This will use `rxr.rio.write_nodata(nodata)`. encode_nodata : Optional[bool], default = False @@ -132,7 +133,8 @@ def categorical_compare( `positive_categories = [1, 2]; weights = [0.25, 0.75]` rasterize_attributes: Optional[list], default = None - Numerical attributes of a GeoDataFrame to rasterize + Numerical attributes of a Benchmark Map GeoDataFrame to rasterize. Only applicable if benchmark map is a vector file. + This cannot be none if the benchmark map is a vector file. Returns ------- @@ -204,7 +206,7 @@ def continuous_compare( benchmark_map : Union[gpd.GeoDataFrame, xr.DataArray, xr.Dataset] Benchmark map. metrics: Union[str, Iterable[str]], default = "all" - Statistics to return in metric table. + Statistics to return in metric table. This can be seen with gval.ContStats.available_functions(). target_map: Optional[Union[xr.Dataset, str]], default = "benchmark" xarray object to match the CRS's and coordinates of candidates and benchmarks to or str with 'candidate' or 'benchmark' as accepted values. resampling : rasterio.enums.Resampling @@ -235,7 +237,10 @@ def continuous_compare( ) metrics_df = _compute_continuous_metrics( - candidate_map=candidate, benchmark_map=benchmark, metrics=metrics + agreement_map=agreement_map, + candidate_map=candidate, + benchmark_map=benchmark, + metrics=metrics, ) del candidate, benchmark diff --git a/src/gval/comparison/compute_comparison.py b/src/gval/comparison/compute_comparison.py index 69d34365..04a1f7c5 100644 --- a/src/gval/comparison/compute_comparison.py +++ b/src/gval/comparison/compute_comparison.py @@ -17,12 +17,7 @@ class ComparisonProcessing: Attributes ---------- - _func_names : list (private) - Names of all functions from default categorical statistics class - _funcs : list (private) - List of all functions from default categorical statistics class - _signature_validation : dict (private) - Dictionary to validate all registered functions + registered_functions : dict Available statistical functions with names as keys and parameters as values """ diff --git a/src/gval/comparison/compute_continuous_metrics.py b/src/gval/comparison/compute_continuous_metrics.py index b27998ca..670d6b68 100644 --- a/src/gval/comparison/compute_continuous_metrics.py +++ b/src/gval/comparison/compute_continuous_metrics.py @@ -12,27 +12,15 @@ from pandera.typing import DataFrame import xarray as xr +from gval import ContStats from gval.utils.schemas import Metrics_df -from gval.statistics.continuous_stat_funcs import ( - mean_absolute_error, - mean_squared_error, - root_mean_squared_error, - mean_percentage_error, - mean_absolute_percentage_error, - coefficient_of_determination, - mean_normalized_mean_absolute_error, - range_normalized_mean_absolute_error, - mean_normalized_root_mean_squared_error, - range_normalized_root_mean_squared_error, -) - @pa.check_types def _compute_continuous_metrics( - agreement_map: Union[xr.DataArray, xr.Dataset] = None, - candidate_map: Union[xr.DataArray, xr.Dataset] = None, - benchmark_map: Union[xr.DataArray, xr.Dataset] = None, + agreement_map: Union[xr.DataArray, xr.Dataset], + candidate_map: Union[xr.DataArray, xr.Dataset], + benchmark_map: Union[xr.DataArray, xr.Dataset], metrics: Union[str, Iterable[str]] = "all", ) -> DataFrame[Metrics_df]: """ @@ -40,11 +28,11 @@ def _compute_continuous_metrics( Parameters ---------- - agreement_map : Union[xr.DataArray, xr.Dataset], default = None + agreement_map : Union[xr.DataArray, xr.Dataset] Agreement map, error based (candidate - benchmark). - candidate_map : Union[xr.DataArray, xr.Dataset], default = None + candidate_map : Union[xr.DataArray, xr.Dataset] Candidate map. - benchmark_map : Union[xr.DataArray, xr.Dataset], default = None + benchmark_map : Union[xr.DataArray, xr.Dataset] Benchmark map. metrics : Union[str, Iterable[str]], default = "all" String or list of strings representing metrics to compute. @@ -66,48 +54,21 @@ def _compute_continuous_metrics( """ # compute error based metrics such as MAE, MSE, RMSE, etc. from agreement map and produce metrics_df - - all_metric_funcs = [ - mean_absolute_error, - mean_squared_error, - root_mean_squared_error, - mean_percentage_error, - mean_absolute_percentage_error, - coefficient_of_determination, - mean_normalized_mean_absolute_error, - range_normalized_mean_absolute_error, - mean_normalized_root_mean_squared_error, - range_normalized_root_mean_squared_error, - ] - - all_metric_names = [ - "mean_absolute_error", - "mean_squared_error", - "root_mean_squared_error", - "mean_percentage_error", - "mean_absolute_percentage_error", - "coefficient_of_determination", - "normalized_mean_absolute_error", - "normalized_root_mean_squared_error", - ] - - # create dictionary of metric names and functions - all_metric_dict = dict(zip(all_metric_names, all_metric_funcs)) - - if metrics == "all": - metric_dict = all_metric_dict - else: - metrics_set = set(metrics) - metric_dict = {k: all_metric_dict[k] for k in metrics if k in metrics_set} + statistics, names = ContStats.process_statistics( + metrics, + error=agreement_map, + candidate_map=candidate_map, + benchmark_map=benchmark_map, + ) # create metrics_df metric_df = dict() - for name, func in metric_dict.items(): - metric_df[name] = func(agreement_map, candidate_map, benchmark_map) + for name, stat in zip(names, statistics): + metric_df[name] = stat def is_nested_dict(d): - if not isinstance(d, dict): - return False + # if not isinstance(d, dict): + # return False return any(isinstance(v, dict) for v in d.values()) if is_nested_dict(metric_df): diff --git a/src/gval/homogenize/rasterize.py b/src/gval/homogenize/rasterize.py index b662cf54..c71abd39 100644 --- a/src/gval/homogenize/rasterize.py +++ b/src/gval/homogenize/rasterize.py @@ -41,8 +41,14 @@ def _rasterize_data( """ for attr in rasterize_attributes: - if not issubclass(type(benchmark_map[attr][0]), Number): - raise KeyError("Rasterize attribute needs to be of numeric type") + if ( + not issubclass(type(benchmark_map[attr][0]), Number) + or rasterize_attributes is None + or len(rasterize_attributes) == 0 + ): + raise KeyError( + "Rasterize attribute needs to be not null and of numeric type" + ) rasterized_data = make_geocube( vector_data=benchmark_map, diff --git a/src/gval/statistics/categorical_stat_funcs.py b/src/gval/statistics/categorical_stat_funcs.py index 15c33523..f6f68a1c 100644 --- a/src/gval/statistics/categorical_stat_funcs.py +++ b/src/gval/statistics/categorical_stat_funcs.py @@ -19,7 +19,8 @@ def true_positive_rate(tp: Number, fn: Number) -> float: Returns ------- - True positive rate from 0 to 1 + float + True positive rate from 0 to 1 References ---------- @@ -41,7 +42,8 @@ def true_negative_rate(tn: Number, fp: Number) -> float: Returns ------- - True negative rate from 0 to 1 + float + True negative rate from 0 to 1 References ---------- @@ -63,7 +65,8 @@ def positive_predictive_value(tp: Number, fp: Number) -> float: Returns ------- - Positive predictive value from 0 to 1 + float + Positive predictive value from 0 to 1 References ---------- @@ -85,7 +88,8 @@ def negative_predictive_value(tn: Number, fn: Number) -> float: Returns ------- - Negative predictive value from 0 to 1 + float + Negative predictive value from 0 to 1 References ---------- @@ -107,7 +111,8 @@ def false_negative_rate(tp: Number, fn: Number) -> float: Returns ------- - False negative rate from 0 to 1 + float + False negative rate from 0 to 1 References ---------- @@ -129,7 +134,8 @@ def false_positive_rate(tn: Number, fp: Number) -> float: Returns ------- - False positive rate from 0 to 1 + float + False positive rate from 0 to 1 References ---------- @@ -151,7 +157,8 @@ def false_discovery_rate(tp: Number, fp: Number) -> float: Returns ------- - False discovery rate from 0 to 1 + float + False discovery rate from 0 to 1 References ---------- @@ -173,7 +180,8 @@ def false_omission_rate(tn: Number, fn: Number) -> float: Returns ------- - False omission rate from 0 to 1 + float + False omission rate from 0 to 1 References ---------- @@ -199,7 +207,8 @@ def positive_likelihood_ratio(tp: Number, tn: Number, fp: Number, fn: Number) -> Returns ------- - Positive likelihood rate from 1 to infinity + float + Positive likelihood rate from 1 to infinity References ---------- @@ -225,7 +234,8 @@ def negative_likelihood_ratio(tp: Number, tn: Number, fp: Number, fn: Number) -> Returns ------- - Negative likelihood from 1 to infinity + float + Negative likelihood from 1 to infinity References ---------- @@ -251,7 +261,8 @@ def prevalence_threshold(tp: Number, tn: Number, fp: Number, fn: Number) -> floa Returns ------- - Prevalence threshold from 0 to 1 + float + Prevalence threshold from 0 to 1 References ---------- @@ -279,7 +290,8 @@ def critical_success_index(tp: Number, fp: Number, fn: Number) -> float: Returns ------- - Critical success index from 0 to 1 + float + Critical success index from 0 to 1 References ---------- @@ -305,7 +317,8 @@ def prevalence(tp: Number, tn: Number, fp: Number, fn: Number) -> float: Returns ------- - Prevalence from 0 to 1 + float + Prevalence from 0 to 1 References ---------- @@ -331,7 +344,8 @@ def accuracy(tp: Number, tn: Number, fp: Number, fn: Number) -> float: Returns ------- - Accuracy from 0 to 1 + float + Accuracy from 0 to 1 References ---------- @@ -355,7 +369,8 @@ def f_score(tp: Number, fp: Number, fn: Number) -> float: Returns ------- - F-score from 0 to 1 + float + F-score from 0 to 1 References ---------- @@ -384,7 +399,8 @@ def matthews_correlation_coefficient( Returns ------- - Correlation coefficient from -1 to 1 + float + Correlation coefficient from -1 to 1 References ---------- @@ -410,7 +426,8 @@ def fowlkes_mallows_index(tp: Number, fp: Number, fn: Number) -> float: Returns ------- - Correlation coefficient from -1 to 1 + float + Correlation coefficient from -1 to 1 References ---------- diff --git a/src/gval/statistics/categorical_statistics.py b/src/gval/statistics/categorical_statistics.py index 8b5a9a88..0d95efa0 100644 --- a/src/gval/statistics/categorical_statistics.py +++ b/src/gval/statistics/categorical_statistics.py @@ -19,12 +19,6 @@ class CategoricalStatistics(BaseStatistics): Attributes ---------- - _func_names : list (private) - Names of all functions from default categorical statistics class - _funcs : list (private) - List of all functions from default categorical statistics class - _signature_validation : dict (private) - Dictionary to validate all registered functions registered_functions : dict Available statistical functions with names as keys and parameters as values """ diff --git a/src/gval/statistics/continuous_stat_funcs.py b/src/gval/statistics/continuous_stat_funcs.py index 5ae9ae7f..c6a262b7 100644 --- a/src/gval/statistics/continuous_stat_funcs.py +++ b/src/gval/statistics/continuous_stat_funcs.py @@ -3,21 +3,16 @@ """ from numbers import Number -from typing import Union, Optional +from typing import Union import xarray as xr import numpy as np -from gval.statistics.continuous_stat_utils import convert_output, compute_error_if_none +from gval.statistics.continuous_stat_utils import convert_output -@compute_error_if_none @convert_output -def mean_absolute_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, -) -> Number: +def mean_absolute_error(error: Union[xr.DataArray, xr.Dataset]) -> Number: """ Compute mean absolute error (MAE). @@ -25,12 +20,8 @@ def mean_absolute_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Benchmark map. Returns ------- @@ -44,13 +35,8 @@ def mean_absolute_error( return np.abs(error).mean() -@compute_error_if_none @convert_output -def mean_squared_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, -) -> Number: +def mean_squared_error(error: Union[xr.DataArray, xr.Dataset]) -> Number: """ Compute mean squared error (MSE). @@ -58,12 +44,8 @@ def mean_squared_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Benchmark map. Returns ------- @@ -78,13 +60,8 @@ def mean_squared_error( return (error**2).mean() -@compute_error_if_none @convert_output -def root_mean_squared_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, -) -> Number: +def root_mean_squared_error(error: Union[xr.DataArray, xr.Dataset]) -> Number: """ Compute root mean squared error (RMSE). @@ -92,12 +69,8 @@ def root_mean_squared_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Benchmark map. Returns ------- @@ -112,13 +85,8 @@ def root_mean_squared_error( return np.sqrt((error**2).mean()) -@compute_error_if_none @convert_output -def mean_signed_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, -) -> Number: +def mean_signed_error(error: Union[xr.DataArray, xr.Dataset]) -> Number: """ Compute mean signed error (MSiE). @@ -126,12 +94,8 @@ def mean_signed_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Benchmark map. Returns ------- @@ -146,12 +110,10 @@ def mean_signed_error( return error.mean() -@compute_error_if_none @convert_output def mean_percentage_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, + error: Union[xr.DataArray, xr.Dataset], + benchmark_map: Union[xr.DataArray, xr.Dataset], ) -> Number: """ Compute mean percentage error (MPE). @@ -160,11 +122,9 @@ def mean_percentage_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None + benchmark_map : Union[xr.DataArray, xr.Dataset] Benchmark map. Returns @@ -179,12 +139,10 @@ def mean_percentage_error( return (error / benchmark_map.mean()).mean() -@compute_error_if_none @convert_output def mean_absolute_percentage_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, + error: Union[xr.DataArray, xr.Dataset], + benchmark_map: Union[xr.DataArray, xr.Dataset], ) -> Number: """ Compute mean absolute percentage error (MAPE). @@ -193,11 +151,9 @@ def mean_absolute_percentage_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None + benchmark_map : Union[xr.DataArray, xr.Dataset] Benchmark map. Returns @@ -213,12 +169,10 @@ def mean_absolute_percentage_error( return np.abs(error / benchmark_map).mean() -@compute_error_if_none @convert_output def mean_normalized_root_mean_squared_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, + error: Union[xr.DataArray, xr.Dataset], + benchmark_map: Union[xr.DataArray, xr.Dataset], ) -> Number: """ Compute mean normalized root mean squared error (NRMSE). @@ -227,11 +181,9 @@ def mean_normalized_root_mean_squared_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None + benchmark_map : Union[xr.DataArray, xr.Dataset] Benchmark map. Returns @@ -247,12 +199,10 @@ def mean_normalized_root_mean_squared_error( return np.sqrt((error**2).mean()) / benchmark_map.mean() -@compute_error_if_none @convert_output def range_normalized_root_mean_squared_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, + error: Union[xr.DataArray, xr.Dataset], + benchmark_map: Union[xr.DataArray, xr.Dataset], ) -> Number: """ Compute range normalized root mean squared error (RNRMSE). @@ -261,11 +211,9 @@ def range_normalized_root_mean_squared_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None + benchmark_map : Union[xr.DataArray, xr.Dataset] Benchmark map. Returns @@ -280,12 +228,10 @@ def range_normalized_root_mean_squared_error( return np.sqrt((error**2).mean()) / (benchmark_map.max() - benchmark_map.min()) -@compute_error_if_none @convert_output def mean_normalized_mean_absolute_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, + error: Union[xr.DataArray, xr.Dataset], + benchmark_map: Union[xr.DataArray, xr.Dataset], ) -> Number: """ Compute mean normalized mean absolute error (NMAE). @@ -294,11 +240,9 @@ def mean_normalized_mean_absolute_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None + benchmark_map : Union[xr.DataArray, xr.Dataset] Benchmark map. Returns @@ -313,12 +257,10 @@ def mean_normalized_mean_absolute_error( return np.abs(error).mean() / benchmark_map.mean() -@compute_error_if_none @convert_output def range_normalized_mean_absolute_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, + error: Union[xr.DataArray, xr.Dataset], + benchmark_map: Union[xr.DataArray, xr.Dataset], ) -> Number: """ Compute range normalized mean absolute error (RNMAE). @@ -327,11 +269,9 @@ def range_normalized_mean_absolute_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None + benchmark_map : Union[xr.DataArray, xr.Dataset] Benchmark map. Returns @@ -346,12 +286,10 @@ def range_normalized_mean_absolute_error( return np.abs(error).mean() / (benchmark_map.max() - benchmark_map.min()) -@compute_error_if_none @convert_output def coefficient_of_determination( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, + error: Union[xr.DataArray, xr.Dataset], + benchmark_map: Union[xr.DataArray, xr.Dataset], ) -> Number: """ Compute coefficient of determination (R2). @@ -360,11 +298,9 @@ def coefficient_of_determination( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None - Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None + benchmark_map : Union[xr.DataArray, xr.Dataset] Benchmark map. Returns @@ -380,12 +316,11 @@ def coefficient_of_determination( return 1 - (error**2).sum() / ((benchmark_map - benchmark_map.mean()) ** 2).sum() -@compute_error_if_none @convert_output def symmetric_mean_absolute_percentage_error( - error: Optional[Union[xr.DataArray, xr.Dataset]] = None, - candidate_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, - benchmark_map: Optional[Union[xr.DataArray, xr.Dataset]] = None, + error: Union[xr.DataArray, xr.Dataset], + candidate_map: Union[xr.DataArray, xr.Dataset], + benchmark_map: Union[xr.DataArray, xr.Dataset], ) -> Number: """ Compute symmetric mean absolute percentage error (sMAPE). @@ -394,11 +329,11 @@ def symmetric_mean_absolute_percentage_error( Parameters ---------- - error : Optional[Union[xr.DataArray, xr.Dataset]], default = None + error : Union[xr.DataArray, xr.Dataset] Candidate minus benchmark error. - candidate_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None + candidate_map : Union[xr.DataArray, xr.Dataset] Candidate map. - benchmark_map : Optional[Union[xr.DataArray, xr.Dataset]], default = None + benchmark_map : Union[xr.DataArray, xr.Dataset] Benchmark map. Returns diff --git a/src/gval/statistics/continuous_statistics.py b/src/gval/statistics/continuous_statistics.py new file mode 100644 index 00000000..42ace93d --- /dev/null +++ b/src/gval/statistics/continuous_statistics.py @@ -0,0 +1,292 @@ +""" +Continuous Statistics Class +""" + +from numbers import Number +from typing import Union, Tuple +from functools import wraps +import inspect + +import numpy as np + +from gval.statistics.base_statistics import BaseStatistics +import gval.statistics.continuous_stat_funcs as cs + + +class ContinuousStatistics(BaseStatistics): + """ + Class for Running Continuous Statistics on Agreement Maps + + Attributes + ---------- + registered_functions : dict + Available statistical functions with names as keys and parameters as values + """ + + def __init__(self): + # Automatically populates and numba vectorizes all functions in categorical_stat_funcs.py + + self.required_param = 1 + self.optional_param = 0 + + self._func_names = [ + fn + for fn in dir(cs) + if len(fn) > 5 + and "__" not in fn + and "Number" not in fn + and "convert_output" not in fn + ] + self._funcs = [getattr(cs, name) for name in self._func_names] + + for name, func in zip(self._func_names, self._funcs): + setattr(self, name, func) + + self._signature_validation = { + "names": { + "error": self.required_param, + "candidate_map": self.optional_param, + "benchmark_map": self.optional_param, + }, + "required": [ + self.optional_param, + self.optional_param, + self.optional_param, + ], + "param_types": [ + "xarray.core.dataset.Dataset", + "xarray.core.dataarray.DataArray", + "Union[xarray.core.dataarray.DataArray, xarray.core.dataset.Dataset]", + "Union[xarray.core.dataset.Dataset, xarray.core.dataarray.DataArray]", + ], + "return_type": [Number], + "no_of_args": [1, 2, 3], + } + + self.registered_functions = { + name: {"params": [param for param in inspect.signature(func).parameters]} + for name, func in zip(self._func_names, self._funcs) + } + + def available_functions(self) -> list: + """ + Lists all available functions + + Returns + ------- + List of available functions + """ + return list(self.registered_functions.keys()) + + def get_all_parameters(self): + """ + Get all the possible arguments + + Returns + ------- + List of all possible arguments for functions + """ + + return list(self._signature_validation["names"].keys()) + + def register_function(self, name: str): + """ + Register decorator function in statistics class + + Parameters + ---------- + name: str + Name of function to register in statistics class + + Returns + ------- + Decorator function + """ + + def decorator(func): + self.function_signature_check(func) + + if name not in self.registered_functions: + self.registered_functions[name] = { + "params": [ + param + for param in inspect.signature(func).parameters + if param != "self" + ] + } + + setattr(self, name, func) + else: + raise KeyError("This function name already exists") + + @wraps(func) + def wrapper(*args, **kwargs): # pragma: no cover + result = func(*args, **kwargs) + + return result + + return wrapper + + return decorator + + def register_function_class(self): + """ + Register decorator function for an entire class + + Parameters + ---------- + vectorize_func: bool + Whether to vectorize the function + + """ + + def decorator(dec_self: object): + """ + Decorator for wrapper + + Parameters + ---------- + dec_self: object + Class to register stat functions + """ + + for name, func in inspect.getmembers(dec_self, inspect.isfunction): + if name not in self.registered_functions: + self.function_signature_check(func) + self.registered_functions[name] = { + "params": [ + param + for param in inspect.signature(func).parameters + if param != "self" + ] + } + + setattr(self, name, func) + else: + raise KeyError("This function name already exists") + + return decorator + + def function_signature_check(self, func): + """ + Validates signature of registered function + + Parameters + ---------- + func: function + Function to check the signature of + """ + signature = inspect.signature(func) + names = self._signature_validation["names"] + param_types = self._signature_validation["param_types"] + return_type = self._signature_validation["return_type"] + no_of_args = self._signature_validation["no_of_args"] + + # Checks if param names, type, and return type are in valid list + # Considered no validation if either are empty + for key, val in signature.parameters.items(): + if (key not in names and len(names) > 0) or ( + not str(val).split(": ")[-1] in param_types and len(param_types) > 0 + ): + raise TypeError( + "Wrong parameters in function: \n" + f"Valid Names: {names} \n" + f"Valid Types: {param_types} \n" + ) + + if len(no_of_args) > 0 and len(signature.parameters) not in no_of_args: + raise TypeError( + "Wrong number of parameters: \n" + f"Valid number of parameters: {no_of_args}" + ) + + if signature.return_annotation not in return_type and len(return_type) > 0: + raise TypeError("Wrong return type \n" f"Valid return Type {return_type}") + + def get_parameters(self, func_name: str) -> list: + """ + Get parameters of registered function + + Parameters + ---------- + func_name: str + + + Returns + ------- + List of parameter names for the associated function + """ + + if func_name in self.registered_functions: + return self.registered_functions[func_name]["params"] + else: + raise KeyError("Statistic not found in registered functions") + + def process_statistics( + self, func_names: Union[str, list], **kwargs + ) -> Tuple[float, str]: + """ + + Parameters + ---------- + func_names: Union[str, list] + Name of registered function to run + **kwargs: dict or keyword arguments + Dictionary or keyword arguments of to pass to metric functions. + + Returns + ------- + Tuple[float, str] + Tuple with metric values and metric names. + """ + + func_names = ( + list(self.registered_functions.keys()) + if func_names == "all" + else func_names + ) + func_list = [func_names] if isinstance(func_names, str) else func_names + + return_stats, return_funcs = [], [] + for name in func_list: + if name in self.registered_functions: + params = self.get_parameters(name) + required = self._signature_validation["required"] + + func = getattr(self, name) + + # Necessary for numba functions which cannot accept keyword arguments + func_args, skip_function = [], False + for param, req in zip(params, required): + if param in kwargs and kwargs[param] is not None: + func_args.append(kwargs[param]) + elif not self._signature_validation["names"][param]: + skip_function = True + break + else: + raise ValueError("Parameter missing form kwargs") + + if skip_function: + continue + + stat_val = func(*func_args) + + def check_value(stat_name: str, stat: Number): + if np.isnan(stat) or np.isinf(stat): + raise ValueError( + f"Invalid value calculated for {stat_name}:", stat + ) + + if isinstance(stat_val, dict): + for st_name, val in stat_val.items(): + check_value(st_name, val) + else: + check_value(name, stat_val) + + return_stats.append(stat_val) + return_funcs.append(name) + + else: + raise KeyError(f"Statistic, {name}, not found in registered functions") + + return return_stats, return_funcs diff --git a/tests/cases_compute_continuous_metrics.py b/tests/cases_compute_continuous_metrics.py index c680e119..76c0cf19 100644 --- a/tests/cases_compute_continuous_metrics.py +++ b/tests/cases_compute_continuous_metrics.py @@ -5,9 +5,13 @@ # __all__ = ['*'] __author__ = "Fernando Aristizabal" +from numbers import Number +from typing import Union + +import numpy as np from pytest_cases import parametrize import pandas as pd - +import xarray as xr from tests.conftest import _load_xarray @@ -15,14 +19,18 @@ pd.DataFrame( { "band": {0: "1"}, + "coefficient_of_determination": -0.06615996360778809, "mean_absolute_error": 0.3173885941505432, + "mean_absolute_percentage_error": 0.15956786274909973, + "mean_normalized_mean_absolute_error": 0.16163060069084167, + "mean_normalized_root_mean_squared_error": 0.2802138924598694, + "mean_percentage_error": -0.015025136061012745, + "mean_signed_error": -0.02950434572994709, "mean_squared_error": 0.30277130007743835, + "range_normalized_mean_absolute_error": 0.27127230167388916, + "range_normalized_root_mean_squared_error": 0.4702962636947632, "root_mean_squared_error": 0.5502465963363647, - "mean_percentage_error": 0.015025136061012745, - "mean_absolute_percentage_error": 0.15956786274909973, - "coefficient_of_determination": -0.06615996360778809, - "normalized_mean_absolute_error": 0.16163060069084167, - "normalized_root_mean_squared_error": 0.27127230167388916, + "symmetric_mean_absolute_percentage_error": 0.1628540771054295, } ), pd.DataFrame( @@ -37,6 +45,12 @@ candidate_maps = ["candidate_continuous_0.tif", "candidate_continuous_1.tif"] benchmark_maps = ["benchmark_continuous_0.tif", "benchmark_continuous_1.tif"] +candidate_maps_fail = [ + "candidate_continuous_0_fail.tif", + "candidate_continuous_0_fail.tif", + "candidate_continuous_0_fail.tif", + None, +] all_load_options = [ {"mask_and_scale": True}, @@ -70,3 +84,138 @@ def case_compute_continuous_metrics_success( metrics, expected_df, ) + + +stat_names = ["all", "all", "non_existent_function", "all"] +exceptions = [ValueError, ValueError, KeyError, ValueError] + + +@parametrize( + "names, error_map, exception", + list(zip(stat_names, candidate_maps_fail, exceptions)), +) +def case_compute_continuous_statistics_fail(names, error_map, exception): + test_map = _load_xarray(error_map) if error_map is not None else error_map + return names, test_map, exception + + +stat_args = [{"name": "test_func"}, {"name": "test_func2"}] + + +def pass1(error: Union[xr.DataArray, xr.Dataset]) -> Number: + return (error + 1).mean() + + +def pass2( + error: Union[xr.Dataset, xr.DataArray], + benchmark_map: Union[xr.Dataset, xr.DataArray], +) -> Number: + return ((error + benchmark_map) / 1000).sum() + + +stat_funcs = [pass1, pass2] + + +@parametrize("args, func", list(zip(stat_args, stat_funcs))) +def case_register_continuous_function(args, func): + return args, func + + +stat_args = [ + {"name": "mean_percentage_error"}, + {"name": "test2"}, + {"name": "test2"}, + {"name": "test2"}, + {"name": "test2"}, +] + + +def fail1(error: Union[xr.Dataset, xr.DataArray]) -> Number: + return error.mean() + + +def fail2(arb: Union[xr.Dataset, xr.DataArray]) -> Number: + return arb.mean() + + +def fail3(error: np.array) -> Number: + return error.mean() + + +def fail4(error: Union[xr.Dataset, xr.DataArray]) -> str: + return error.mean() + + +def fail5() -> Number: + return 8.0 + + +stat_fail_funcs = [fail1, fail2, fail3, fail4, fail5] +exceptions = [KeyError, TypeError, TypeError, TypeError, TypeError] + + +@parametrize("args, func, exception", list(zip(stat_args, stat_fail_funcs, exceptions))) +def case_register_continuous_function_fail(args, func, exception): + return args, func, exception + + +class Tester: + @staticmethod + def pass5( + error: Union[xr.Dataset, xr.DataArray], + benchmark_map: Union[xr.Dataset, xr.DataArray], + ) -> Number: + return error / benchmark_map + + @staticmethod + def pass6(error: Union[xr.Dataset, xr.DataArray]) -> Number: + return error * 1.01 + + +stat_names = [["pass5", "pass6"]] +stat_args = [{}] +stat_class = [Tester] + + +@parametrize("names, args, cls", list(zip(stat_names, stat_args, stat_class))) +def case_register_class_continuous(names, args, cls): + return names, args, cls + + +class TesterFail1: + @staticmethod + def fail6(rp: int, fn: int) -> float: + return rp + fn + + +class TesterFail2: + @staticmethod + def mean_absolute_error(error: Union[xr.DataArray, xr.Dataset]) -> Number: + return error + 0.01 + + +stat_args = [{}, {}] +stat_class = [TesterFail1, TesterFail2] +exceptions = [TypeError, KeyError] + + +@parametrize("args, cls, exception", list(zip(stat_args, stat_class, exceptions))) +def case_register_class_continuous_fail(args, cls, exception): + return args, cls, exception + + +stat_funcs = ["mean_absolute_error", "symmetric_mean_absolute_percentage_error"] +stat_params = [["error"], ["error", "candidate_map", "benchmark_map"]] + + +@parametrize("name, params", list(zip(stat_funcs, stat_params))) +def case_get_param_continuous(name, params): + return name, params + + +stat_funcs = ["arbitrary"] + + +@parametrize("name", stat_funcs) +def case_get_param_continuous_fail(name): + return name diff --git a/tests/cases_continuous_metrics.py b/tests/cases_continuous_metrics.py index 4e313f87..1f9c689b 100644 --- a/tests/cases_continuous_metrics.py +++ b/tests/cases_continuous_metrics.py @@ -81,8 +81,8 @@ def case_root_mean_squared_error( ev_mean_signed_error = [ - 0.02950434572994709, - {"band_1": -0.3584839105606079, "band_2": 0.3584839105606079}, + -0.02950434572994709, + {"band_1": 0.3584839105606079, "band_2": -0.3584839105606079}, ] @@ -99,8 +99,8 @@ def case_mean_signed_error(candidate_map, benchmark_map, load_options, expected_ ev_mean_percentage_error = [ - 0.015025136061012745, - {"band_1": -0.1585608273744583, "band_2": 0.1368602067232132}, + -0.015025136061012745, + {"band_1": 0.1585608273744583, "band_2": -0.1368602067232132}, ] diff --git a/tests/test_compute_continuous_metrics.py b/tests/test_compute_continuous_metrics.py index e22e172e..703f31d8 100644 --- a/tests/test_compute_continuous_metrics.py +++ b/tests/test_compute_continuous_metrics.py @@ -8,7 +8,10 @@ import pandas as pd from pytest_cases import parametrize_with_cases +from pytest import raises +import numpy as np +from gval import ContStats as con_stat from gval.comparison.compute_continuous_metrics import _compute_continuous_metrics @@ -23,9 +26,100 @@ def test_compute_continuous_metrics_success( # compute continuous metrics metrics_df = _compute_continuous_metrics( - candidate_map=candidate_map, benchmark_map=benchmark_map, metrics=metrics + agreement_map=candidate_map - benchmark_map, + candidate_map=candidate_map, + benchmark_map=benchmark_map, + metrics=metrics, ) pd.testing.assert_frame_equal( metrics_df, expected_df, check_dtype=False ), "Compute statistics did not return expected values" + + +@parametrize_with_cases( + "names, error_map, exception", glob="compute_continuous_statistics_fail" +) +def test_compute_continuous_statistics_fail(names, error_map, exception): + """tests compute statistics fail function""" + + args = { + "agreement_map": error_map, + "metrics": names, + "candidate_map": None, + "benchmark_map": None, + } + + with np.errstate(divide="ignore"): + with raises(exception): + # NOTE: Removed bc this should be handled within process_statistics. + # stat_names = cat_stat.available_functions() if names == "all" else names + _compute_continuous_metrics(**args) + + +@parametrize_with_cases("args, func", glob="register_continuous_function") +def test_register_continuous_function(args, func): + """tests register func function""" + + con_stat.register_function(**args)(func) + + +@parametrize_with_cases( + "args, func, exception", glob="register_continuous_function_fail" +) +def test_register_continuous_function_fail(args, func, exception): + """tests register func fail function""" + + with raises(exception): + con_stat.register_function(**args)(func) + + +@parametrize_with_cases("names, args, cls", glob="register_class_continuous") +def test_register_class_continuous(names, args, cls): + """tests register class continuous function""" + + con_stat.register_function_class(**args)(cls) + + if [name in con_stat.registered_functions for name in names] != [True] * len(names): + assert False, "Unable to register all class functions" + + +@parametrize_with_cases("args, cls, exception", glob="register_class_continuous_fail") +def test_register_class_continuous_fail(args, cls, exception): + """tests register class continuous fail function""" + + with raises(exception): + con_stat.register_function_class(**args)(cls) + + +@parametrize_with_cases("name, params", glob="get_param_continuous") +def test_get_param_continuous(name, params): + """tests get param continuous function""" + + _params = con_stat.get_parameters(name) + assert _params == params + + +@parametrize_with_cases("name", glob="get_param_continuous_fail") +def test_get_param_continuous_fail(name): + """tests get param continuous fail function""" + + with raises(KeyError): + _ = con_stat.get_parameters(name) + + +def test_get_all_param_continuous(): + """tests get all params function""" + + try: + con_stat.get_all_parameters() + except KeyError: + assert False, "Signature dict not present or keys changed" + + +def test_available_functions_continuous(): + """tests get available functions""" + + a_funcs = con_stat.available_functions() + + assert isinstance(a_funcs, list) diff --git a/tests/test_continuous_metrics.py b/tests/test_continuous_metrics.py index 6a1ef66e..5f58360a 100644 --- a/tests/test_continuous_metrics.py +++ b/tests/test_continuous_metrics.py @@ -34,7 +34,7 @@ def wrapper(candidate_map, benchmark_map, expected_value): ) @assert_logic_decorator def test_mean_absolute_error(candidate_map, benchmark_map, expected_value): - return continuous_stat_funcs.mean_absolute_error(None, candidate_map, benchmark_map) + return continuous_stat_funcs.mean_absolute_error(candidate_map - benchmark_map) @parametrize_with_cases( @@ -43,7 +43,7 @@ def test_mean_absolute_error(candidate_map, benchmark_map, expected_value): ) @assert_logic_decorator def test_mean_squared_error(candidate_map, benchmark_map, expected_value): - return continuous_stat_funcs.mean_squared_error(None, candidate_map, benchmark_map) + return continuous_stat_funcs.mean_squared_error(candidate_map - benchmark_map) @parametrize_with_cases( @@ -52,9 +52,7 @@ def test_mean_squared_error(candidate_map, benchmark_map, expected_value): ) @assert_logic_decorator def test_root_mean_squared_error(candidate_map, benchmark_map, expected_value): - return continuous_stat_funcs.root_mean_squared_error( - None, candidate_map, benchmark_map - ) + return continuous_stat_funcs.root_mean_squared_error(candidate_map - benchmark_map) @parametrize_with_cases( @@ -63,7 +61,7 @@ def test_root_mean_squared_error(candidate_map, benchmark_map, expected_value): ) @assert_logic_decorator def test_mean_signed_error(candidate_map, benchmark_map, expected_value): - return continuous_stat_funcs.mean_signed_error(None, candidate_map, benchmark_map) + return continuous_stat_funcs.mean_signed_error(candidate_map - benchmark_map) @parametrize_with_cases( @@ -73,7 +71,7 @@ def test_mean_signed_error(candidate_map, benchmark_map, expected_value): @assert_logic_decorator def test_mean_percentage_error(candidate_map, benchmark_map, expected_value): return continuous_stat_funcs.mean_percentage_error( - None, candidate_map, benchmark_map + candidate_map - benchmark_map, benchmark_map ) @@ -84,7 +82,7 @@ def test_mean_percentage_error(candidate_map, benchmark_map, expected_value): @assert_logic_decorator def test_mean_absolute_percentage_error(candidate_map, benchmark_map, expected_value): return continuous_stat_funcs.mean_absolute_percentage_error( - None, candidate_map, benchmark_map + candidate_map - benchmark_map, benchmark_map ) @@ -97,7 +95,7 @@ def test_mean_normalized_root_mean_squared_error( candidate_map, benchmark_map, expected_value ): return continuous_stat_funcs.mean_normalized_root_mean_squared_error( - None, candidate_map, benchmark_map + candidate_map - benchmark_map, benchmark_map ) @@ -110,7 +108,7 @@ def test_range_normalized_root_mean_squared_error( candidate_map, benchmark_map, expected_value ): return continuous_stat_funcs.range_normalized_root_mean_squared_error( - None, candidate_map, benchmark_map + candidate_map - benchmark_map, benchmark_map ) @@ -123,7 +121,7 @@ def test_mean_normalized_mean_absolute_error( candidate_map, benchmark_map, expected_value ): return continuous_stat_funcs.mean_normalized_mean_absolute_error( - None, candidate_map, benchmark_map + candidate_map - benchmark_map, benchmark_map ) @@ -136,7 +134,7 @@ def test_range_normalized_mean_absolute_error( candidate_map, benchmark_map, expected_value ): return continuous_stat_funcs.range_normalized_mean_absolute_error( - None, candidate_map, benchmark_map + candidate_map - benchmark_map, benchmark_map ) @@ -147,7 +145,7 @@ def test_range_normalized_mean_absolute_error( @assert_logic_decorator def test_coefficient_of_determination(candidate_map, benchmark_map, expected_value): return continuous_stat_funcs.coefficient_of_determination( - None, candidate_map, benchmark_map + candidate_map - benchmark_map, benchmark_map ) @@ -160,5 +158,5 @@ def test_symmetric_mean_absolute_percentage_error( candidate_map, benchmark_map, expected_value ): return continuous_stat_funcs.symmetric_mean_absolute_percentage_error( - None, candidate_map, benchmark_map + candidate_map - benchmark_map, candidate_map, benchmark_map ) From d2e0e63020e71cfb3f83dc32f38d622980d97665 Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Fri, 23 Jun 2023 20:34:03 -0400 Subject: [PATCH 2/5] Multi plotting and tech debt part 2 --- CONTRIBUTING.MD | 8 +- ....py => compile_readme_and_arrange_docs.py} | 0 docs/sphinx/SPHINX_CONTRIBUTING.MD | 15 +- docs/sphinx/SphinxTutorial.ipynb | 82 ++++----- notebooks/Tutorial.ipynb | 82 ++++----- src/gval/accessors/gval_array.py | 38 ---- src/gval/accessors/gval_dataframe.py | 2 +- src/gval/accessors/gval_xarray.py | 81 ++++++++- src/gval/statistics/categorical_stat_funcs.py | 4 +- src/gval/utils/visualize.py | 166 +++++++++++++----- tests/cases_accessors.py | 61 ++++++- tests/test_accessors.py | 34 +++- 12 files changed, 379 insertions(+), 194 deletions(-) rename docs/{compile_readme.py => compile_readme_and_arrange_docs.py} (100%) diff --git a/CONTRIBUTING.MD b/CONTRIBUTING.MD index daa2b374..ddcf86e4 100644 --- a/CONTRIBUTING.MD +++ b/CONTRIBUTING.MD @@ -51,9 +51,9 @@ automatically with flake8-black but in some cases manual changes will be necessa in code coverage and increase in memory. Also run local_benchmark_test to make sure there is no significant cpu time performance loss. To run everything (in root project directory): `pytest --memray --cov=gval --cov-report term-missing && python ./tests/local_benchmark.py`. -7. If any changes were made to the documentation, make sure to run `python docs/compile_readme.py`. (The README is made -up of other docs located in Markdown. Edit those directly if need be before running the preceding command.) -To build sphinx documentation locally, change to the docs/sphinx folder and run `make clean && make html`. +7. The README is made up of other docs located in /docs/markdown. To make changes to the README edit them directly +then run the following script: `python docs/compile_readme_and_arrange_docs.py`. +8. To build sphinx documentation locally, change to the docs/sphinx folder and run `make clean && make html`. The html will be created in the _build/html folder. Open index.html in a browser to preview docs. 9. Commit your changes: `git commit -m 'feature message'` This will invoke pre-commit hooks mentioned on step 5 that will lint the code. Make sure all of these checks pass, if not make changes and re-commit. @@ -86,7 +86,7 @@ that will lint the code. Make sure all of these checks pass, if not make changes * [Jupyter](https://pypi.org/project/jupyter/) -## Development Installation +## Docker Use (In this case, the image name, "gval-image", and container name, "gval-python" can be changed to whatever name is more suitable. Script, "test.py", does not exist and is an arbitrary placeholder for diff --git a/docs/compile_readme.py b/docs/compile_readme_and_arrange_docs.py similarity index 100% rename from docs/compile_readme.py rename to docs/compile_readme_and_arrange_docs.py diff --git a/docs/sphinx/SPHINX_CONTRIBUTING.MD b/docs/sphinx/SPHINX_CONTRIBUTING.MD index efa12306..ddcf86e4 100644 --- a/docs/sphinx/SPHINX_CONTRIBUTING.MD +++ b/docs/sphinx/SPHINX_CONTRIBUTING.MD @@ -51,14 +51,15 @@ automatically with flake8-black but in some cases manual changes will be necessa in code coverage and increase in memory. Also run local_benchmark_test to make sure there is no significant cpu time performance loss. To run everything (in root project directory): `pytest --memray --cov=gval --cov-report term-missing && python ./tests/local_benchmark.py`. -7. If any changes were made to the documentation, make sure to run `python docs/compile_readme.py`. (The README is made -up of other docs located in Markdown. Edit those directly if need be before running the preceding command.) -To build sphinx documentation locally, change to the docs/sphinx folder and run `make clean && make html`. +7. The README is made up of other docs located in /docs/markdown. To make changes to the README edit them directly +then run the following script: `python docs/compile_readme_and_arrange_docs.py`. +8. To build sphinx documentation locally, change to the docs/sphinx folder and run `make clean && make html`. The html will be created in the _build/html folder. Open index.html in a browser to preview docs. -8. Commit your changes: `git commit -m 'feature message'` This will invoke pre-commit hooks mentioned on step 5 +9. Commit your changes: `git commit -m 'feature message'` This will invoke pre-commit hooks mentioned on step 5 that will lint the code. Make sure all of these checks pass, if not make changes and re-commit. -9. Push to the branch: `git push origin ` -10. Open a pull request (review checklist in PR template before requesting a review) +10. Push to the branch: `git push -u origin`, or if the branch is not pushed up yet: +`git push --set-upstream origin ` +11. Open a pull request (review checklist in PR template before requesting a review) ## Standards @@ -85,7 +86,7 @@ that will lint the code. Make sure all of these checks pass, if not make changes * [Jupyter](https://pypi.org/project/jupyter/) -## Development Installation +## Docker Use (In this case, the image name, "gval-image", and container name, "gval-python" can be changed to whatever name is more suitable. Script, "test.py", does not exist and is an arbitrary placeholder for diff --git a/docs/sphinx/SphinxTutorial.ipynb b/docs/sphinx/SphinxTutorial.ipynb index 5fe7e8bd..20d65591 100644 --- a/docs/sphinx/SphinxTutorial.ipynb +++ b/docs/sphinx/SphinxTutorial.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "a9fa8470", "metadata": {}, "outputs": [], @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "id": "f91c0b8c", "metadata": {}, "outputs": [], @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "id": "541857a7", "metadata": {}, "outputs": [], @@ -132,17 +132,17 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "id": "b1ef13a0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "id": "fdc9df2b", "metadata": {}, "outputs": [ @@ -256,7 +256,7 @@ "3 1 2.0 2.0 24.0 2473405.0" ] }, - "execution_count": 10, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -283,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "id": "16cb3626", "metadata": {}, "outputs": [ @@ -386,7 +386,7 @@ "[1 rows x 22 columns]" ] }, - "execution_count": 11, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -433,7 +433,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "id": "7264ffc9", "metadata": {}, "outputs": [], @@ -452,7 +452,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 8, "id": "e3917e34", "metadata": {}, "outputs": [], @@ -488,17 +488,17 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 9, "id": "c6e3c35c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, @@ -534,17 +534,17 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 10, "id": "a2310a98", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, @@ -591,17 +591,17 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 11, "id": "f6567376", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 16, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, @@ -643,17 +643,17 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "id": "972f07aa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 17, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, @@ -708,7 +708,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 13, "id": "18b9c315", "metadata": {}, "outputs": [ @@ -767,7 +767,7 @@ "1 1 2.0 2.0 4.0 2624301.0" ] }, - "execution_count": 18, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -799,7 +799,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 14, "id": "2ba3fc06", "metadata": {}, "outputs": [ @@ -856,16 +856,16 @@ "0 0.213711 " ] }, - "execution_count": 19, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics = []\n", - "metric_table_select = crosstab_table.gval.compute_metrics(negative_categories= [0, 1],\n", - " positive_categories = [2],\n", - " metrics=['true_positive_rate', 'prevalence'])\n", + "metric_table_select = crosstab_table.gval.compute_categorical_metrics(negative_categories= [0, 1],\n", + " positive_categories = [2],\n", + " metrics=['true_positive_rate', 'prevalence'])\n", "metric_table_select" ] }, @@ -879,7 +879,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 15, "id": "67938408", "metadata": {}, "outputs": [], @@ -901,7 +901,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 16, "id": "1e8eeb59", "metadata": {}, "outputs": [], @@ -928,19 +928,21 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 17, "id": "6a41eee3", "metadata": {}, "outputs": [], "source": [ - "metric_table_register = crosstab_table.gval.compute_metrics(negative_categories= None,\n", - " positive_categories = [2],\n", - " metrics=['error_balance', 'arbitrary1', 'arbitrary2'])" + "metric_table_register = crosstab_table.gval.compute_categorical_metrics(negative_categories= None,\n", + " positive_categories = [2],\n", + " metrics=['error_balance', \n", + " 'arbitrary1', \n", + " 'arbitrary2'])" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 18, "id": "6ab884b7", "metadata": {}, "outputs": [ @@ -992,7 +994,7 @@ "0 1 639227.0 512277.0 NaN 2473405.0 0.801401" ] }, - "execution_count": 23, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -1019,7 +1021,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 19, "id": "899a1da9", "metadata": {}, "outputs": [], diff --git a/notebooks/Tutorial.ipynb b/notebooks/Tutorial.ipynb index 5fe7e8bd..20d65591 100644 --- a/notebooks/Tutorial.ipynb +++ b/notebooks/Tutorial.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "a9fa8470", "metadata": {}, "outputs": [], @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "id": "f91c0b8c", "metadata": {}, "outputs": [], @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "id": "541857a7", "metadata": {}, "outputs": [], @@ -132,17 +132,17 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "id": "b1ef13a0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "id": "fdc9df2b", "metadata": {}, "outputs": [ @@ -256,7 +256,7 @@ "3 1 2.0 2.0 24.0 2473405.0" ] }, - "execution_count": 10, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -283,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "id": "16cb3626", "metadata": {}, "outputs": [ @@ -386,7 +386,7 @@ "[1 rows x 22 columns]" ] }, - "execution_count": 11, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -433,7 +433,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "id": "7264ffc9", "metadata": {}, "outputs": [], @@ -452,7 +452,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 8, "id": "e3917e34", "metadata": {}, "outputs": [], @@ -488,17 +488,17 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 9, "id": "c6e3c35c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, @@ -534,17 +534,17 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 10, "id": "a2310a98", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, @@ -591,17 +591,17 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 11, "id": "f6567376", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 16, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, @@ -643,17 +643,17 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "id": "972f07aa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 17, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, @@ -708,7 +708,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 13, "id": "18b9c315", "metadata": {}, "outputs": [ @@ -767,7 +767,7 @@ "1 1 2.0 2.0 4.0 2624301.0" ] }, - "execution_count": 18, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -799,7 +799,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 14, "id": "2ba3fc06", "metadata": {}, "outputs": [ @@ -856,16 +856,16 @@ "0 0.213711 " ] }, - "execution_count": 19, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics = []\n", - "metric_table_select = crosstab_table.gval.compute_metrics(negative_categories= [0, 1],\n", - " positive_categories = [2],\n", - " metrics=['true_positive_rate', 'prevalence'])\n", + "metric_table_select = crosstab_table.gval.compute_categorical_metrics(negative_categories= [0, 1],\n", + " positive_categories = [2],\n", + " metrics=['true_positive_rate', 'prevalence'])\n", "metric_table_select" ] }, @@ -879,7 +879,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 15, "id": "67938408", "metadata": {}, "outputs": [], @@ -901,7 +901,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 16, "id": "1e8eeb59", "metadata": {}, "outputs": [], @@ -928,19 +928,21 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 17, "id": "6a41eee3", "metadata": {}, "outputs": [], "source": [ - "metric_table_register = crosstab_table.gval.compute_metrics(negative_categories= None,\n", - " positive_categories = [2],\n", - " metrics=['error_balance', 'arbitrary1', 'arbitrary2'])" + "metric_table_register = crosstab_table.gval.compute_categorical_metrics(negative_categories= None,\n", + " positive_categories = [2],\n", + " metrics=['error_balance', \n", + " 'arbitrary1', \n", + " 'arbitrary2'])" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 18, "id": "6ab884b7", "metadata": {}, "outputs": [ @@ -992,7 +994,7 @@ "0 1 639227.0 512277.0 NaN 2473405.0 0.801401" ] }, - "execution_count": 23, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -1019,7 +1021,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 19, "id": "899a1da9", "metadata": {}, "outputs": [], diff --git a/src/gval/accessors/gval_array.py b/src/gval/accessors/gval_array.py index fe2b79b6..476595c4 100644 --- a/src/gval/accessors/gval_array.py +++ b/src/gval/accessors/gval_array.py @@ -1,9 +1,6 @@ -from typing import Tuple - import xarray as xr from gval.accessors.gval_xarray import GVALXarray -from gval.utils.visualize import categorical_plot @xr.register_dataarray_accessor("gval") @@ -21,38 +18,3 @@ class GVALArray(GVALXarray): def __init__(self, xarray_obj: xr.DataArray): super().__init__(xarray_obj) - - def cat_plot( - self, - title: str = "Categorical Map", - colormap: str = "viridis", - figsize: Tuple[int, int] = (6, 4), - legend_labels: list = None, - ): - """ - Plots categorical Map for xarray dataset - - Parameters - __________ - title : str - Title of map, default = "Categorical Map" - colormap : str, default = "viridis" - Colormap of data - figsize : tuple[int, int], default=(6, 4) - Size of the plot - legend_labels : list, default = None - Override labels in legend - - References - ---------- - .. [1] [Matplotlib figure](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.figure.html) - .. [2] [Matplotlib legend](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html) - """ - - return categorical_plot( - self._obj, - title=title, - colormap=colormap, - figsize=figsize, - legend_labels=legend_labels, - ) diff --git a/src/gval/accessors/gval_dataframe.py b/src/gval/accessors/gval_dataframe.py index ca251f8e..8550de96 100644 --- a/src/gval/accessors/gval_dataframe.py +++ b/src/gval/accessors/gval_dataframe.py @@ -22,7 +22,7 @@ class GVALDataFrame: def __init__(self, pandas_obj): self._obj = pandas_obj - def compute_metrics( + def compute_categorical_metrics( self, positive_categories: Union[Number, Iterable[Number]], negative_categories: Union[Number, Iterable[Number]], diff --git a/src/gval/accessors/gval_xarray.py b/src/gval/accessors/gval_xarray.py index 25976467..64857a06 100644 --- a/src/gval/accessors/gval_xarray.py +++ b/src/gval/accessors/gval_xarray.py @@ -17,6 +17,7 @@ from gval.comparison.compute_categorical_metrics import _compute_categorical_metrics from gval.comparison.compute_continuous_metrics import _compute_continuous_metrics from gval.utils.schemas import Crosstab_df, Metrics_df +from gval.utils.visualize import _map_plot from gval.comparison.pairing_functions import difference @@ -93,7 +94,7 @@ def categorical_compare( benchmark_map: Union[gpd.GeoDataFrame, xr.Dataset, xr.DataArray] Benchmark map. positive_categories : Optional[Union[Number, Iterable[Number]]] - Number or list of numbers representing the values to consider as the positive condition. For average types "macro" and "weighted", this represents the categories to compute metrics for. + Number or list of numbers representing the values to consider as the positive condition. When the average argument is either "macro" or "weighted", this represents the categories to compute metrics for. comparison_function : Union[Callable, nb.np.ufunc.dufunc.DUFunc, np.ufunc, np.vectorize, str], default = 'szudzik' Comparison function. Created by decorating function with @nb.vectorize() or using np.ufunc(). Use of numba is preferred as it is faster. Strings with registered comparison_functions are also accepted. Possible options include "pairing_dict". If passing "pairing_dict" value, please see the description for the argument for more information on behaviour. All available comparison functions can be found with gval.Comparison.available_functions(). @@ -127,7 +128,7 @@ def categorical_compare( Macro weighing computes the metrics for each category then averages them. Weighted average computes the metrics for each category then averages them weighted by the number of weights argument in each category. weights : Optional[Iterable[Number]], default = None - Weights to use when computing weighted average. Elements correspond to positive categories in order. + Weights to use when computing weighted average, specifically when the average argument is "weighted". Elements correspond to positive categories in order. Example: @@ -419,3 +420,79 @@ def compute_crosstab( exclude_value, comparison_function, ) + + def cat_plot( + self, + title: str = "Categorical Map", + colormap: str = "viridis", + figsize: Tuple[int, int] = None, + legend_labels: list = None, + plot_bands: Union[str, list] = "all", + ): + """ + Plots categorical Map for xarray object + + Parameters + __________ + title : str + Title of map, default = "Categorical Map" + colormap : str, default = "viridis" + Colormap of data + figsize : tuple[int, int], default=None + Size of the plot + legend_labels : list, default = None + Override labels in legend + plot_bands: Union[str, list], default='all' + What bands to plot + + References + ---------- + .. [1] [Matplotlib figure](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.figure.html) + .. [2] [Matplotlib legend](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html) + """ + + return _map_plot( + self._obj, + title=title, + colormap=colormap, + figsize=figsize, + legend_labels=legend_labels, + plot_type="categorical", + plot_bands=plot_bands, + ) + + def cont_plot( + self, + title: str = "Continuous Map", + colormap: str = "viridis", + figsize: Tuple[int, int] = None, + plot_bands: Union[str, list] = "all", + ): + """ + Plots categorical Map for xarray object + + Parameters + __________ + title : str + Title of map, default = "Categorical Map" + colormap : str, default = "viridis" + Colormap of data + figsize : tuple[int, int], default=None + Size of the plot + plot_bands: Union[str, list], default='all' + What bands to plot + + References + ---------- + .. [1] [Matplotlib figure](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.figure.html) + .. [2] [Matplotlib legend](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html) + """ + + return _map_plot( + self._obj, + title=title, + colormap=colormap, + figsize=figsize, + plot_type="continuous", + plot_bands=plot_bands, + ) diff --git a/src/gval/statistics/categorical_stat_funcs.py b/src/gval/statistics/categorical_stat_funcs.py index f6f68a1c..526c393d 100644 --- a/src/gval/statistics/categorical_stat_funcs.py +++ b/src/gval/statistics/categorical_stat_funcs.py @@ -266,7 +266,7 @@ def prevalence_threshold(tp: Number, tn: Number, fp: Number, fn: Number) -> floa References ---------- - .. [1] [Prevalence Threshold](https://en.wikipedia.org/wiki/Sensitivity_and_specificity#Prevalence_threshold) + .. [1] [Prevalence Threshold](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7540853/) """ return math.sqrt(fp / (fp + tn)) / ( math.sqrt(tp / (tp + fn)) + math.sqrt(fp / (fp + tn)) @@ -295,7 +295,7 @@ def critical_success_index(tp: Number, fp: Number, fn: Number) -> float: References ---------- - .. [1] [Critical Success Index](https://www.weather.gov/media/erh/ta2004-03.pdf) + .. [1] [Critical Success Index](https://www.swpc.noaa.gov/sites/default/files/images/u30/Forecast%20Verification%20Glossary.pdf#page=4) """ return tp / (tp + fn + fp) diff --git a/src/gval/utils/visualize.py b/src/gval/utils/visualize.py index 71d20c1f..087d1867 100644 --- a/src/gval/utils/visualize.py +++ b/src/gval/utils/visualize.py @@ -1,5 +1,5 @@ import warnings -from typing import Tuple +from typing import Tuple, Union import numpy as np import matplotlib @@ -8,15 +8,17 @@ import xarray as xr -def categorical_plot( - ds: xr.DataArray, +def _map_plot( + ds: Union[xr.DataArray, xr.Dataset], title: str = "Categorical Map", colormap: str = "viridis", - figsize: Tuple[int, int] = (6, 4), + figsize: Tuple[int, int] = None, legend_labels: list = None, + plot_bands: Union[str, list] = "all", + plot_type: str = "categorical", ): """ - Plots categorical Map for xarray dataset + Plots categorical or continuous Map for xarray object Parameters __________ @@ -30,10 +32,15 @@ def categorical_plot( Size of the plot legend_labels : list, default = None Override labels in legend + plot_bands : Union[str, list], default = 'all' + Which bands to plot if multiple. Default is all bands. + plot_type : str, default = 'categorical', options + Whether to plot the map as a categorical map Returns ------- - QuadMesh Matplotlib object + Union[xr.DataArray, xr.Dataset] + QuadMesh Matplotlib object Raises ------ @@ -46,50 +53,123 @@ def categorical_plot( .. [2] [Matplotlib legend](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html) """ + categorical = True if plot_type == "categorical" else False + with warnings.catch_warnings(): warnings.simplefilter("ignore") - # Setup figure, axis, and plot - fig, ax1 = plt.subplots(figsize=figsize) - plot = ds.plot(ax=ax1, cmap=colormap) - - # Print CRS if not large WKT String - crs = ds.rio.crs if len(ds.rio.crs) < 25 else "" - ax1.set_title(f"{title} ({crs})") - ax1.set_xlabel("Longitude") - ax1.set_ylabel("Latitude") - - # Get colormap values for each unique value - cmap = matplotlib.colormaps[colormap] - unique_vals = np.unique(ds) + # Get datasets + if isinstance(ds, xr.Dataset): + ds_list = [ + ds[x] + for x in ds.data_vars + if plot_bands == "all" or int(x.split("_")[1]) in plot_bands + ] + bands = ( + [x + 1 for x in range(len(ds.data_vars))] + if plot_bands == "all" + else plot_bands + ) + + else: + if len(ds.shape) == 3: + ds_list = [ + ds.sel({"band": x + 1}) + for x in range(ds.shape[0]) + if plot_bands == "all" or x in plot_bands + ] + bands = ( + [x + 1 for x in range(ds.shape[0])] + if plot_bands == "all" + else plot_bands + ) + elif len(ds.shape) > 3 or len(ds.shape) < 2: + raise ValueError("Needs to be 2 or 3 dimensional xarray object") + else: + ds_list = [ds] + bands = [1] + + if len(ds_list) > 8: + raise ValueError("Cannot plot more than 8 DataArrays at a time") + + cols = 2 if len(ds_list) > 1 else 1 + rows = ((len(ds_list) - 1) // 2) + 1 + + if figsize is None: + figsize = (5 * cols, 4 * rows) - if len(unique_vals) > 25: - raise ValueError("Too many values present in dataset for categorical plot.") - - if legend_labels is not None and len(legend_labels) != len(unique_vals): - raise ValueError("Need as many labels as unique values.") - - # Setup normalized color ramp - norm = matplotlib.colors.Normalize( - vmin=np.nanmin(unique_vals), vmax=np.nanmax(unique_vals) + # Setup figure, axis, and plot + fig, axs = plt.subplots(rows, cols, figsize=figsize) + axes = axs.ravel() if "ravel" in dir(axs) else [axs] + + for i, ax in enumerate(axes): + if i >= len(ds_list): + continue + + ds_c = ds_list[i] + plot = ds_c.plot(ax=ax, cmap=colormap) + + # Print CRS if not large WKT String + crs = ds_c.rio.crs if len(ds_c.rio.crs) < 25 else "" + ax.set_title(f"Band {bands[i]}") + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + + if categorical: + # Get colormap values for each unique value + cmap = matplotlib.colormaps[colormap] + unique_vals = np.unique(ds_c) + + if len(unique_vals) > 25: + raise ValueError( + "Too many values present in dataset for categorical plot." + ) + + if legend_labels is not None and len(legend_labels) != len(unique_vals): + raise ValueError("Need as many labels as unique values.") + + # Setup normalized color ramp + norm = matplotlib.colors.Normalize( + vmin=np.nanmin(unique_vals), vmax=np.nanmax(unique_vals) + ) + + # Create legend + labels = unique_vals if legend_labels is None else legend_labels + legend_elements = [ + Patch(color=cmap(norm(val)), label=str(label)) + for val, label in zip(unique_vals, labels) + if not np.isnan(val) + ] + ax.legend( + title="Encodings", + handles=legend_elements, + bbox_to_anchor=(1.05, 1.0), + loc="upper left", + ) + + fig.suptitle(f"{title} ({crs})", fontsize=15) + + if len(ds_list) <= 2: + top = 0.85 + elif len(ds_list) <= 4: + top = 0.9 + elif len(ds_list) <= 6: + top = 0.925 + else: + top = 0.95 + + plt.subplots_adjust( + top=top, bottom=0.1, left=0.125, right=0.9, hspace=0.6, wspace=0.5 ) - # Create legend - labels = unique_vals if legend_labels is None else legend_labels - legend_elements = [ - Patch(color=cmap(norm(val)), label=str(label)) - for val, label in zip(unique_vals, labels) - if not np.isnan(val) - ] - ax1.legend( - title="Encodings", - handles=legend_elements, - bbox_to_anchor=(1.05, 1.0), - loc="upper left", - ) + if categorical: + # Erase color bar and autoformat x labels to not overlap + while len(fig.axes) > len(ds_list): + fig.delaxes(fig.axes[len(ds_list)]) + # Erase extra axis if present + elif len(ds_list) % 2 == 1 and len(ds_list) > 1: + fig.delaxes(fig.axes[len(ds_list)]) - # Erase color bar and autoformat x labels to not overlap - fig.delaxes(fig.axes[1]) fig.autofmt_xdate() fig.show() diff --git a/tests/cases_accessors.py b/tests/cases_accessors.py index d648b22a..d7f7274a 100644 --- a/tests/cases_accessors.py +++ b/tests/cases_accessors.py @@ -56,6 +56,21 @@ chunks="auto", ), ] +plot_maps = [ + _load_xarray("categorical_multiband_4.tif", mask_and_scale=True), + _load_xarray("categorical_multiband_6.tif", mask_and_scale=True), + _load_xarray( + "categorical_multiband_8.tif", mask_and_scale=True, band_as_variable=True + ).drop_vars("band_8"), + _load_xarray("categorical_multiband_10.tif", mask_and_scale=True), + _load_xarray("categorical_multiband_4.tif", mask_and_scale=True).sel( + {"band": [1, 2, 3]} + ), + _load_xarray("categorical_multiband_4.tif", mask_and_scale=True).sel( + {"band": 1, "x": -169463.7041} + ), +] + positive_cat = np.array([2, 2, 2, 2, 2, 2]) negative_cat = np.array([[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]) @@ -274,10 +289,16 @@ def case_data_frame_accessor_compute_metrics( @parametrize( - "candidate_map, crs", + "candidate_map, crs, entries", list( zip( - [candidate_maps[0]] * 2, + [ + candidate_maps[0], + candidate_maps[1], + plot_maps[0], + plot_maps[1], + plot_maps[2], + ], [ "EPSG:5070", """PROJCS["NAD27 / California zone II", @@ -291,24 +312,48 @@ def case_data_frame_accessor_compute_metrics( PARAMETER["central_meridian",-122], PARAMETER["false_easting",2000000], PARAMETER["false_northing",0],UNIT["Foot_US",0.30480060960121924]]""", + "EPSG:5070", + "EPSG:5070", + "EPSG:5070", ], + [2, 2, 3, 3, 4], ) ), ) -def case_data_array_accessor_categorical_plot_success(candidate_map, crs): - return candidate_map, crs +def case_categorical_plot_success(candidate_map, crs, entries): + return candidate_map, crs, entries @parametrize( "candidate_map, legend_labels, num_classes", - list(zip([candidate_maps[0]] * 2, [None, ["a", "b", "c"]], [30, 2])), + list( + zip( + [candidate_maps[0], candidate_maps[0], plot_maps[3]], + [None, ["a", "b", "c"], ["a", "b"]], + [30, 2, 2], + ) + ), ) -def case_data_array_accessor_categorical_plot_fail( - candidate_map, legend_labels, num_classes -): +def case_categorical_plot_fail(candidate_map, legend_labels, num_classes): return candidate_map, legend_labels, num_classes +@parametrize( + "candidate_map, axes", + list(zip([candidate_maps[0], plot_maps[4]], [2, 6])), +) +def case_continuous_plot_success(candidate_map, axes): + return candidate_map, axes + + +@parametrize( + "candidate_map", + [plot_maps[3], plot_maps[5]], +) +def case_continuous_plot_fail(candidate_map): + return candidate_map + + candidate_maps = ["candidate_continuous_0.tif", "candidate_continuous_1.tif"] benchmark_maps = ["benchmark_continuous_0.tif", "benchmark_continuous_1.tif"] diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 1924af23..05e0c13a 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -228,7 +228,7 @@ def test_data_set_accessor_crosstab_table_fail(candidate_map, benchmark_map): def test_data_frame_accessor_compute_metrics( crosstab_df, positive_categories, negative_categories ): - data = crosstab_df.gval.compute_metrics( + data = crosstab_df.gval.compute_categorical_metrics( positive_categories=positive_categories, negative_categories=negative_categories ) @@ -236,27 +236,43 @@ def test_data_frame_accessor_compute_metrics( @parametrize_with_cases( - "candidate_map, crs", - glob="data_array_accessor_categorical_plot_success", + "candidate_map, crs, entries", + glob="categorical_plot_success", ) -def test_data_array_accessor_categorical_plot_success(candidate_map, crs): +def test_categorical_plot_success(candidate_map, crs, entries): candidate_map.rio.set_crs(crs) viz_object = candidate_map.gval.cat_plot() - assert len(viz_object.axes.get_legend().texts) == 2 + assert len(viz_object.axes.get_legend().texts) == entries @parametrize_with_cases( "candidate_map, legend_labels, num_classes", - glob="data_array_accessor_categorical_plot_fail", + glob="categorical_plot_fail", ) -def test_data_array_accessor_categorical_plot_fail( - candidate_map, legend_labels, num_classes -): +def test_categorical_plot_fail(candidate_map, legend_labels, num_classes): candidate_map.data = np.random.choice(np.arange(num_classes), candidate_map.shape) with raises(ValueError): _ = candidate_map.gval.cat_plot(legend_labels=legend_labels) +@parametrize_with_cases( + "candidate_map, axes", + glob="continuous_plot_success", +) +def test_continuous_plot_success(candidate_map, axes): + viz_object = candidate_map.gval.cont_plot() + assert len(viz_object.figure.axes) == axes + + +@parametrize_with_cases( + "candidate_map", + glob="continuous_plot_fail", +) +def test_continuous_plot_fail(candidate_map): + with raises(ValueError): + _ = candidate_map.gval.cont_plot() + + @parametrize_with_cases( "candidate_map, benchmark_map", glob="data_array_accessor_continuous", From da42247d52a4bc3fd0d682af78e1ddb43e0e4f73 Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Fri, 23 Jun 2023 21:09:47 -0400 Subject: [PATCH 3/5] Adjust size of datasets for memory --- tests/cases_accessors.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/cases_accessors.py b/tests/cases_accessors.py index d7f7274a..d67b773a 100644 --- a/tests/cases_accessors.py +++ b/tests/cases_accessors.py @@ -57,15 +57,23 @@ ), ] plot_maps = [ - _load_xarray("categorical_multiband_4.tif", mask_and_scale=True), - _load_xarray("categorical_multiband_6.tif", mask_and_scale=True), + _load_xarray("categorical_multiband_4.tif", mask_and_scale=True).rio.reproject( + dst_crs="EPSG:5070", resolution=50 + ), + _load_xarray("categorical_multiband_6.tif", mask_and_scale=True).rio.reproject( + dst_crs="EPSG:5070", resolution=50 + ), _load_xarray( "categorical_multiband_8.tif", mask_and_scale=True, band_as_variable=True - ).drop_vars("band_8"), - _load_xarray("categorical_multiband_10.tif", mask_and_scale=True), - _load_xarray("categorical_multiband_4.tif", mask_and_scale=True).sel( - {"band": [1, 2, 3]} + ) + .drop_vars("band_8") + .rio.reproject(dst_crs="EPSG:5070", resolution=10), + _load_xarray("categorical_multiband_10.tif", mask_and_scale=True).rio.reproject( + dst_crs="EPSG:5070", resolution=50 ), + _load_xarray("categorical_multiband_4.tif", mask_and_scale=True) + .sel({"band": [1, 2, 3]}) + .rio.reproject(dst_crs="EPSG:5070", resolution=50), _load_xarray("categorical_multiband_4.tif", mask_and_scale=True).sel( {"band": 1, "x": -169463.7041} ), From 134d9784bfce82536617e80017c2c2944d484fda Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Fri, 23 Jun 2023 21:22:53 -0400 Subject: [PATCH 4/5] Reupload to s3 --- tests/cases_accessors.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/cases_accessors.py b/tests/cases_accessors.py index d67b773a..d7f7274a 100644 --- a/tests/cases_accessors.py +++ b/tests/cases_accessors.py @@ -57,23 +57,15 @@ ), ] plot_maps = [ - _load_xarray("categorical_multiband_4.tif", mask_and_scale=True).rio.reproject( - dst_crs="EPSG:5070", resolution=50 - ), - _load_xarray("categorical_multiband_6.tif", mask_and_scale=True).rio.reproject( - dst_crs="EPSG:5070", resolution=50 - ), + _load_xarray("categorical_multiband_4.tif", mask_and_scale=True), + _load_xarray("categorical_multiband_6.tif", mask_and_scale=True), _load_xarray( "categorical_multiband_8.tif", mask_and_scale=True, band_as_variable=True - ) - .drop_vars("band_8") - .rio.reproject(dst_crs="EPSG:5070", resolution=10), - _load_xarray("categorical_multiband_10.tif", mask_and_scale=True).rio.reproject( - dst_crs="EPSG:5070", resolution=50 + ).drop_vars("band_8"), + _load_xarray("categorical_multiband_10.tif", mask_and_scale=True), + _load_xarray("categorical_multiband_4.tif", mask_and_scale=True).sel( + {"band": [1, 2, 3]} ), - _load_xarray("categorical_multiband_4.tif", mask_and_scale=True) - .sel({"band": [1, 2, 3]}) - .rio.reproject(dst_crs="EPSG:5070", resolution=50), _load_xarray("categorical_multiband_4.tif", mask_and_scale=True).sel( {"band": 1, "x": -169463.7041} ), From 93efdde79019e1e6fc5fa776e0dc8abf7cb88f48 Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Fri, 23 Jun 2023 21:45:13 -0400 Subject: [PATCH 5/5] Decrease size of geotiffs --- tests/cases_accessors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cases_accessors.py b/tests/cases_accessors.py index d7f7274a..cd77d6f5 100644 --- a/tests/cases_accessors.py +++ b/tests/cases_accessors.py @@ -67,7 +67,7 @@ {"band": [1, 2, 3]} ), _load_xarray("categorical_multiband_4.tif", mask_and_scale=True).sel( - {"band": 1, "x": -169463.7041} + {"band": 1, "x": -169443.7041}, method="nearest" ), ]