Skip to content

Commit

Permalink
artificial "collect" process evaluates the whole graph
Browse files Browse the repository at this point in the history
  • Loading branch information
bossie authored Jul 18, 2024
1 parent 15b72a7 commit c63b66d
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 13 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ and start a new "In Progress" section above it.

## In progress

## 0.107.3

- Support `save_result` processes in arbitrary subtrees in the process graph i.e. those not necessarily contributing to the final result ([Open-EO/openeo-geopyspark-driver#424](https://github.com/Open-EO/openeo-geopyspark-driver/issues/424))

## 0.107.2

- Fix default level of `inspect` process (defaults to `info`) ([#424](https://github.com/Open-EO/openeo-geopyspark-driver/issues/424))
- Fix default level of `inspect` process (defaults to `info`) ([Open-EO/openeo-geopyspark-driver#424](https://github.com/Open-EO/openeo-geopyspark-driver/issues/424))
- `apply_polygon`: add support for `geometries` argument (in addition to legacy, but still supported `polygons`) ([Open-EO/openeo-processes#511](https://github.com/Open-EO/openeo-processes/issues/511))

## 0.107.1
Expand Down
73 changes: 63 additions & 10 deletions openeo_driver/ProcessGraphDeserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def _register_fallback_implementations_by_process_graph(process_registry: Proces
# Some (env) string constants to simplify code navigation
ENV_SOURCE_CONSTRAINTS = "source_constraints"
ENV_DRY_RUN_TRACER = "dry_run_tracer"
ENV_SAVE_RESULT= "save_result"
ENV_FINAL_RESULT = "final_result"
ENV_SAVE_RESULT = "save_result"


class SimpleProcessing(Processing):
Expand All @@ -266,6 +267,7 @@ def get_process_registry(self, api_version: Union[str, ComparableVersion]) -> Pr
if spec not in self._registry_cache:
registry = ProcessRegistry(spec_root=SPECS_ROOT / spec, argument_names=["args", "env"])
_add_standard_processes(registry, _OPENEO_PROCESSES_PYTHON_WHITELIST)
registry.add_hidden(collect)
self._registry_cache[spec] = registry
return self._registry_cache[spec]

Expand Down Expand Up @@ -301,16 +303,17 @@ def evaluate(self, process_graph: dict, env: EvalEnv = None):

def validate(self, process_graph: dict, env: EvalEnv = None) -> List[dict]:
dry_run_tracer = DryRunDataTracer()
env = env.push({ENV_DRY_RUN_TRACER: dry_run_tracer})
env = env.push({ENV_DRY_RUN_TRACER: dry_run_tracer, ENV_FINAL_RESULT: [None]})

try:
top_level_node = ProcessGraphVisitor.dereference_from_node_arguments(process_graph)
result_node = process_graph[top_level_node]
ProcessGraphVisitor.dereference_from_node_arguments(process_graph)
except ProcessGraphVisitException as e:
return [{"code": "ProcessGraphInvalid", "message": str(e)}]

try:
result = convert_node(result_node, env=env)
collected_process_graph, top_level_node_id = _collect_end_nodes(process_graph)
top_level_node = collected_process_graph[top_level_node_id]
result = convert_node(top_level_node, env=env)
except OpenEOApiException as e:
return [{"code": e.code, "message": str(e)}]
except Exception as e:
Expand Down Expand Up @@ -340,6 +343,44 @@ def extra_validation(
return []


def _collect_end_nodes(process_graph: dict) -> (dict, str):
end_node_ids = _end_node_ids(process_graph)
top_level_node_id = "collect1" # the node where evaluation starts (not necessarily the result node)

collected_process_graph = dict(process_graph, **{top_level_node_id: {
"process_id": "collect",
"arguments": {
"end_nodes": [{"from_node": end_node_id} for end_node_id in end_node_ids]
}
}})

ProcessGraphVisitor.dereference_from_node_arguments(collected_process_graph)
return collected_process_graph, top_level_node_id


def _end_node_ids(process_graph: dict) -> set:
all_node_ids = set(process_graph.keys())

def get_from_node_ids(value) -> set:
if isinstance(value, dict):
if "from_node" in value:
return {value["from_node"]}
else:
return {node_id for v in value.values() for node_id in get_from_node_ids(v)}

if isinstance(value, list):
return {node_id for v in value for node_id in get_from_node_ids(v)}

return set()

from_node_ids = {node_id
for node in process_graph.values()
for argument_value in node["arguments"].values()
for node_id in get_from_node_ids(argument_value)}

return all_node_ids - from_node_ids


def evaluate(
process_graph: dict,
env: EvalEnv,
Expand All @@ -355,15 +396,17 @@ def evaluate(
_log.warning("No version in `evaluate()` env. Blindly assuming 1.0.0.")
env = env.push({"version": "1.0.0"})

top_level_node = ProcessGraphVisitor.dereference_from_node_arguments(process_graph)
result_node = process_graph[top_level_node]
collected_process_graph, top_level_node_id = _collect_end_nodes(process_graph)
top_level_node = collected_process_graph[top_level_node_id]
if ENV_SAVE_RESULT not in env:
env = env.push({ENV_SAVE_RESULT: []})

env = env.push({ENV_FINAL_RESULT: [None]}) # mutable, holds final result of process graph

if do_dry_run:
dry_run_tracer = do_dry_run if isinstance(do_dry_run, DryRunDataTracer) else DryRunDataTracer()
_log.info("Doing dry run")
convert_node(result_node, env=env.push({
convert_node(top_level_node, env=env.push({
ENV_DRY_RUN_TRACER: dry_run_tracer,
ENV_SAVE_RESULT: [], # otherwise dry run and real run append to the same mutable result list
"node_caching": False
Expand All @@ -373,7 +416,7 @@ def evaluate(
_log.info("Dry run extracted these source constraints: {s}".format(s=source_constraints))
env = env.push({ENV_SOURCE_CONSTRAINTS: source_constraints})

result = convert_node(result_node, env=env)
result = convert_node(top_level_node, env=env)
if len(env[ENV_SAVE_RESULT]) > 0:
if len(env[ENV_SAVE_RESULT]) == 1:
return env[ENV_SAVE_RESULT][0]
Expand Down Expand Up @@ -410,6 +453,10 @@ def convert_node(processGraph: Union[dict, list], env: EvalEnv = None):
# TODO: this manipulates the process graph, while we often assume it's immutable.
# Adding complex data structures could also interfere with attempts to (re)encode the process graph as JSON again.
processGraph["result_cache"] = process_result

if processGraph.get('result', False):
env[ENV_FINAL_RESULT][0] = process_result

return process_result
elif 'node' in processGraph:
return convert_node(processGraph['node'], env=env)
Expand All @@ -424,7 +471,7 @@ def convert_node(processGraph: Union[dict, list], env: EvalEnv = None):
raise ProcessParameterRequiredException(process="n/a", parameter=processGraph['from_parameter'])
else:
# TODO: Don't apply `convert_node` for some special cases (e.g. geojson objects)?
return {k:convert_node(v, env=env) for k,v in processGraph.items()}
return {k: convert_node(v, env=env) for k, v in processGraph.items()}
elif isinstance(processGraph, list):
return [convert_node(x, env=env) for x in processGraph]
return processGraph
Expand Down Expand Up @@ -1623,6 +1670,7 @@ def apply_process(process_id: str, args: dict, namespace: Union[str, None], env:

process_registry = env.backend_implementation.processing.get_process_registry(api_version=env["version"])
process_function = process_registry.get_function(process_id, namespace=namespace)
_log.debug(f"Applying process {process_id} to arguments {args}")
return process_function(args=ProcessArgs(args, process_id=process_id), env=env)


Expand Down Expand Up @@ -2266,6 +2314,11 @@ def export_workspace(args: ProcessArgs, env: EvalEnv) -> SaveResult:
return result


@custom_process
def collect(args: ProcessArgs, env: EvalEnv):
return env[ENV_FINAL_RESULT][0]


# Finally: register some fallback implementation if possible
_register_fallback_implementations_by_process_graph(process_registry_100)
_register_fallback_implementations_by_process_graph(process_registry_2xx)
2 changes: 1 addition & 1 deletion openeo_driver/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.107.2a1"
__version__ = "0.107.3a1"
55 changes: 55 additions & 0 deletions tests/test_dry_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,7 @@ def test_load_result_constraints(dry_run_env, dry_run_tracer):
)
]


def test_multiple_save_result(dry_run_env):
pg = {
"collection1": {
Expand Down Expand Up @@ -1738,6 +1739,60 @@ def test_multiple_save_result(dry_run_env):
assert len(the_result) == 2


def test_non_result_subtrees_are_evaluated(dry_run_env, caplog):
pg = {
"loadcollection1": {
"process_id": "load_collection",
"arguments": {
"id": "S2_FAPAR_CLOUDCOVER",
"spatial_extent": {"west": 5, "south": 50, "east": 5.1, "north": 50.1},
"temporal_extent": ["2024-07-11", "2024-07-21"],
"bands": ["Flat:1"]
}
},
"loadcollection2": {
"process_id": "load_collection",
"arguments": {
"id": "S2_FOOBAR",
"spatial_extent": {"west": 5, "south": 50, "east": 5.1, "north": 50.1},
"temporal_extent": ["2024-07-11", "2024-07-21"],
"bands": ["Flat:2"]
}
},
"inspect1": {
"process_id": "inspect",
"arguments": {
"data": {"from_node": "loadcollection1"},
"message": "intermediate result",
"level": "warning"
}
},
"mergecubes1": {
"process_id": "merge_cubes",
"arguments": {
"cube1": {"from_node": "loadcollection1"},
"cube2": {"from_node": "loadcollection2"},
},
"result": True
},
"saveresult2": {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "mergecubes1"},
"format": "netCDF"
}
},
}

result = evaluate(pg, env=dry_run_env)

# side-effect 1: output asset
assert result.format == "netCDF"

# side-effect 2: inspect log
assert "intermediate result" in caplog.messages


def test_invalid_latlon_in_geojson(dry_run_env):
init_cube = DataCube(PGNode("load_collection", id="S2_FOOBAR"), connection=None)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_views_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from openeo_driver.ProcessGraphDeserializer import (
custom_process,
custom_process_from_process_graph,
collect,
)
from openeo_driver.testing import (
TEST_USER,
Expand Down Expand Up @@ -3893,7 +3894,7 @@ def test_if_merge_cubes(api):
"bands": ["B04"],
}},
"eq1": {"process_id": "eq", "arguments": {"x": 4, "y": 3}},
"errornode":{"process_id":"doesntExist"},
"errornode": {"process_id": "doesntExist", "arguments": {}},
"if1": {
"process_id": "if",
"arguments": {
Expand Down Expand Up @@ -4385,6 +4386,7 @@ def test_synchronous_processing_response_header_openeo_identifier(api):
@pytest.fixture
def custom_process_registry(backend_implementation) -> ProcessRegistry:
process_registry = ProcessRegistry()
process_registry.add_hidden(collect)
with mock.patch.object(backend_implementation.processing, "get_process_registry", return_value=process_registry):
yield process_registry

Expand Down

0 comments on commit c63b66d

Please sign in to comment.