Skip to content

Commit

Permalink
Add fanin inventory to create a single inventory out of many
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Jul 24, 2023
1 parent d38168f commit 13e6644
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/saturn_engine/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as t

import builtins
import collections
import enum
import threading
Expand Down Expand Up @@ -194,3 +195,11 @@ def __setitem__(self, name: str, value: t.Any) -> None:

def assert_never(x: t.NoReturn) -> t.NoReturn:
raise AssertionError("Unhandled type: {}".format(type(x).__name__))


if not (ExceptionGroup := getattr(builtins, "ExceptionGroup", None)): # type: ignore

class ExceptionGroup(Exception):
def __init__(self, msg: str, errors: list[Exception]) -> None:
super().__init__(msg)
self.errors = errors
56 changes: 56 additions & 0 deletions src/saturn_engine/utils/iterators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import typing as t

import asyncio
import contextlib

import asyncstdlib as alib

from . import ExceptionGroup

T = t.TypeVar("T")

Expand Down Expand Up @@ -58,3 +63,54 @@ async def async_enter(
continue
raise
yield (context, item)


@contextlib.asynccontextmanager
async def scoped_aiters(
*iterators: t.AsyncIterator[T],
) -> t.AsyncIterator[list[t.AsyncIterator[T]]]:
ctx = contextlib.AsyncExitStack()
scoped_aiters = [
await ctx.enter_async_context(alib.scoped_iter(i)) for i in iterators
]
async with ctx:
yield scoped_aiters


async def fanin(*iterators: t.AsyncIterator[T]) -> t.AsyncIterator[T]:
anext_tasks: dict[asyncio.Task, t.AsyncIterator[T]] = {
asyncio.create_task(alib.anext(i), name="fanin.anext"): i for i in iterators
}
errors: list[Exception] = []
while True:
if not anext_tasks:
break

done, _ = await asyncio.wait(
anext_tasks.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
iterator = anext_tasks.pop(task)
if task.cancelled():
continue

e = task.exception()
if e is None:
yield task.result()
elif isinstance(e, StopAsyncIteration):
continue
elif isinstance(e, Exception):
for task in anext_tasks:
if task not in done:
task.cancel()
errors.append(e)
else:
raise e

if not errors:
anext_tasks[
asyncio.create_task(alib.anext(iterator), name="fanin.anext")
] = iterator

if errors:
raise ExceptionGroup("One iterator failed", errors)
46 changes: 46 additions & 0 deletions src/saturn_engine/worker/inventories/fanin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import typing as t

import dataclasses
import json
from contextlib import AsyncExitStack

import asyncstdlib as alib

from saturn_engine.core.api import ComponentDefinition
from saturn_engine.core.types import Cursor
from saturn_engine.utils import iterators
from saturn_engine.worker.inventory import Item
from saturn_engine.worker.inventory import IteratorInventory
from saturn_engine.worker.services import Services
from saturn_engine.worker.work_factory import build_inventory


class FanIn(IteratorInventory):
@dataclasses.dataclass
class Options:
inputs: list[ComponentDefinition]

def __init__(self, options: Options, services: Services, **kwargs: object) -> None:
super().__init__()

self.inputs = {
input_def.name: build_inventory(input_def, services=services)
for input_def in options.inputs
}

async def iterate(self, after: t.Optional[Cursor] = None) -> t.AsyncIterator[Item]:
cursors = json.loads(after) if after else {}

aiters = [
alib.map(lambda m: (k, m), i.iterate(after=cursors.get(k)))
for k, i in self.inputs.items()
]
ctx = AsyncExitStack()
scoped_aiters = [
await ctx.enter_async_context(alib.scoped_iter(i)) for i in aiters
]

async with ctx:
async for name, message in iterators.fanin(*scoped_aiters):
cursors[name] = message.cursor
yield dataclasses.replace(message, cursor=Cursor(json.dumps(cursors)))
2 changes: 1 addition & 1 deletion src/saturn_engine/worker/inventories/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, options: Options, **kwargs: object) -> None:
self.items = options.items

async def next_batch(self, after: t.Optional[Cursor] = None) -> list[Item]:
begin = int(after) + 1 if after else 0
begin = int(after) + 1 if after is not None else 0
return [
Item(id=MessageId(str(i)), args=args)
for i, args in enumerate(self.items[begin:], start=begin)
Expand Down
42 changes: 42 additions & 0 deletions tests/utils/test_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio

import asyncstdlib as alib
import pytest

from saturn_engine.utils import iterators

Expand Down Expand Up @@ -49,3 +50,44 @@ async def test_flatten() -> None:
flatten_it = iterators.async_flatten(iterator)
items = await alib.list(flatten_it)
assert items == [1, 2, 3, 4, 5]


async def test_fanin() -> None:
async def fast() -> t.AsyncIterator[int]:
for i in range(5):
await asyncio.sleep(1.1)
yield i

async def slow() -> t.AsyncIterator[int]:
for i in range(10, 14):
await asyncio.sleep(2)
yield i

assert await alib.list(iterators.fanin(fast(), slow())) == [
0,
10,
1,
2,
11,
3,
4,
12,
13,
]


async def test_fanin_fails() -> None:
async def nosleep() -> t.AsyncIterator[int]:
for i in range(5):
yield i

async def fail() -> t.AsyncIterator[int]:
yield 100
raise ValueError("Fail")

results = []
with pytest.raises(ExceptionGroup):
async for x in iterators.fanin(nosleep(), fail(), nosleep()):
results.append(x)

assert list(sorted(results)) == [0, 0, 1, 1, 100]
36 changes: 36 additions & 0 deletions tests/worker/inventories/test_fanin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncstdlib as alib
import pytest

from saturn_engine.core.types import Cursor
from saturn_engine.core.types import MessageId
from saturn_engine.worker.inventories.fanin import FanIn
from saturn_engine.worker.inventory import Item


@pytest.mark.asyncio
async def test_fanin_inventory() -> None:
inventory = FanIn.from_options(
{
"inputs": [
{
"name": "a",
"type": "StaticInventory",
"options": {"items": [{"n": 0}, {"n": 1}, {"n": 2}, {"n": 3}]},
},
{
"name": "b",
"type": "StaticInventory",
"options": {"items": [{"n": 4}, {"n": 5}]},
},
],
"batch_size": 10,
},
services=None,
)
messages = await alib.list(inventory.iterate())
assert {m.args["n"] for m in messages} == set(range(6))

messages = await alib.list(inventory.iterate(after=Cursor('{"a": "3", "b": "0"}')))
assert messages == [
Item(id=MessageId("1"), cursor='{"a": "3", "b": "1"}', args={"n": 5})
]

0 comments on commit 13e6644

Please sign in to comment.