Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STAC Catalog Functionality #184

Merged
merged 12 commits into from
Mar 27, 2024
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [
requires-python = ">=3.8"
keywords = ["geospatial", "evaluations"]
license = {text = "MIT"}
version = "0.2.5"
version = "0.2.6"
dynamic = ["readme", "dependencies"]

[project.optional-dependencies]
Expand Down
9 changes: 6 additions & 3 deletions src/gval/accessors/gval_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,18 @@ def __handle_attribute_tracking(
else:
del attribute_tracking_kwargs["agreement_map"]

results = candidate_map.gval.attribute_tracking_xarray(
results = _attribute_tracking_xarray(
candidate_map=candidate_map,
benchmark_map=benchmark_map,
agreement_map=agreement_map,
**attribute_tracking_kwargs,
)

else:
results = candidate_map.gval.attribute_tracking_xarray(
benchmark_map=benchmark_map, agreement_map=agreement_map
results = _attribute_tracking_xarray(
candidate_map=candidate_map,
benchmark_map=benchmark_map,
agreement_map=agreement_map,
)

return results
Expand Down
9 changes: 9 additions & 0 deletions src/gval/comparison/tabulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,15 @@ def _crosstab_Datasets(agreement_map: xr.DataArray) -> DataFrame[Crosstab_df]:
# loop variables
previous_crosstab_df = None # initializing to avoid having unset
for i, b in enumerate(agreement_variable_names):
# Pass pairing dictionary to variable if necessary
GregoryPetrochenkov-NOAA marked this conversation as resolved.
Show resolved Hide resolved
if (
agreement_map[b].attrs.get("pairing_dictionary") is None
and agreement_map.attrs.get("pairing_dictionary") is not None
):
agreement_map[b].attrs["pairing_dictionary"] = agreement_map.attrs[
"pairing_dictionary"
]

crosstab_df = _crosstab_2d_DataArrays(
agreement_map=agreement_map[b], band_value=b
)
Expand Down
221 changes: 142 additions & 79 deletions src/gval/utils/loading_datasets.py
fernando-aristizabal marked this conversation as resolved.
Show resolved Hide resolved
GregoryPetrochenkov-NOAA marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from numbers import Number
import ast

import pandas as pd
import rioxarray as rxr
import xarray as xr
import numpy as np
from shapely.geometry import MultiPoint, shape
from tempfile import NamedTemporaryFile
from rio_cogeo.cogeo import cog_translate
from rio_cogeo.profiles import cog_profiles
import pystac_client
from pystac.item_collection import ItemCollection

import stackstac

Expand Down Expand Up @@ -324,40 +326,22 @@ def _set_crs(stack: xr.DataArray, band_metadata: list = None) -> Number:


def get_stac_data(
GregoryPetrochenkov-NOAA marked this conversation as resolved.
Show resolved Hide resolved
url: str,
collection: str,
time: str,
stac_items: ItemCollection,
bands: list = None,
query: str = None,
time_aggregate: str = None,
max_items: int = None,
intersects: dict = None,
bbox: list = None,
resolution: int = None,
nodata_fill: Number = None,
) -> xr.Dataset:
"""
"""Transform STAC Items in to an xarray object

Parameters
----------
url : str
Address hosting the STAC API
collection : str
Name of collection to get (currently limited to one)
time : str
Single or range of values to query in the time dimension
stac_items : ItemCollection
STAC Item Collection returned from pystac client
bands: list, default = None
Bands to retrieve from service
query : str, default = None
String command to filter data
time_aggregate : str, default = None
Method to aggregate multiple time stamps
max_items : int, default = None
The maximum amount of records to retrieve
intersects : dict, default = None
Dictionary representing the type of geometry and its respective coordinates
bbox : list, default = None
Coordinates to filter the spatial range of request
resolution : int, default = 10
Resolution to get data from
nodata_fill : Number, default = None
Expand All @@ -368,75 +352,154 @@ def get_stac_data(
xr.Dataset
Xarray object with resepective STAC API data

Raises
------
ValueError
A valid aggregate must be used for time ranges

"""

with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Call cataloging url, search, and convert to xarray
catalog = pystac_client.Client.open(url)

stac_items = catalog.search(
datetime=time,
collections=[collection],
max_items=max_items,
intersects=intersects,
bbox=bbox,
query=query,
).get_all_items()

stack = stackstac.stack(stac_items, resolution=resolution)

# Only get unique time indices in case there are duplicates
_, idxs = np.unique(stack.coords["time"], return_index=True)
stack = stack[idxs]

# Aggregate if there is more than one time
if stack.coords["time"].shape[0] > 1:
crs = stack.rio.crs
if time_aggregate == "mean":
stack = stack.mean(dim="time")
stack.attrs["time_aggregate"] = "mean"
elif time_aggregate == "min":
stack = stack.min(dim="time")
stack.attrs["time_aggregate"] = "min"
elif time_aggregate == "max":
stack = stack.max(dim="time")
stack.attrs["time_aggregate"] = "max"
else:
raise ValueError("A valid aggregate must be used for time ranges")

stack.rio.write_crs(crs, inplace=True)
# Only get unique time indices in case there are duplicates
_, idxs = np.unique(stack.coords["time"], return_index=True)
stack = stack[idxs]

# Aggregate if there is more than one time
if stack.coords["time"].shape[0] > 1:
crs = stack.rio.crs
if time_aggregate == "mean":
stack = stack.mean(dim="time")
stack.attrs["time_aggregate"] = "mean"
elif time_aggregate == "min":
stack = stack.min(dim="time")
stack.attrs["time_aggregate"] = "min"
elif time_aggregate == "max":
stack = stack.max(dim="time")
stack.attrs["time_aggregate"] = "max"
else:
stack = stack[0]
stack.attrs["time_aggregate"] = "none"
raise ValueError("A valid aggregate must be used for time ranges")

stack.rio.write_crs(crs, inplace=True)
else:
stack = stack[0]
stack.attrs["time_aggregate"] = "none"

# Select specific bands
if bands is not None:
bands = [bands] if isinstance(bands, str) else bands
stack = stack.sel({"band": bands})

band_metadata = (
stack.coords["raster:bands"] if "raster:bands" in stack.coords else None
)
if "band" in stack.dims:
og_names = [name for name in stack.coords["band"]]
names = [f"band_{x + 1}" for x in range(len(stack.coords["band"]))]
stack = stack.assign_coords({"band": names}).to_dataset(dim="band")

# Select specific bands
if bands is not None:
bands = [bands] if isinstance(bands, str) else bands
stack = stack.sel({"band": bands})
for metadata, var, og_var in zip(band_metadata, stack.data_vars, og_names):
_set_nodata(stack[var], metadata, nodata_fill)
stack[var] = _set_crs(stack[var], band_metadata)
stack[var].attrs["original_name"] = og_var

band_metadata = (
stack.coords["raster:bands"] if "raster:bands" in stack.coords else None
else:
stack = stack.to_dataset(name="band_1")
_set_nodata(stack["band_1"], band_metadata, nodata_fill)
stack["band_1"] = _set_crs(stack["band_1"])
stack["band_1"].attrs["original_name"] = (
bands[0] if isinstance(bands, list) else bands
)
if "band" in stack.dims:
og_names = [name for name in stack.coords["band"]]
names = [f"band_{x + 1}" for x in range(len(stack.coords["band"]))]
stack = stack.assign_coords({"band": names}).to_dataset(dim="band")

for metadata, var, og_var in zip(band_metadata, stack.data_vars, og_names):
_set_nodata(stack[var], metadata, nodata_fill)
stack[var] = _set_crs(stack[var], band_metadata)
stack[var].attrs["original_name"] = og_var
return stack

else:
stack = stack.to_dataset(name="band_1")
_set_nodata(stack["band_1"], band_metadata, nodata_fill)
stack["band_1"] = _set_crs(stack["band_1"])
stack["band_1"].attrs["original_name"] = (
bands[0] if isinstance(bands, list) else bands
)

return stack
def stac_to_df(
stac_items: ItemCollection,
GregoryPetrochenkov-NOAA marked this conversation as resolved.
Show resolved Hide resolved
GregoryPetrochenkov-NOAA marked this conversation as resolved.
Show resolved Hide resolved
assets: list = None,
column_allow_list: list = None,
GregoryPetrochenkov-NOAA marked this conversation as resolved.
Show resolved Hide resolved
column_block_list: list = None,
) -> pd.DataFrame:
"""Convert STAC Items in to a DataFrame

Parameters
----------
stac_items: ItemCollection
STAC Item Collection returned from pystac client
assets : list, default = None
Assets to keep, (keep all if None)
column_allow_list: list, default = None
List of columns to allow in the result DataFrame
column_block_list: list, default = None
List of columns to remove in the result DataFrame

Returns
-------
pd.DataFrame
A DataFrame with rows for each unique item/asset combination

Raises
------
ValueError
No entries in DataFrame due to nonexistent asset

"""

item_dfs, compare_idx = [], 1

# Iterate through each STAC Item
GregoryPetrochenkov-NOAA marked this conversation as resolved.
Show resolved Hide resolved
for item in stac_items:
item_dict = item.to_dict()
item_columns = {}

# Get columns for all collection level and item level properties
for key, val in item_dict["properties"].items():
if not isinstance(val, list):
if isinstance(val, dict):
for k, v in val.items():
item_columns[k] = [v]
else:
item_columns[key] = [val]

item_columns["bbox"] = MultiPoint(np.array(item_dict["bbox"]).reshape(2, 2)).wkt
item_columns["geometry"] = shape(item_dict["geometry"]).wkt

unique_keys = []
for k, v in item_dict["assets"].items():
for key in v.keys():
if key not in unique_keys:
unique_keys.append(key)

# Create new row for each asset with and assign compare_id and map_id
asset_dfs = []
for k, v in item_dict["assets"].items():
if assets is None or k in assets:
asset_columns = item_columns.copy()

asset_columns[key] = [str(v.get(key, "N/a"))]
asset_columns["compare_id"] = compare_idx
asset_columns["map_id"] = v["href"]
compare_idx += 1
asset_columns["asset"] = [k]
for key in unique_keys:
asset_columns[key] = [str(v.get(key, "N/a"))]

asset_dfs.append(pd.DataFrame(asset_columns))

item_dfs.append(pd.concat(asset_dfs))

# Concatenate the DataFrames and remove unwanted columns if allow and block lists exist
catalog_df = pd.concat(item_dfs, ignore_index=True)

if column_allow_list is not None:
catalog_df = catalog_df[column_allow_list]

if column_block_list is not None:
catalog_df = catalog_df.drop(column_block_list, axis=1)

return catalog_df


def _create_circle_mask(
Expand Down
Loading
Loading