Skip to content

Commit

Permalink
JobState: Add method to fetch cursor
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Jul 7, 2023
1 parent f48a053 commit bc9ea93
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/saturn_engine/client/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def sync(self, sync: JobsStatesSyncInput) -> JobsStatesSyncResponse:
async def fetch_cursors_states(
self, cursors: FetchCursorsStatesInput
) -> FetchCursorsStatesResponse:
state_url = urlcat(self.base_url, "api/jobs/_states/cursors")
state_url = urlcat(self.base_url, "api/jobs/_states/fetch")
json = asdict(cursors)
async with self.http_client.put(state_url, json=json) as response:
async with self.http_client.post(state_url, json=json) as response:
return fromdict(await response.json(), FetchCursorsStatesResponse)
35 changes: 29 additions & 6 deletions src/saturn_engine/utils/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
K = t.TypeVar("K")
V = t.TypeVar("V")

AsyncFNone = t.TypeVar("AsyncFNone", bound=t.Callable[..., Awaitable[None]])
AsyncFNone = t.TypeVar("AsyncFNone", bound=t.Callable[..., Awaitable])


async def aiter2agen(iterator: AsyncIterator[T]) -> AsyncGenerator[T, None]:
Expand Down Expand Up @@ -168,6 +168,8 @@ def __init__(self, func: AsyncFNone, *, delay: float) -> None:
tuple[tuple[t.Any, ...], dict[str, t.Any]]
] = None

self._call_fut: t.Optional[asyncio.Future] = None

@property
def is_idle(self) -> bool:
return self.delayed_task is None
Expand All @@ -180,32 +182,53 @@ def is_waiting(self) -> bool:
def is_running(self) -> bool:
return not self.is_idle and not self.is_waiting

def __call__(self, *args: t.Any, **kwargs: t.Any) -> None:
def call_nowait(self, *args: t.Any, **kwargs: t.Any) -> None:
self.delayed_params = (args, kwargs)

if self.is_idle:
name = f"{self.func.__qualname__}.delayed"
self.delayed_task = asyncio.create_task(self._delay_call(), name=name)

def __call__(self, *args: t.Any, **kwargs: t.Any) -> asyncio.Future:
if self._call_fut is None:
self._call_fut = asyncio.Future()

self.call_nowait(*args, **kwargs)
return self._call_fut

async def _delay_call(self) -> None:
call_fut = None
try:
self.flush_event.clear()
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(self.flush_event.wait(), timeout=self.delay)
try:
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(self.flush_event.wait(), timeout=self.delay)
finally:
# Ensure we set call_fut, even if the wait_for raised an exception.
call_fut = self._call_fut

if not self.delayed_params:
if call_fut:
call_fut.cancel()
return

args, kwargs = self.delayed_params
self.delayed_params = None
await self.func(*args, **kwargs)
self._call_fut = None
result = await self.func(*args, **kwargs)
if call_fut:
call_fut.set_result(result)
except BaseException as e:
if call_fut:
call_fut.set_exception(e)
raise
finally:
self.delayed_task = None

# If __call__ was called while we were calling func, we requeue a new task.
if self.delayed_params:
args, kwargs = self.delayed_params
self.__call__(*args, **kwargs)
self.call_nowait(*args, **kwargs)

async def cancel(self) -> None:
self.delayed_params = None
Expand Down
40 changes: 39 additions & 1 deletion src/saturn_engine/worker/services/job_state/service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@

import dataclasses
from collections import defaultdict

from saturn_engine.client.worker_manager import WorkerManagerClient
from saturn_engine.core import Cursor
from saturn_engine.core import JobId
from saturn_engine.core.api import FetchCursorsStatesInput
from saturn_engine.core.api import FetchCursorsStatesResponse
from saturn_engine.core.api import JobsStatesSyncInput
from saturn_engine.utils.asyncutils import DelayedThrottle

Expand All @@ -22,6 +27,31 @@ class Options:
auto_flush: bool = True


