Skip to content

Commit

Permalink
service(State): Allow to override the updated state cursor
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed May 28, 2024
1 parent 7353f14 commit f1b2e4a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/saturn_engine/core/job_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import dataclasses

from saturn_engine.core import Cursor
from saturn_engine.core.pipeline import PipelineEvent


@dataclasses.dataclass
class CursorStateUpdated(PipelineEvent):
state: dict
cursor: Cursor | None = None


class CursorState:
Expand Down
4 changes: 3 additions & 1 deletion src/saturn_engine/worker/services/job_state/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ async def on_pipeline_events_emitted(self, pevents: PipelineEventsEmitted) -> No
continue

message = pevents.xmsg.message.message
cursor = message.metadata.get("job_state", {}).get("state_cursor")
cursor = event.cursor or message.metadata.get("job_state", {}).get(
"state_cursor"
)
if not cursor:
continue

Expand Down
20 changes: 19 additions & 1 deletion tests/worker/services/job_state/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,22 @@ async def fake_work_builder(queue: QueueItemWithState) -> t.Any:
async with xmsg._context:
xmsgs.append(xmsg)
results.append(xmsg.message.execute())

# Emit a new state.
await services_manager.services.s.hooks.pipeline_events_emitted.emit(
PipelineEventsEmitted(
xmsg=xmsg, events=[CursorStateUpdated(state={"x": len(xmsgs) * 10})]
)
)

# Emit a new state with a custom cursor.
await services_manager.services.s.hooks.pipeline_events_emitted.emit(
PipelineEventsEmitted(
xmsg=xmsg,
events=[CursorStateUpdated(cursor=Cursor("42"), state={"x": 42})],
)
)

# Assert everything loaded and ran in order.
msgs = [i.message.message for i in xmsgs]
assert [i.id for i in msgs] == ["0", "1", "2", "3", "4"]
Expand Down Expand Up @@ -287,7 +297,14 @@ async def fake_work_builder(queue: QueueItemWithState) -> t.Any:
# Rerun the inventory, expect new states to be loaded
inventory = FakeInventory(
options=FakeInventory.Options(
data=[5, 4, 3, 2, (1, {"job_state": {"state_cursor": "a"}})]
data=[
5,
4,
3,
2,
(1, {"job_state": {"state_cursor": "a"}}),
(42, {"job_state": {"state_cursor": "42"}}),
]
)
)
job = Job(
Expand All @@ -310,6 +327,7 @@ async def fake_work_builder(queue: QueueItemWithState) -> t.Any:
{"x": 30},
{"x": 40},
{"x": 50},
{"x": 42},
]


Expand Down

0 comments on commit f1b2e4a

Please sign in to comment.