Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add results_processed hook #326

Merged
merged 1 commit into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions src/saturn_engine/worker/executors/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from saturn_engine.core import PipelineOutput
from saturn_engine.core import PipelineResults
from saturn_engine.utils import ExceptionGroup
from saturn_engine.utils.asyncutils import Cancellable
from saturn_engine.utils.asyncutils import TasksGroupRunner
from saturn_engine.utils.log import getLogger
Expand All @@ -13,6 +14,7 @@
from saturn_engine.worker.services import Services
from saturn_engine.worker.services.hooks import MessagePublished
from saturn_engine.worker.services.hooks import PipelineEventsEmitted
from saturn_engine.worker.services.hooks import ResultsProcessed
from saturn_engine.worker.topic import Topic

from . import Executor
Expand Down Expand Up @@ -77,25 +79,31 @@ async def scope(
raise
return result

output = await scope(processable)
try:
self.consuming_tasks.create_task(
self.services.s.hooks.pipeline_events_emitted.emit(
PipelineEventsEmitted(
events=output.events, xmsg=processable
)
),
name=f"emit-pipeline-events({processable})",
)
processable.update_resources_used(output.resources)
self.consuming_tasks.create_task(
self.consume_output(
processable=processable, output=output.outputs
),
name=f"consume-output({processable})",
)
except Exception:
self.logger.exception("Error processing outputs")
results = await scope(processable)
self.consuming_tasks.create_task(
self.process_results(xmsg=processable, results=results)
)

async def process_results(
self, *, xmsg: ExecutableMessage, results: PipelineResults
) -> None:
@self.services.s.hooks.results_processed.emit
async def scope(msg: ResultsProcessed) -> None:
msg.xmsg.update_resources_used(msg.results.resources)

await self.consume_output(processable=xmsg, output=msg.results.outputs)

await self.services.s.hooks.pipeline_events_emitted.emit(
PipelineEventsEmitted(events=msg.results.events, xmsg=xmsg)
)

with contextlib.suppress(Exception):
await scope(
ResultsProcessed(
xmsg=xmsg,
results=results,
)
)

async def submit(self, processable: ExecutableMessage) -> None:
# Get the lock to ensure we don't acquire resource if the submit queue
Expand Down Expand Up @@ -154,6 +162,7 @@ async def consume_output(
self, *, processable: ExecutableMessage, output: list[PipelineOutput]
) -> None:
try:
errors = []
for item in output:
topics = processable.output.get(item.channel, [])
for topic in topics:
Expand All @@ -173,10 +182,15 @@ async def scope(topic: Topic) -> None:

await scope(topic)

with contextlib.suppress(Exception):
try:
await scope(
MessagePublished(xmsg=processable, topic=topic, output=item)
)
except Exception as e:
errors.append(e)
if errors:
raise ExceptionGroup("Failed to process outputs", errors)

finally:
await processable.unpark()

Expand Down
26 changes: 26 additions & 0 deletions src/saturn_engine/worker/services/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
T = t.TypeVar("T")


class ResultsProcessed(t.NamedTuple):
xmsg: "ExecutableMessage"
results: "PipelineResults"


class MessagePublished(t.NamedTuple):
xmsg: "ExecutableMessage"
topic: "Topic"
Expand Down Expand Up @@ -57,6 +62,7 @@ class HooksLists(t.Generic[T]):
message_scheduled: list[T] = listfield()
message_submitted: list[T] = listfield()
message_executed: list[T] = listfield()
results_processed: list[T] = listfield()
message_published: list[T] = listfield()
output_blocked: list[T] = listfield()
pipeline_events_emitted: list[T] = listfield()
Expand Down Expand Up @@ -89,6 +95,7 @@ def load_hooks(self) -> HooksLists[t.Callable]:
message_submitted=self._load_hooks(self.message_submitted),
message_executed=self._load_hooks(self.message_executed),
message_published=self._load_hooks(self.message_published),
results_processed=self._load_hooks(self.results_processed),
output_blocked=self._load_hooks(self.output_blocked),
pipeline_events_emitted=self._load_hooks(self.pipeline_events_emitted),
work_queue_built=self._load_hooks(self.work_queue_built),
Expand All @@ -107,6 +114,7 @@ def _load_hooks(hooks: list[str]) -> list[t.Callable]:
message_submitted: AsyncEventHook["ExecutableMessage"]
message_executed: AsyncContextHook["ExecutableMessage", "PipelineResults"]
message_published: AsyncContextHook["MessagePublished", None]
results_processed: AsyncContextHook["ResultsProcessed", None]
output_blocked: AsyncContextHook["Topic", None]
pipeline_events_emitted: AsyncEventHook[PipelineEventsEmitted]

Expand Down Expand Up @@ -143,6 +151,10 @@ def __init__(self, options: t.Optional[Options] = None) -> None:
hooks.message_published + [self.on_message_published],
error_handler=self.hook_failed.emit,
)
self.results_processed = AsyncContextHook(
hooks.results_processed + [self.on_results_processed],
error_handler=self.hook_failed.emit,
)
self.output_blocked = AsyncContextHook(
hooks.output_blocked, error_handler=self.hook_failed.emit
)
Expand Down Expand Up @@ -262,6 +274,20 @@ async def on_message_published(
except Exception as e:
await context.on_error(g, e)

async def on_results_processed(
self,
msg: ResultsProcessed,
) -> t.AsyncGenerator[None, None]:
if not (context := self._msg_context(msg.xmsg, "results_processed")):
return

g = await context.on_call(msg)
try:
result = yield
await context.on_result(g, result)
except Exception as e:
await context.on_error(g, e)

async def on_pipeline_events_emitted(self, events: PipelineEventsEmitted) -> None:
await self._on_msg(events.xmsg, "pipeline_events_emitted", events)

Expand Down
19 changes: 19 additions & 0 deletions src/saturn_engine/worker/services/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from saturn_engine.worker.executors.executable import ExecutableQueue
from saturn_engine.worker.pipeline_message import PipelineMessage
from saturn_engine.worker.services.hooks import MessagePublished
from saturn_engine.worker.services.hooks import ResultsProcessed
from saturn_engine.worker.services.tracing import get_trace_context

from .. import BaseServices
Expand Down Expand Up @@ -77,6 +78,7 @@ async def open(self) -> None:
self.services.hooks.message_submitted.register(self.on_message_submitted)
self.services.hooks.message_executed.register(self.on_message_executed)
self.services.hooks.message_published.register(self.on_message_published)
self.services.hooks.results_processed.register(self.on_results_processed)
self.services.hooks.executor_initialized.register(on_executor_initialized)

@property
Expand Down Expand Up @@ -167,6 +169,22 @@ async def on_message_published(
"Failed to publish message", extra={"data": self.published_data(event)}
)

async def on_results_processed(
self, event: ResultsProcessed
) -> AsyncGenerator[None, None]:
try:
yield
except Exception:
self.message_logger.exception(
"Failed to process message results",
extra={"data": self.results_processed_data(event)},
)

def results_processed_data(self, event: ResultsProcessed) -> dict[str, t.Any]:
return {
"from": pipeline_message_data(event.xmsg.message, verbose=self.verbose)
} | self.result_data(event.results)

def published_data(self, event: MessagePublished) -> dict[str, t.Any]:
return {
"from": pipeline_message_data(event.xmsg.message, verbose=self.verbose)
Expand All @@ -176,6 +194,7 @@ def result_data(self, results: PipelineResults) -> dict[str, t.Any]:
return {
"output": [self.output_data(o) for o in results.outputs[:10]],
"output_count": len(results.outputs),
"events_count": len(results.events),
"resources": self.resources_used(results.resources),
}

Expand Down
20 changes: 20 additions & 0 deletions tests/worker/services/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from saturn_engine.worker.services.hooks import ItemsBatch
from saturn_engine.worker.services.hooks import MessagePublished
from saturn_engine.worker.services.hooks import PipelineEventsEmitted
from saturn_engine.worker.services.hooks import ResultsProcessed


def test_event_hook() -> None:
Expand Down Expand Up @@ -492,6 +493,7 @@ async def handler(arg: str) -> t.AsyncGenerator[None, str]:
mock_message_executed_handler = context_mock("message_executed")
mock_message_executed_handler_msg = context_mock("message_executed_msg")
mock_message_published_handler = context_mock("message_published")
mock_results_processed_handler = context_mock("results_processed")
mock_output_blocked_handler = context_mock("output_blocked")
mock_work_queue_built_handler = context_mock("work_queue_built")

Expand Down Expand Up @@ -521,6 +523,7 @@ def hook_options() -> Hooks.Options:
"message_executed",
"message_published",
"output_blocked",
"results_processed",
"work_queue_built",
]

Expand All @@ -547,6 +550,7 @@ async def test_hooks_service(
xmsg = executable_maker()
msg_events = PipelineEventsEmitted(xmsg=xmsg, events=[])
message_published = MessagePublished(xmsg=xmsg, topic=None, output=None) # type: ignore
results_processed = ResultsProcessed(xmsg=xmsg, results=None) # type: ignore

await hooks.hook_failed.emit(error)
HookMock.hook_failed.assert_awaited_once_with(error)
Expand Down Expand Up @@ -580,6 +584,10 @@ async def scope(arg: t.Any) -> t.Any:
HookMock.message_published.before.assert_awaited_once_with(message_published)
HookMock.message_published.after.assert_awaited_once_with(message_published)

await hooks.results_processed.emit(scope)(results_processed)
HookMock.results_processed.before.assert_awaited_once_with(results_processed)
HookMock.results_processed.after.assert_awaited_once_with(results_processed)

await hooks.output_blocked.emit(scope)("topic") # type: ignore
HookMock.output_blocked.before.assert_awaited_once_with("topic")
HookMock.output_blocked.after.assert_awaited_once_with("topic")
Expand Down Expand Up @@ -667,6 +675,18 @@ async def publish_scope(msg: MessagePublished) -> None:
HookMock.message_published.before.assert_awaited_once_with(message_published)
HookMock.message_published.error.assert_awaited_once_with(error)

@hooks.results_processed.emit
async def results_processed_scope(msg: ResultsProcessed) -> None:
await HookMock.results_processed()
return None

results_processed = ResultsProcessed(xmsg=xmsg, results=None) # type: ignore
await results_processed_scope(results_processed)

HookMock.results_processed.before.assert_awaited_once_with(results_processed)
HookMock.results_processed.after.assert_awaited_once_with(None)
HookMock.results_processed.assert_awaited_once_with()

@hooks.work_queue_built.emit
async def work_queue_scope(queue_item: QueueItemWithState) -> ExecutableQueue:
await HookMock.work_queue_built(queue_item)
Expand Down
1 change: 1 addition & 0 deletions tests/worker/services/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ async def test_logger_message_executed(
"resources": {FakeResource._typename(): "r1"},
"pipeline": "tests.worker.services.test_logger.fake_pipeline",
"result": {
"events_count": 0,
"output_count": 1,
"output": [
{
Expand Down
Loading