diff --git a/CHANGELOG.md b/CHANGELOG.md index 79fe028b..2aee94c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,10 +21,13 @@ and start a new "In Progress" section above it. ## In progress +## 0.107.3 + +- Support `save_result` processes in arbitrary subtrees in the process graph i.e. those not necessarily contributing to the final result ([Open-EO/openeo-geopyspark-driver#424](https://github.com/Open-EO/openeo-geopyspark-driver/issues/424)) ## 0.107.2 -- Fix default level of `inspect` process (defaults to `info`) ([#424](https://github.com/Open-EO/openeo-geopyspark-driver/issues/424)) +- Fix default level of `inspect` process (defaults to `info`) ([Open-EO/openeo-geopyspark-driver#424](https://github.com/Open-EO/openeo-geopyspark-driver/issues/424)) - `apply_polygon`: add support for `geometries` argument (in addition to legacy, but still supported `polygons`) ([Open-EO/openeo-processes#511](https://github.com/Open-EO/openeo-processes/issues/511)) ## 0.107.1 diff --git a/openeo_driver/ProcessGraphDeserializer.py b/openeo_driver/ProcessGraphDeserializer.py index e5f11a84..6a173a93 100644 --- a/openeo_driver/ProcessGraphDeserializer.py +++ b/openeo_driver/ProcessGraphDeserializer.py @@ -242,7 +242,8 @@ def _register_fallback_implementations_by_process_graph(process_registry: Proces # Some (env) string constants to simplify code navigation ENV_SOURCE_CONSTRAINTS = "source_constraints" ENV_DRY_RUN_TRACER = "dry_run_tracer" -ENV_SAVE_RESULT= "save_result" +ENV_FINAL_RESULT = "final_result" +ENV_SAVE_RESULT = "save_result" class SimpleProcessing(Processing): @@ -266,6 +267,7 @@ def get_process_registry(self, api_version: Union[str, ComparableVersion]) -> Pr if spec not in self._registry_cache: registry = ProcessRegistry(spec_root=SPECS_ROOT / spec, argument_names=["args", "env"]) _add_standard_processes(registry, _OPENEO_PROCESSES_PYTHON_WHITELIST) + registry.add_hidden(collect) self._registry_cache[spec] = registry return self._registry_cache[spec] @@ -301,16 +303,17 @@ def evaluate(self, process_graph: dict, env: EvalEnv = None): def validate(self, process_graph: dict, env: EvalEnv = None) -> List[dict]: dry_run_tracer = DryRunDataTracer() - env = env.push({ENV_DRY_RUN_TRACER: dry_run_tracer}) + env = env.push({ENV_DRY_RUN_TRACER: dry_run_tracer, ENV_FINAL_RESULT: [None]}) try: - top_level_node = ProcessGraphVisitor.dereference_from_node_arguments(process_graph) - result_node = process_graph[top_level_node] + ProcessGraphVisitor.dereference_from_node_arguments(process_graph) except ProcessGraphVisitException as e: return [{"code": "ProcessGraphInvalid", "message": str(e)}] try: - result = convert_node(result_node, env=env) + collected_process_graph, top_level_node_id = _collect_end_nodes(process_graph) + top_level_node = collected_process_graph[top_level_node_id] + result = convert_node(top_level_node, env=env) except OpenEOApiException as e: return [{"code": e.code, "message": str(e)}] except Exception as e: @@ -340,6 +343,44 @@ def extra_validation( return [] +def _collect_end_nodes(process_graph: dict) -> (dict, str): + end_node_ids = _end_node_ids(process_graph) + top_level_node_id = "collect1" # the node where evaluation starts (not necessarily the result node) + + collected_process_graph = dict(process_graph, **{top_level_node_id: { + "process_id": "collect", + "arguments": { + "end_nodes": [{"from_node": end_node_id} for end_node_id in end_node_ids] + } + }}) + + ProcessGraphVisitor.dereference_from_node_arguments(collected_process_graph) + return collected_process_graph, top_level_node_id + + +def _end_node_ids(process_graph: dict) -> set: + all_node_ids = set(process_graph.keys()) + + def get_from_node_ids(value) -> set: + if isinstance(value, dict): + if "from_node" in value: + return {value["from_node"]} + else: + return {node_id for v in value.values() for node_id in get_from_node_ids(v)} + + if isinstance(value, list): + return {node_id for v in value for node_id in get_from_node_ids(v)} + + return set() + + from_node_ids = {node_id + for node in process_graph.values() + for argument_value in node["arguments"].values() + for node_id in get_from_node_ids(argument_value)} + + return all_node_ids - from_node_ids + + def evaluate( process_graph: dict, env: EvalEnv, @@ -355,15 +396,17 @@ def evaluate( _log.warning("No version in `evaluate()` env. Blindly assuming 1.0.0.") env = env.push({"version": "1.0.0"}) - top_level_node = ProcessGraphVisitor.dereference_from_node_arguments(process_graph) - result_node = process_graph[top_level_node] + collected_process_graph, top_level_node_id = _collect_end_nodes(process_graph) + top_level_node = collected_process_graph[top_level_node_id] if ENV_SAVE_RESULT not in env: env = env.push({ENV_SAVE_RESULT: []}) + env = env.push({ENV_FINAL_RESULT: [None]}) # mutable, holds final result of process graph + if do_dry_run: dry_run_tracer = do_dry_run if isinstance(do_dry_run, DryRunDataTracer) else DryRunDataTracer() _log.info("Doing dry run") - convert_node(result_node, env=env.push({ + convert_node(top_level_node, env=env.push({ ENV_DRY_RUN_TRACER: dry_run_tracer, ENV_SAVE_RESULT: [], # otherwise dry run and real run append to the same mutable result list "node_caching": False @@ -373,7 +416,7 @@ def evaluate( _log.info("Dry run extracted these source constraints: {s}".format(s=source_constraints)) env = env.push({ENV_SOURCE_CONSTRAINTS: source_constraints}) - result = convert_node(result_node, env=env) + result = convert_node(top_level_node, env=env) if len(env[ENV_SAVE_RESULT]) > 0: if len(env[ENV_SAVE_RESULT]) == 1: return env[ENV_SAVE_RESULT][0] @@ -410,6 +453,10 @@ def convert_node(processGraph: Union[dict, list], env: EvalEnv = None): # TODO: this manipulates the process graph, while we often assume it's immutable. # Adding complex data structures could also interfere with attempts to (re)encode the process graph as JSON again. processGraph["result_cache"] = process_result + + if processGraph.get('result', False): + env[ENV_FINAL_RESULT][0] = process_result + return process_result elif 'node' in processGraph: return convert_node(processGraph['node'], env=env) @@ -424,7 +471,7 @@ def convert_node(processGraph: Union[dict, list], env: EvalEnv = None): raise ProcessParameterRequiredException(process="n/a", parameter=processGraph['from_parameter']) else: # TODO: Don't apply `convert_node` for some special cases (e.g. geojson objects)? - return {k:convert_node(v, env=env) for k,v in processGraph.items()} + return {k: convert_node(v, env=env) for k, v in processGraph.items()} elif isinstance(processGraph, list): return [convert_node(x, env=env) for x in processGraph] return processGraph @@ -1623,6 +1670,7 @@ def apply_process(process_id: str, args: dict, namespace: Union[str, None], env: process_registry = env.backend_implementation.processing.get_process_registry(api_version=env["version"]) process_function = process_registry.get_function(process_id, namespace=namespace) + _log.debug(f"Applying process {process_id} to arguments {args}") return process_function(args=ProcessArgs(args, process_id=process_id), env=env) @@ -2266,6 +2314,11 @@ def export_workspace(args: ProcessArgs, env: EvalEnv) -> SaveResult: return result +@custom_process +def collect(args: ProcessArgs, env: EvalEnv): + return env[ENV_FINAL_RESULT][0] + + # Finally: register some fallback implementation if possible _register_fallback_implementations_by_process_graph(process_registry_100) _register_fallback_implementations_by_process_graph(process_registry_2xx) diff --git a/openeo_driver/_version.py b/openeo_driver/_version.py index de5a1272..d1467f3a 100644 --- a/openeo_driver/_version.py +++ b/openeo_driver/_version.py @@ -1 +1 @@ -__version__ = "0.107.2a1" +__version__ = "0.107.3a1" diff --git a/tests/test_dry_run.py b/tests/test_dry_run.py index cc888af5..f2ee43c6 100644 --- a/tests/test_dry_run.py +++ b/tests/test_dry_run.py @@ -1667,6 +1667,7 @@ def test_load_result_constraints(dry_run_env, dry_run_tracer): ) ] + def test_multiple_save_result(dry_run_env): pg = { "collection1": { @@ -1738,6 +1739,60 @@ def test_multiple_save_result(dry_run_env): assert len(the_result) == 2 +def test_non_result_subtrees_are_evaluated(dry_run_env, caplog): + pg = { + "loadcollection1": { + "process_id": "load_collection", + "arguments": { + "id": "S2_FAPAR_CLOUDCOVER", + "spatial_extent": {"west": 5, "south": 50, "east": 5.1, "north": 50.1}, + "temporal_extent": ["2024-07-11", "2024-07-21"], + "bands": ["Flat:1"] + } + }, + "loadcollection2": { + "process_id": "load_collection", + "arguments": { + "id": "S2_FOOBAR", + "spatial_extent": {"west": 5, "south": 50, "east": 5.1, "north": 50.1}, + "temporal_extent": ["2024-07-11", "2024-07-21"], + "bands": ["Flat:2"] + } + }, + "inspect1": { + "process_id": "inspect", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "message": "intermediate result", + "level": "warning" + } + }, + "mergecubes1": { + "process_id": "merge_cubes", + "arguments": { + "cube1": {"from_node": "loadcollection1"}, + "cube2": {"from_node": "loadcollection2"}, + }, + "result": True + }, + "saveresult2": { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "mergecubes1"}, + "format": "netCDF" + } + }, + } + + result = evaluate(pg, env=dry_run_env) + + # side-effect 1: output asset + assert result.format == "netCDF" + + # side-effect 2: inspect log + assert "intermediate result" in caplog.messages + + def test_invalid_latlon_in_geojson(dry_run_env): init_cube = DataCube(PGNode("load_collection", id="S2_FOOBAR"), connection=None) diff --git a/tests/test_views_execute.py b/tests/test_views_execute.py index d923b48b..776f732b 100644 --- a/tests/test_views_execute.py +++ b/tests/test_views_execute.py @@ -32,6 +32,7 @@ from openeo_driver.ProcessGraphDeserializer import ( custom_process, custom_process_from_process_graph, + collect, ) from openeo_driver.testing import ( TEST_USER, @@ -3893,7 +3894,7 @@ def test_if_merge_cubes(api): "bands": ["B04"], }}, "eq1": {"process_id": "eq", "arguments": {"x": 4, "y": 3}}, - "errornode":{"process_id":"doesntExist"}, + "errornode": {"process_id": "doesntExist", "arguments": {}}, "if1": { "process_id": "if", "arguments": { @@ -4385,6 +4386,7 @@ def test_synchronous_processing_response_header_openeo_identifier(api): @pytest.fixture def custom_process_registry(backend_implementation) -> ProcessRegistry: process_registry = ProcessRegistry() + process_registry.add_hidden(collect) with mock.patch.object(backend_implementation.processing, "get_process_registry", return_value=process_registry): yield process_registry