Skip to content

Commit

Permalink
Refctor datacube object, minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
m-mohr committed Dec 15, 2023
1 parent f1e6a6d commit 17e4444
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 41 deletions.
2 changes: 1 addition & 1 deletion src/openeo_test_suite/lib/process_runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def get_nodata_value(self) -> Any:
"""
Returns the nodata value of the backend.
"""
return None
return None
8 changes: 7 additions & 1 deletion src/openeo_test_suite/lib/process_runner/dask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import inspect

import dask
from openeo_pg_parser_networkx import OpenEOProcessGraph, ProcessRegistry
from openeo_pg_parser_networkx.process_registry import Process
from openeo_processes_dask.process_implementations.core import process
Expand All @@ -26,6 +27,7 @@ def create_process_registry():

# not sure why this is needed
from openeo_processes_dask.process_implementations.math import e

processes_from_module.append(e)

specs_module = importlib.import_module("openeo_processes_dask.specs")
Expand Down Expand Up @@ -66,9 +68,13 @@ def encode_datacube(self, data):
return datacube_to_xarray(data)

def decode_data(self, data, expected):
if isinstance(data, dask.array.core.Array):
data = data.compute()

data = numpy_to_native(data, expected)
data = xarray_to_datacube(data)

return data

def get_nodata_value(self):
return float('nan')
return float("nan")
30 changes: 16 additions & 14 deletions src/openeo_test_suite/lib/process_runner/util.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from dateutil.parser import parse
from datetime import datetime, timezone

import dask
import numpy as np
import xarray as xr
from dateutil.parser import parse


def numpy_to_native(data, expected):
if isinstance(data, dask.array.core.Array):
data = data.compute()

# Converting numpy dtypes to native python types
if isinstance(data, np.ndarray) or isinstance(data, np.generic):
if isinstance(expected, list):
Expand All @@ -28,7 +24,8 @@ def numpy_to_native(data, expected):
def datacube_to_xarray(cube):
coords = []
crs = None
for dim in cube["dimensions"]:
for name in cube["order"]:
dim = cube["dimensions"][name]
if dim["type"] == "temporal":
# date replace for older Python versions that don't support ISO parsing (only available since 3.11)
values = [
Expand All @@ -41,7 +38,7 @@ def datacube_to_xarray(cube):
else:
values = dim["values"]

coords.append((dim["name"], values))
coords.append((name, values))

da = xr.DataArray(cube["data"], coords=coords)
if crs is not None:
Expand All @@ -52,13 +49,12 @@ def datacube_to_xarray(cube):


def xarray_to_datacube(data):
if isinstance(data, dask.array.core.Array):
data = xr.DataArray(data.compute())

if not isinstance(data, xr.DataArray):
return data

dims = []
order = list(data.dims)

dims = {}
for c in data.coords:
type = "bands"
values = []
Expand All @@ -75,14 +71,20 @@ def xarray_to_datacube(data):
type = "spatial"
axis = "y"

dim = {"name": c, "type": type, "values": values}
dim = {"type": type, "values": values}
if axis is not None:
dim["axis"] = axis
if "crs" in data.attrs:
dim["reference_system"] = data.attrs["crs"] # todo: non-standardized
dims.append(dim)

cube = {"type": "datacube", "dimensions": dims, "data": data.values.tolist()}
dims[c] = dim

cube = {
"type": "datacube",
"order": order,
"dimensions": dims,
"data": data.values.tolist(),
}

if "nodata" in data.attrs:
cube["nodata"] = data.attrs["nodata"] # todo: non-standardized
Expand Down
2 changes: 1 addition & 1 deletion src/openeo_test_suite/lib/process_runner/vito.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def decode_data(self, data, expected):
return data

def get_nodata_value(self):
return float('nan')
return float("nan")
65 changes: 42 additions & 23 deletions src/openeo_test_suite/tests/processes/processing/test_example.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import math
import warnings
from pathlib import Path, posixpath
from openeo_test_suite.lib.process_runner.util import isostr_to_datetime

import json5
import pytest
import xarray as xr
from deepdiff import DeepDiff

from openeo_test_suite.lib.process_runner.util import isostr_to_datetime

# glob path to the test files
examples_path = "assets/processes/tests/*.json5"

Expand Down Expand Up @@ -49,11 +50,7 @@ def test_process(connection, process_levels, processes, id, example, file, level
)
)
elif len(processes) > 0 and id not in processes:
pytest.skip(
"Skipping process {} because it is not in the specified processes: {}".format(
id, ", ".join(processes)
)
)
pytest.skip("Skipping process {} because it is not in the specified processes".format(id))