class CursorsStatesFetcher:
def __init__(self, *, client: WorkerManagerClient, fetch_delay: float = 0) -> None:
self.client = client
self.pending_queries: dict[JobId, set[Cursor]] = defaultdict(set)
self._delayed_fetch = DelayedThrottle(self._do_fetch, delay=fetch_delay)

async def _wait_fetch(self) -> None:
pass

async def _do_fetch(self) -> FetchCursorsStatesResponse:
queries = self.pending_queries
self.pending_queries = {}
cursors = {k: list(v) for k, v in queries.items()}
return await self.client.fetch_cursors_states(
FetchCursorsStatesInput(cursors=cursors)
)

async def fetch(
self, job_name: JobId, *, cursors: list[Cursor]
) -> dict[Cursor, dict]:
self.pending_queries[job_name].update(cursors)
result = await self._delayed_fetch()
return result.cursors.get(job_name, {})


class JobStateService(Service[Services, Options]):
name = "job_state"

Expand All @@ -33,6 +63,9 @@ class JobStateService(Service[Services, Options]):

async def open(self) -> None:
self._store = JobsStatesSyncStore()
self._cursors_fetcher = CursorsStatesFetcher(
client=self.services.api_client.client
)
self._delayed_flush = DelayedThrottle(
self.flush, delay=self.options.flush_delay
)
Expand Down Expand Up @@ -61,9 +94,14 @@ def set_job_cursor_state(
)
self._maybe_flush()

async def fetch_cursors_states(
self, job_name: JobId, *, cursors: list[Cursor]
) -> dict[Cursor, dict]:
return await self._cursors_fetcher.fetch(job_name, cursors=cursors)

def _maybe_flush(self) -> None:
if self.options.auto_flush:
self._delayed_flush()
self._delayed_flush.call_nowait()

async def flush(self) -> None:
with self._store.flush() as state:
Expand Down
8 changes: 8 additions & 0 deletions tests/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,11 @@ def register_hooks_handler(services: Services) -> AsyncMock:
async_context_mock_handler(_hooks_handler.message_published)
)
return _hooks_handler


class EqualAnyOrder:
def __init__(self, expected: t.Iterable):
self.expected = expected

def __eq__(self, other: t.Any) -> bool:
return list(sorted(self.expected)) == list(sorted(other))
63 changes: 53 additions & 10 deletions tests/utils/test_asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import asyncio
from unittest.mock import Mock

import pytest

from saturn_engine.utils.asyncutils import DelayedThrottle


Expand All @@ -22,9 +24,9 @@ async def func(*args: t.Any, **kwargs: t.Any) -> None:

# The call is delayed and done with the latest parameters.
func_wait.set()
delayed_func(1, a="b")
delayed_func.call_nowait(1, a="b")
await asyncio.sleep(4)
delayed_func(2, a="b")
delayed_func.call_nowait(2, a="b")
func_mock.assert_not_called()

await asyncio.sleep(2)
Expand All @@ -37,7 +39,7 @@ async def func(*args: t.Any, **kwargs: t.Any) -> None:
func_wait.clear()
func_mock.reset_mock()

delayed_func()
delayed_func.call_nowait()
await asyncio.sleep(4)
await delayed_func.cancel()
func_mock.assert_not_called()
Expand All @@ -48,7 +50,7 @@ async def func(*args: t.Any, **kwargs: t.Any) -> None:
func_wait.clear()
func_mock.reset_mock()

delayed_func()
delayed_func.call_nowait()
await asyncio.sleep(6)
func_mock.assert_called_once()
func_mock.reset_mock()
Expand All @@ -66,7 +68,7 @@ async def func(*args: t.Any, **kwargs: t.Any) -> None:
await delayed_func.flush()
func_mock.assert_not_called()

delayed_func()
delayed_func.call_nowait()
await delayed_func.flush()
func_mock.assert_called_once()
func_mock.reset_mock()
Expand All @@ -81,11 +83,11 @@ async def func(*args: t.Any, **kwargs: t.Any) -> None:
func_wait.clear()
func_mock.reset_mock()

