-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add fanin inventory to create a single inventory out of many
- Loading branch information
Showing
6 changed files
with
190 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import typing as t | ||
|
||
import dataclasses | ||
import json | ||
from contextlib import AsyncExitStack | ||
|
||
import asyncstdlib as alib | ||
|
||
from saturn_engine.core.api import ComponentDefinition | ||
from saturn_engine.core.types import Cursor | ||
from saturn_engine.utils import iterators | ||
from saturn_engine.worker.inventory import Item | ||
from saturn_engine.worker.inventory import IteratorInventory | ||
from saturn_engine.worker.services import Services | ||
from saturn_engine.worker.work_factory import build_inventory | ||
|
||
|
||
class FanIn(IteratorInventory): | ||
@dataclasses.dataclass | ||
class Options: | ||
inputs: list[ComponentDefinition] | ||
|
||
def __init__(self, options: Options, services: Services, **kwargs: object) -> None: | ||
super().__init__() | ||
|
||
self.inputs = { | ||
input_def.name: build_inventory(input_def, services=services) | ||
for input_def in options.inputs | ||
} | ||
|
||
async def iterate(self, after: t.Optional[Cursor] = None) -> t.AsyncIterator[Item]: | ||
cursors = json.loads(after) if after else {} | ||
|
||
aiters = [ | ||
alib.map(lambda m: (k, m), i.iterate(after=cursors.get(k))) | ||
for k, i in self.inputs.items() | ||
] | ||
ctx = AsyncExitStack() | ||
scoped_aiters = [ | ||
await ctx.enter_async_context(alib.scoped_iter(i)) for i in aiters | ||
] | ||
|
||
async with ctx: | ||
async for name, message in iterators.fanin(*scoped_aiters): | ||
cursors[name] = message.cursor | ||
yield dataclasses.replace(message, cursor=Cursor(json.dumps(cursors))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import asyncstdlib as alib | ||
import pytest | ||
|
||
from saturn_engine.core.types import Cursor | ||
from saturn_engine.core.types import MessageId | ||
from saturn_engine.worker.inventories.fanin import FanIn | ||
from saturn_engine.worker.inventory import Item | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_fanin_inventory() -> None: | ||
inventory = FanIn.from_options( | ||
{ | ||
"inputs": [ | ||
{ | ||
"name": "a", | ||
"type": "StaticInventory", | ||
"options": {"items": [{"n": 0}, {"n": 1}, {"n": 2}, {"n": 3}]}, | ||
}, | ||
{ | ||
"name": "b", | ||
"type": "StaticInventory", | ||
"options": {"items": [{"n": 4}, {"n": 5}]}, | ||
}, | ||
], | ||
"batch_size": 10, | ||
}, | ||
services=None, | ||
) | ||
messages = await alib.list(inventory.iterate()) | ||
assert {m.args["n"] for m in messages} == set(range(6)) | ||
|
||
messages = await alib.list(inventory.iterate(after=Cursor('{"a": "3", "b": "0"}'))) | ||
assert messages == [ | ||
Item(id=MessageId("1"), cursor='{"a": "3", "b": "1"}', args={"n": 5}) | ||
] |