Skip to content

Commit

Permalink
Move inventory cursor only once an item have been successfuly processed
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Aug 3, 2023
1 parent 32d0063 commit cd2028c
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 21 deletions.
3 changes: 3 additions & 0 deletions src/saturn_engine/utils/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ async def close(self, timeout: t.Optional[float] = None) -> None:
if not task.done():
task.cancel()

if not self.tasks:
return

# Collect results to log errors.
done, pending = await asyncio.wait(self.tasks, timeout=timeout)
for task in done:
Expand Down
10 changes: 1 addition & 9 deletions src/saturn_engine/worker/inventories/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,6 @@
from saturn_engine.worker.services import Services


@dataclasses.dataclass
class ItemAdapter(Item):
_context: AsyncExitStack = None # type: ignore[assignment]

async def __aexit__(self, *exc: t.Any) -> t.Optional[bool]:
return await self._context.__aexit__(*exc)


class TopicAdapter(IteratorInventory):
@dataclasses.dataclass
class Options:
Expand All @@ -34,7 +26,7 @@ async def iterate(self, after: t.Optional[Cursor] = None) -> t.AsyncIterator[Ite
try:
async with AsyncExitStack() as stack:
message = await stack.enter_async_context(message_ctx)
yield ItemAdapter(
yield Item(
id=message.id,
cursor=None,
args=message.args,
Expand Down
8 changes: 7 additions & 1 deletion src/saturn_engine/worker/inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import uuid
from collections.abc import AsyncIterator
from contextlib import AsyncExitStack
from datetime import timedelta
from functools import cached_property

Expand All @@ -28,6 +29,9 @@ class Item:
cursor: t.Optional[Cursor] = MISSING # type: ignore[assignment]
tags: dict[str, str] = dataclasses.field(default_factory=dict)
metadata: dict[str, t.Any] = dataclasses.field(default_factory=dict)
_context: AsyncExitStack = dataclasses.field(
default_factory=AsyncExitStack, compare=False
)

# Hack to allow building object with `str` instead of new types `MessageId`
# and `Cursor`.
Expand All @@ -41,6 +45,7 @@ def __init__(
cursor: t.Optional[str] = None,
tags: dict[str, str] = None, # type: ignore[assignment]
metadata: dict[str, t.Any] = None, # type: ignore[assignment]
_context: AsyncExitStack = None, # type: ignore[assignment]
) -> None:
...

Expand All @@ -49,10 +54,11 @@ def __post_init__(self) -> None:
self.cursor = Cursor(self.id)

async def __aenter__(self) -> "Item":
await self._context.__aenter__()
return self

async def __aexit__(self, *exc: t.Any) -> t.Optional[bool]:
return None
return await self._context.__aexit__(*exc)


class MaxRetriesError(Exception):
Expand Down
52 changes: 41 additions & 11 deletions src/saturn_engine/worker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self.services = services
self.queue_item = queue_item
self.state_service = services.cast_service(JobStateService)
self._pendings: dict[int, tuple[bool, Item]] = {}

async def run(self) -> AsyncGenerator[TopicOutput, None]:
cursor = self.queue_item.state.cursor
Expand All @@ -39,22 +40,51 @@ async def run(self) -> AsyncGenerator[TopicOutput, None]:
cursor = item.cursor
yield self.item_to_topic(item)

if cursor:
self.state_service.set_job_cursor(
self.queue_item.name, cursor=cursor
)

self.state_service.set_job_completed(self.queue_item.name)
except Exception as e:
self.logger.exception("Exception raised from job")
self.state_service.set_job_failed(self.queue_item.name, error=e)

@asynccontextmanager
async def item_to_topic(self, item_ctx: Item) -> t.AsyncIterator[TopicMessage]:
async with item_ctx as item:
yield TopicMessage(
id=MessageId(item.id),
args=item.args,
tags=item.tags,
metadata=item.metadata | {"job": {"cursor": item.cursor}},
self._set_item_pending(item_ctx)
try:
async with item_ctx as item:
yield TopicMessage(
id=MessageId(item.id),
args=item.args,
tags=item.tags,
metadata=item.metadata | {"job": {"cursor": item.cursor}},
)
except Exception:
self.logger.exception(
"Failed to process item",
extra={"data": {"message": {"id": item_ctx.id}}},
)
finally:
self._set_item_done(item_ctx)

def _set_item_pending(self, item: Item) -> None:
self._pendings[id(item)] = (True, item)

def _set_item_done(self, item: Item) -> None:
self._pendings[id(item)] = (False, item)

# Collect the serie of done item from the beginning.
items_done = []
for pending, item in self._pendings.values():
if pending:
break
items_done.append(item)

# Remove all done item from the pendings
for item in items_done:
del self._pendings[id(item)]

# Commit last cursor.
for item in reversed(items_done):
if item.cursor:
self.state_service.set_job_cursor(
self.queue_item.name, cursor=item.cursor
)
break
3 changes: 3 additions & 0 deletions src/saturn_engine/worker/services/job_state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ def flush(self) -> t.Iterator[JobsStates]:
except BaseException:
self._current_state = flushing_state.merge(self._current_state)
raise

def job_state(self, job_name: JobId) -> JobState:
return self._current_state.jobs[job_name]
139 changes: 139 additions & 0 deletions tests/worker/test_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import typing as t

import dataclasses
from contextlib import AsyncExitStack

import asyncstdlib as alib
import pytest

from saturn_engine.core.api import QueueItemWithState
from saturn_engine.worker.executors.executable import ExecutableQueue
from saturn_engine.worker.inventory import Cursor
from saturn_engine.worker.inventory import Inventory
from saturn_engine.worker.inventory import Item
from saturn_engine.worker.job import Job
from saturn_engine.worker.services.job_state.service import JobStateService
from saturn_engine.worker.services.manager import ServicesManager


class FakeInventory(Inventory):
name = "fake_inventory"

@dataclasses.dataclass
class Options:
items: list[Item]

def __init__(self, *args: t.Any, options: Options, **kwargs: t.Any) -> None:
self.options = options

async def next_batch(self, after: t.Optional[Cursor] = None) -> list[Item]:
raise NotImplementedError()

async def iterate(self, after: t.Optional[Cursor] = None) -> t.AsyncIterator[Item]:
for item in self.options.items:
yield item


@pytest.mark.asyncio
async def test_inventory_set_cursor(
services_manager: ServicesManager,
fake_queue_item: QueueItemWithState,
executable_queue_maker: t.Callable[..., ExecutableQueue],
) -> None:
inventory = FakeInventory(
options=FakeInventory.Options(
items=[
Item(cursor=Cursor("1"), args={"x": 1}),
Item(cursor=Cursor("2"), args={"x": 1}),
]
)
)
job_id = fake_queue_item.name
job_state_store = services_manager.services.cast_service(JobStateService)._store
job = Job(
inventory=inventory,
queue_item=fake_queue_item,
services=services_manager.services,
)
xqueue = executable_queue_maker(definition=fake_queue_item, topic=job)

async for xmsg in xqueue.run():
async with xmsg._context:
pass

assert job_state_store.job_state(job_id).cursor == "2"


@pytest.mark.asyncio
async def test_inventory_set_cursor_after_completed(
services_manager: ServicesManager,
fake_queue_item: QueueItemWithState,
executable_queue_maker: t.Callable[..., ExecutableQueue],
) -> None:
def fail() -> None:
raise ValueError()

failing_stack = AsyncExitStack()
failing_stack.callback(fail)
inventory = FakeInventory(
options=FakeInventory.Options(
items=[
Item(cursor=Cursor("0"), args={"x": 1}),
Item(cursor=None, args={"x": 1}),
Item(cursor=Cursor("2"), args={"x": 1}, _context=failing_stack),
Item(cursor=None, args={"x": 1}),
Item(cursor=Cursor("4"), args={"x": 1}),
Item(cursor=Cursor("5"), args={"x": 1}),
Item(cursor=Cursor("6"), args={"x": 1}),
]
)
)
job_id = fake_queue_item.name
job_state_store = services_manager.services.cast_service(JobStateService)._store
job = Job(
inventory=inventory,
queue_item=fake_queue_item,
services=services_manager.services,
)
xqueue = executable_queue_maker(definition=fake_queue_item, topic=job)

xmsg_ctxs: list[AsyncExitStack] = []
async with alib.scoped_iter(xqueue.run()) as xrun:
async for xmsg in alib.islice(xrun, 7):
async with AsyncExitStack() as stack:
await stack.enter_async_context(xmsg._context)
xmsg_ctxs.append(stack.pop_all())

assert job_state_store.job_state(job_id).cursor is None
assert len(xmsg_ctxs) == 7

# .: Pending, R: Ready
# |0|1|2|3|4|5|6|
# -> |.|.|R|.|.|R|.|
# Nothing commited.
await xmsg_ctxs[2].aclose()
await xmsg_ctxs[5].aclose()
assert job_state_store.job_state(job_id).cursor is None

# .: Pending, R: Ready
# |0|1|2|3|4|5|6|
# -> |C|.|R|R|.|R|.|
# Message 0 is commited.
await xmsg_ctxs[3].aclose()
await xmsg_ctxs[0].aclose()
assert job_state_store.job_state(job_id).cursor == "0"

# .: Pending, R: Ready
# |0|1|2|3|4|5|6|
# -> |C|R|C|R|.|R|.|
# Message 2 is commited (Message 3 has no cursor)
await xmsg_ctxs[1].aclose()
assert job_state_store.job_state(job_id).cursor == "2"

# .: Pending, R: Ready
# |0|1|2|3|4|5|6|
# -> |C|R|C|R|R|R|C|
# Message 6 is commited
await xmsg_ctxs[6].aclose()
await xmsg_ctxs[4].aclose()
assert job_state_store.job_state(job_id).cursor == "6"

0 comments on commit cd2028c

Please sign in to comment.