delayed_func(1)
delayed_func.call_nowait(1)
await asyncio.sleep(6)
func_mock.assert_called_once_with(1)
func_mock.reset_mock()
delayed_func(2)
delayed_func.call_nowait(2)
func_wait.set()

await asyncio.sleep(6)
Expand All @@ -99,18 +101,59 @@ async def func(*args: t.Any, **kwargs: t.Any) -> None:
func_wait.clear()
func_mock.reset_mock()

delayed_func(1)
delayed_func.call_nowait(1)
await asyncio.sleep(6)
func_mock.assert_called_once_with(1)
func_mock.reset_mock()
delayed_func(2)
delayed_func.call_nowait(2)

await delayed_func.cancel()
await asyncio.sleep(6)
assert isinstance(func_mock.call_args.kwargs["error"], asyncio.CancelledError)
func_mock.reset_mock()
func_wait.set()

delayed_func(3)
delayed_func.call_nowait(3)
await asyncio.sleep(6)
func_mock.assert_called_once_with(3)


async def test_delayed_throttle_wait() -> None:
# We can await a call to get the latest call result.
async def return_arg(x: int) -> int:
return x

delayed_func = DelayedThrottle(return_arg, delay=5)

t1 = delayed_func(1)
await asyncio.sleep(1)
delayed_func.call_nowait(2)
assert (await t1) == 2

# Call getting cancelled while waiting see the exception.
t1 = delayed_func(1)
await asyncio.sleep(1)
await delayed_func.cancel()
with pytest.raises(asyncio.CancelledError):
await t1

# A call while another call is in progress will have its own future.
wait_func = asyncio.Event()
resume_func = asyncio.Event()

async def wait_event(x: int) -> int:
wait_func.set()
await resume_func.wait()
return x

delayed_func = DelayedThrottle(wait_event, delay=5)

t1 = delayed_func(1)
t2 = delayed_func(2)
await wait_func.wait()
t3 = delayed_func(3)
resume_func.set()

assert (await t1) == 2
assert (await t2) == 2
assert (await t3) == 3
44 changes: 44 additions & 0 deletions tests/worker/services/job_state/test_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing as t

import asyncio

import pytest

from saturn_engine.core import Cursor
Expand All @@ -8,6 +10,7 @@
from saturn_engine.worker.services.job_state.service import JobStateService
from saturn_engine.worker.services.manager import ServicesManager
from tests.conftest import FreezeTime
from tests.utils import EqualAnyOrder
from tests.utils import HttpClientMock


Expand Down Expand Up @@ -96,3 +99,44 @@ async def test_job_state_update(
http_client_mock.reset_mock()
await job_state_service.flush()
http_client_mock.put("http://127.0.0.1:5000/api/jobs/_states").assert_not_called()


async def test_job_state_fetch_cursors(
http_client_mock: HttpClientMock,
job_state_service: JobStateService,
) -> None:
http_client_mock.post(
"http://127.0.0.1:5000/api/jobs/_states/fetch"
).return_value = {
"cursors": {
"job-1": {
"a": {"x": 1},
"b": {"x": 2},
},
"job-2": {
"c": None,
},
}
}

fetch_1 = job_state_service.fetch_cursors_states(
JobId("job-1"), cursors=[Cursor("a"), Cursor("b")]
)
fetch_2 = job_state_service.fetch_cursors_states(
JobId("job-2"), cursors=[Cursor("c")]
)
states_1, states_2 = await asyncio.gather(fetch_1, fetch_2)

assert states_1 == {"a": {"x": 1}, "b": {"x": 2}}
assert states_2 == {"c": None}

http_client_mock.post(
"http://127.0.0.1:5000/api/jobs/_states/fetch"
).assert_called_once_with(
json={
"cursors": {
"job-1": EqualAnyOrder(["a", "b"]),
"job-2": ["c"],
}
}
)

0 comments on commit bc9ea93

Please sign in to comment.