diff --git a/src/saturn_engine/worker/executors/queue.py b/src/saturn_engine/worker/executors/queue.py index f7ac6272..56e698c5 100644 --- a/src/saturn_engine/worker/executors/queue.py +++ b/src/saturn_engine/worker/executors/queue.py @@ -5,6 +5,7 @@ from saturn_engine.core import PipelineOutput from saturn_engine.core import PipelineResults +from saturn_engine.utils import ExceptionGroup from saturn_engine.utils.asyncutils import Cancellable from saturn_engine.utils.asyncutils import TasksGroupRunner from saturn_engine.utils.log import getLogger @@ -13,6 +14,7 @@ from saturn_engine.worker.services import Services from saturn_engine.worker.services.hooks import MessagePublished from saturn_engine.worker.services.hooks import PipelineEventsEmitted +from saturn_engine.worker.services.hooks import ResultsProcessed from saturn_engine.worker.topic import Topic from . import Executor @@ -77,25 +79,31 @@ async def scope( raise return result - output = await scope(processable) - try: - self.consuming_tasks.create_task( - self.services.s.hooks.pipeline_events_emitted.emit( - PipelineEventsEmitted( - events=output.events, xmsg=processable - ) - ), - name=f"emit-pipeline-events({processable})", - ) - processable.update_resources_used(output.resources) - self.consuming_tasks.create_task( - self.consume_output( - processable=processable, output=output.outputs - ), - name=f"consume-output({processable})", - ) - except Exception: - self.logger.exception("Error processing outputs") + results = await scope(processable) + self.consuming_tasks.create_task( + self.process_results(xmsg=processable, results=results) + ) + + async def process_results( + self, *, xmsg: ExecutableMessage, results: PipelineResults + ) -> None: + @self.services.s.hooks.results_processed.emit + async def scope(msg: ResultsProcessed) -> None: + msg.xmsg.update_resources_used(msg.results.resources) + + await self.consume_output(processable=xmsg, output=msg.results.outputs) + + await self.services.s.hooks.pipeline_events_emitted.emit( + PipelineEventsEmitted(events=msg.results.events, xmsg=xmsg) + ) + + with contextlib.suppress(Exception): + await scope( + ResultsProcessed( + xmsg=xmsg, + results=results, + ) + ) async def submit(self, processable: ExecutableMessage) -> None: # Get the lock to ensure we don't acquire resource if the submit queue @@ -154,6 +162,7 @@ async def consume_output( self, *, processable: ExecutableMessage, output: list[PipelineOutput] ) -> None: try: + errors = [] for item in output: topics = processable.output.get(item.channel, []) for topic in topics: @@ -173,10 +182,15 @@ async def scope(topic: Topic) -> None: await scope(topic) - with contextlib.suppress(Exception): + try: await scope( MessagePublished(xmsg=processable, topic=topic, output=item) ) + except Exception as e: + errors.append(e) + if errors: + raise ExceptionGroup("Failed to process outputs", errors) + finally: await processable.unpark() diff --git a/src/saturn_engine/worker/services/hooks.py b/src/saturn_engine/worker/services/hooks.py index dda98db0..df77bdec 100644 --- a/src/saturn_engine/worker/services/hooks.py +++ b/src/saturn_engine/worker/services/hooks.py @@ -30,6 +30,11 @@ T = t.TypeVar("T") +class ResultsProcessed(t.NamedTuple): + xmsg: "ExecutableMessage" + results: "PipelineResults" + + class MessagePublished(t.NamedTuple): xmsg: "ExecutableMessage" topic: "Topic" @@ -57,6 +62,7 @@ class HooksLists(t.Generic[T]): message_scheduled: list[T] = listfield() message_submitted: list[T] = listfield() message_executed: list[T] = listfield() + results_processed: list[T] = listfield() message_published: list[T] = listfield() output_blocked: list[T] = listfield() pipeline_events_emitted: list[T] = listfield() @@ -89,6 +95,7 @@ def load_hooks(self) -> HooksLists[t.Callable]: message_submitted=self._load_hooks(self.message_submitted), message_executed=self._load_hooks(self.message_executed), message_published=self._load_hooks(self.message_published), + results_processed=self._load_hooks(self.results_processed), output_blocked=self._load_hooks(self.output_blocked), pipeline_events_emitted=self._load_hooks(self.pipeline_events_emitted), work_queue_built=self._load_hooks(self.work_queue_built), @@ -107,6 +114,7 @@ def _load_hooks(hooks: list[str]) -> list[t.Callable]: message_submitted: AsyncEventHook["ExecutableMessage"] message_executed: AsyncContextHook["ExecutableMessage", "PipelineResults"] message_published: AsyncContextHook["MessagePublished", None] + results_processed: AsyncContextHook["ResultsProcessed", None] output_blocked: AsyncContextHook["Topic", None] pipeline_events_emitted: AsyncEventHook[PipelineEventsEmitted] @@ -143,6 +151,10 @@ def __init__(self, options: t.Optional[Options] = None) -> None: hooks.message_published + [self.on_message_published], error_handler=self.hook_failed.emit, ) + self.results_processed = AsyncContextHook( + hooks.results_processed + [self.on_results_processed], + error_handler=self.hook_failed.emit, + ) self.output_blocked = AsyncContextHook( hooks.output_blocked, error_handler=self.hook_failed.emit ) @@ -262,6 +274,20 @@ async def on_message_published( except Exception as e: await context.on_error(g, e) + async def on_results_processed( + self, + msg: ResultsProcessed, + ) -> t.AsyncGenerator[None, None]: + if not (context := self._msg_context(msg.xmsg, "results_processed")): + return + + g = await context.on_call(msg) + try: + result = yield + await context.on_result(g, result) + except Exception as e: + await context.on_error(g, e) + async def on_pipeline_events_emitted(self, events: PipelineEventsEmitted) -> None: await self._on_msg(events.xmsg, "pipeline_events_emitted", events) diff --git a/src/saturn_engine/worker/services/loggers/logger.py b/src/saturn_engine/worker/services/loggers/logger.py index 007570e1..518499cb 100644 --- a/src/saturn_engine/worker/services/loggers/logger.py +++ b/src/saturn_engine/worker/services/loggers/logger.py @@ -17,6 +17,7 @@ from saturn_engine.worker.executors.executable import ExecutableQueue from saturn_engine.worker.pipeline_message import PipelineMessage from saturn_engine.worker.services.hooks import MessagePublished +from saturn_engine.worker.services.hooks import ResultsProcessed from saturn_engine.worker.services.tracing import get_trace_context from .. import BaseServices @@ -77,6 +78,7 @@ async def open(self) -> None: self.services.hooks.message_submitted.register(self.on_message_submitted) self.services.hooks.message_executed.register(self.on_message_executed) self.services.hooks.message_published.register(self.on_message_published) + self.services.hooks.results_processed.register(self.on_results_processed) self.services.hooks.executor_initialized.register(on_executor_initialized) @property @@ -167,6 +169,22 @@ async def on_message_published( "Failed to publish message", extra={"data": self.published_data(event)} ) + async def on_results_processed( + self, event: ResultsProcessed + ) -> AsyncGenerator[None, None]: + try: + yield + except Exception: + self.message_logger.exception( + "Failed to process message results", + extra={"data": self.results_processed_data(event)}, + ) + + def results_processed_data(self, event: ResultsProcessed) -> dict[str, t.Any]: + return { + "from": pipeline_message_data(event.xmsg.message, verbose=self.verbose) + } | self.result_data(event.results) + def published_data(self, event: MessagePublished) -> dict[str, t.Any]: return { "from": pipeline_message_data(event.xmsg.message, verbose=self.verbose) @@ -176,6 +194,7 @@ def result_data(self, results: PipelineResults) -> dict[str, t.Any]: return { "output": [self.output_data(o) for o in results.outputs[:10]], "output_count": len(results.outputs), + "events_count": len(results.events), "resources": self.resources_used(results.resources), } diff --git a/tests/worker/services/test_hooks.py b/tests/worker/services/test_hooks.py index d170f147..bf1f8176 100644 --- a/tests/worker/services/test_hooks.py +++ b/tests/worker/services/test_hooks.py @@ -21,6 +21,7 @@ from saturn_engine.worker.services.hooks import ItemsBatch from saturn_engine.worker.services.hooks import MessagePublished from saturn_engine.worker.services.hooks import PipelineEventsEmitted +from saturn_engine.worker.services.hooks import ResultsProcessed def test_event_hook() -> None: @@ -492,6 +493,7 @@ async def handler(arg: str) -> t.AsyncGenerator[None, str]: mock_message_executed_handler = context_mock("message_executed") mock_message_executed_handler_msg = context_mock("message_executed_msg") mock_message_published_handler = context_mock("message_published") +mock_results_processed_handler = context_mock("results_processed") mock_output_blocked_handler = context_mock("output_blocked") mock_work_queue_built_handler = context_mock("work_queue_built") @@ -521,6 +523,7 @@ def hook_options() -> Hooks.Options: "message_executed", "message_published", "output_blocked", + "results_processed", "work_queue_built", ] @@ -547,6 +550,7 @@ async def test_hooks_service( xmsg = executable_maker() msg_events = PipelineEventsEmitted(xmsg=xmsg, events=[]) message_published = MessagePublished(xmsg=xmsg, topic=None, output=None) # type: ignore + results_processed = ResultsProcessed(xmsg=xmsg, results=None) # type: ignore await hooks.hook_failed.emit(error) HookMock.hook_failed.assert_awaited_once_with(error) @@ -580,6 +584,10 @@ async def scope(arg: t.Any) -> t.Any: HookMock.message_published.before.assert_awaited_once_with(message_published) HookMock.message_published.after.assert_awaited_once_with(message_published) + await hooks.results_processed.emit(scope)(results_processed) + HookMock.results_processed.before.assert_awaited_once_with(results_processed) + HookMock.results_processed.after.assert_awaited_once_with(results_processed) + await hooks.output_blocked.emit(scope)("topic") # type: ignore HookMock.output_blocked.before.assert_awaited_once_with("topic") HookMock.output_blocked.after.assert_awaited_once_with("topic") @@ -667,6 +675,18 @@ async def publish_scope(msg: MessagePublished) -> None: HookMock.message_published.before.assert_awaited_once_with(message_published) HookMock.message_published.error.assert_awaited_once_with(error) + @hooks.results_processed.emit + async def results_processed_scope(msg: ResultsProcessed) -> None: + await HookMock.results_processed() + return None + + results_processed = ResultsProcessed(xmsg=xmsg, results=None) # type: ignore + await results_processed_scope(results_processed) + + HookMock.results_processed.before.assert_awaited_once_with(results_processed) + HookMock.results_processed.after.assert_awaited_once_with(None) + HookMock.results_processed.assert_awaited_once_with() + @hooks.work_queue_built.emit async def work_queue_scope(queue_item: QueueItemWithState) -> ExecutableQueue: await HookMock.work_queue_built(queue_item) diff --git a/tests/worker/services/test_logger.py b/tests/worker/services/test_logger.py index 0ae5dcba..e44f33c1 100644 --- a/tests/worker/services/test_logger.py +++ b/tests/worker/services/test_logger.py @@ -92,6 +92,7 @@ async def test_logger_message_executed( "resources": {FakeResource._typename(): "r1"}, "pipeline": "tests.worker.services.test_logger.fake_pipeline", "result": { + "events_count": 0, "output_count": 1, "output": [ {