Skip to content

Commit

Permalink
Add TopicAdapter inventory that convert a topic into an inventory
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Jul 25, 2023
1 parent d38168f commit 45f7a4e
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 9 deletions.
7 changes: 7 additions & 0 deletions src/saturn_engine/utils/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
AsyncFNone = t.TypeVar("AsyncFNone", bound=t.Callable[..., Awaitable])


@contextlib.asynccontextmanager
async def opened_acontext(ctx: t.AsyncContextManager, value: T) -> t.AsyncIterator[T]:
async with contextlib.AsyncExitStack() as stack:
stack.push_async_exit(ctx)
yield value


async def aiter2agen(iterator: AsyncIterator[T]) -> AsyncGenerator[T, None]:
"""
Convert an async iterator into an async generator.
Expand Down
46 changes: 46 additions & 0 deletions src/saturn_engine/worker/inventories/topic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import typing as t

import dataclasses
from contextlib import AsyncExitStack

from saturn_engine.core.api import ComponentDefinition
from saturn_engine.core.types import Cursor
from saturn_engine.worker.inventory import Item
from saturn_engine.worker.inventory import IteratorInventory
from saturn_engine.worker.services import Services


@dataclasses.dataclass
class ItemAdapter(Item):
_context: AsyncExitStack = None # type: ignore[assignment]

async def __aexit__(self, *exc: t.Any) -> t.Optional[bool]:
return await self._context.__aexit__(*exc)


class TopicAdapter(IteratorInventory):
@dataclasses.dataclass
class Options:
topic: ComponentDefinition

def __init__(self, options: Options, services: Services, **kwargs: object) -> None:
# This import must be done late since work_factory depends on this module.
from saturn_engine.worker.work_factory import build_topic

self.topic = build_topic(options.topic, services=services)

async def iterate(self, after: t.Optional[Cursor] = None) -> t.AsyncIterator[Item]:
async for message_ctx in self.topic.run():
try:
async with AsyncExitStack() as stack:
message = await stack.enter_async_context(message_ctx)
yield ItemAdapter(
id=message.id,
cursor=None,
args=message.args,
tags=message.tags,
metadata=message.metadata,
_context=stack.pop_all(),
)
except Exception:
self.logger.exception("Failed to convert message")
9 changes: 8 additions & 1 deletion src/saturn_engine/worker/inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from saturn_engine.core import Cursor
from saturn_engine.core import MessageId
from saturn_engine.utils.log import getLogger
from saturn_engine.utils.options import OptionsSchema

MISSING = object()
Expand Down Expand Up @@ -47,6 +48,12 @@ def __post_init__(self) -> None:
if self.cursor is MISSING:
self.cursor = Cursor(self.id)

async def __aenter__(self) -> "Item":
return self

async def __aexit__(self, *exc: t.Any) -> t.Optional[bool]:
return None


class MaxRetriesError(Exception):
pass
Expand Down Expand Up @@ -99,7 +106,7 @@ async def iterate(self, after: t.Optional[Cursor] = None) -> AsyncIterator[Item]

@cached_property
def logger(self) -> logging.Logger:
return logging.getLogger(__name__ + ".Inventory")
return getLogger(__name__, self)


class IteratorInventory(Inventory):
Expand Down
21 changes: 13 additions & 8 deletions src/saturn_engine/worker/job.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import typing as t

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager

from saturn_engine.core import MessageId
from saturn_engine.core import TopicMessage
from saturn_engine.core.api import QueueItemWithState
from saturn_engine.utils.log import getLogger
from saturn_engine.worker.inventories import Inventory
from saturn_engine.worker.inventory import Item
from saturn_engine.worker.services import Services
from saturn_engine.worker.services.job_state.service import JobStateService
from saturn_engine.worker.topics import Topic
Expand Down Expand Up @@ -35,14 +37,7 @@ async def run(self) -> AsyncGenerator[TopicOutput, None]:
try:
async for item in self.inventory.iterate(after=cursor):
cursor = item.cursor
message = TopicMessage(
id=MessageId(item.id),
args=item.args,
tags=item.tags,
metadata=item.metadata | {"job": {"cursor": cursor}},
)

yield message
yield self.item_to_topic(item)

if cursor:
self.state_service.set_job_cursor(
Expand All @@ -53,3 +48,13 @@ async def run(self) -> AsyncGenerator[TopicOutput, None]:
except Exception as e:
self.logger.exception("Exception raised from job")
self.state_service.set_job_failed(self.queue_item.name, error=e)

@asynccontextmanager
async def item_to_topic(self, item_ctx: Item) -> t.AsyncIterator[TopicMessage]:
async with item_ctx as item:
yield TopicMessage(
id=MessageId(item.id),
args=item.args,
tags=item.tags,
metadata=item.metadata | {"job": {"cursor": item.cursor}},
)
25 changes: 25 additions & 0 deletions tests/utils/test_asyncutils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import typing as t

import asyncio
from contextlib import AsyncExitStack
from contextlib import asynccontextmanager
from unittest.mock import Mock

import pytest

from saturn_engine.utils.asyncutils import DelayedThrottle
from saturn_engine.utils.asyncutils import opened_acontext


async def test_delayed_task() -> None:
Expand Down Expand Up @@ -157,3 +160,25 @@ async def wait_event(x: int) -> int:
assert (await t1) == 2
assert (await t2) == 2
assert (await t3) == 3


async def test_opened_acontext() -> None:
mock = Mock()

@asynccontextmanager
async def context() -> t.AsyncIterator[int]:
mock("before")
yield 1
mock("after")

stack = AsyncExitStack()
value = await stack.enter_async_context(context())
assert value == 1
mock.assert_called_once_with("before")
mock.reset_mock()

async with opened_acontext(stack, value) as opened_value:
assert opened_value == 1
mock.assert_not_called()

mock.assert_called_once_with("after")
51 changes: 51 additions & 0 deletions tests/worker/inventories/test_topic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import typing as t

from contextlib import asynccontextmanager

import asyncstdlib as alib
import pytest

from saturn_engine.core.topic import TopicMessage
from saturn_engine.utils.inspect import get_import_name
from saturn_engine.worker.inventories.topic import TopicAdapter
from saturn_engine.worker.topic import Topic
from saturn_engine.worker.topic import TopicOutput


class FakeTopic(Topic):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
self.processing = 0
self.processed = 0

async def run(self) -> t.AsyncGenerator[TopicOutput, None]:
for x in range(10):
yield self.context(TopicMessage(args={"x": x}))

@asynccontextmanager
async def context(self, message: TopicMessage) -> t.AsyncIterator[TopicMessage]:
self.processing += 1
yield message
self.processed += 1


@pytest.mark.asyncio
async def test_static_inventory() -> None:
inventory = TopicAdapter.from_options(
{
"topic": {
"name": "topic",
"type": get_import_name(FakeTopic),
},
},
services=None,
)
topic: FakeTopic = t.cast(FakeTopic, inventory.topic)

messages = await alib.list(inventory.iterate())

assert topic.processing == 10
for i, ctx in enumerate(messages):
async with ctx as message:
assert message.args["x"] == i
assert message.cursor is None
assert topic.processed == i

0 comments on commit 45f7a4e

Please sign in to comment.