# check whether the process is available
try:
Expand Down Expand Up @@ -98,15 +95,22 @@ def test_process(connection, process_levels, processes, id, example, file, level
elif returns:
check_return_value(example, result, connection, file)
else:
pytest.skip("Test for process {} doesn't provide an expected result for arguments: {}".format(id, example["arguments"]))
pytest.skip(
"Test for process {} doesn't provide an expected result for arguments: {}".format(
id, example["arguments"]
)
)


def prepare_arguments(arguments, process_id, connection, file):
for name in arguments:
arguments[name] = prepare_argument(arguments[name], process_id, name, connection, file)

arguments[name] = prepare_argument(
arguments[name], process_id, name, connection, file
)

return arguments


def prepare_argument(arg, process_id, name, connection, file):
# handle external references to files
if isinstance(arg, dict) and "$ref" in arg:
Expand All @@ -128,8 +132,10 @@ def prepare_argument(arg, process_id, name, connection, file):
arg = connection.encode_process_graph(arg, process_id, name)
else:
for key in arg:
arg[key] = prepare_argument(arg[key], process_id, name, connection, file)

arg[key] = prepare_argument(
arg[key], process_id, name, connection, file
)

elif isinstance(arg, list):
for i in range(len(arg)):
arg[i] = prepare_argument(arg[i], process_id, name, connection, file)
Expand All @@ -142,15 +148,15 @@ def prepare_argument(arg, process_id, name, connection, file):
return arg


def prepare_results(connection, file, example, result = None):
def prepare_results(connection, file, example, result=None):
# go through the example and result recursively and convert datetimes to iso strings
# could be used for more conversions in the future...

if isinstance(example, dict):
# handle external references to files
if isinstance(example, dict) and "$ref" in example:
example = load_ref(example["$ref"], file)

if "type" in example:
if example["type"] == "datetime":
example = isostr_to_datetime(example["value"])
Expand All @@ -165,14 +171,18 @@ def prepare_results(connection, file, example, result = None):
if key not in result:
(example[key], _) = prepare_results(connection, file, example[key])
else:
(example[key], result[key]) = prepare_results(connection, file, example[key], result[key])

(example[key], result[key]) = prepare_results(
connection, file, example[key], result[key]
)

elif isinstance(example, list):
for i in range(len(example)):
if i >= len(result):
(example[i], _) = prepare_results(connection, file, example[i])
else:
(example[i], result[i]) = prepare_results(connection, file, example[i], result[i])
(example[i], result[i]) = prepare_results(
connection, file, example[i], result[i]
)

return (example, result)

Expand Down Expand Up @@ -205,7 +215,9 @@ def check_non_json_values(value):


def check_exception(example, result):
assert isinstance(result, Exception), "Excpected an exception, but got {}".format(result)
assert isinstance(result, Exception), "Excpected an exception, but got {}".format(
result
)
if isinstance(example["throws"], str):
if result.__class__.__name__ != example["throws"]:
warnings.warn(
Expand All @@ -218,32 +230,37 @@ def check_exception(example, result):


def check_return_value(example, result, connection, file):
assert not isinstance(result, Exception), "Unexpected exception: {} ".format(str(result))
assert not isinstance(result, Exception), "Unexpected exception: {} ".format(
str(result)
)

# handle custom types of data
result = connection.decode_data(result, example["returns"])

# decode special types (currently mostly datetimes and nodata)
(example["returns"], result) = prepare_results(connection, file, example["returns"], result)
(example["returns"], result) = prepare_results(
connection, file, example["returns"], result
)

delta = example["delta"] if "delta" in example else 0.0000000001

if isinstance(example["returns"], dict):
assert isinstance(result, dict), "Expected a dict but got {}".format(type(result))
assert isinstance(result, dict), "Expected a dict but got {}".format(
type(result)
)
exclude_regex_paths = []
exclude_paths = []
ignore_order_func = None
if "type" in example["returns"] and example["returns"]["type"] == "datacube":
# todo: non-standardized
exclude_regex_paths.append(
r"root\['dimensions'\]\[\d+\]\['reference_system'\]"
r"root\['dimensions'\]\[[^\]]+\]\['reference_system'\]"
)
# todo: non-standardized
exclude_paths.append("root['nodata']")
# ignore data if operation is not changing data
if example["returns"]["data"] is None:
exclude_paths.append("root['data']")
ignore_order_func = lambda level: "dimensions" in level.path()

diff = DeepDiff(
example["returns"],
Expand All @@ -257,7 +274,9 @@ def check_return_value(example, result, connection, file):
)
assert {} == diff, "Differences: {}".format(str(diff))
elif isinstance(example["returns"], list):
assert isinstance(result, list), "Expected a list but got {}".format(type(result))
assert isinstance(result, list), "Expected a list but got {}".format(
type(result)
)
diff = DeepDiff(
example["returns"],
result,
Expand Down

0 comments on commit 17e4444

Please sign in to comment.