Skip to content

Commit

Permalink
scheduler: Fix cleanup failing due to improper generator handling
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Jul 5, 2023
1 parent 32a75da commit 52bb67f
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 56 deletions.
26 changes: 20 additions & 6 deletions src/saturn_engine/utils/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ async def wait(self) -> set[asyncio.Task]:

return done

async def wait_all(self) -> set[asyncio.Task]:
if not self.tasks:
return set()
done, _ = await asyncio.wait(self.tasks)
self.tasks.difference_update(done)
return done

async def close(self, timeout: t.Optional[float] = None) -> None:
# Cancel the update event task.
self.updated_task.cancel()
Expand Down Expand Up @@ -120,25 +127,32 @@ def stop(self) -> None:
self.is_running = False
self.notify()

async def close(self, timeout: t.Optional[float] = None) -> None:
async def close(
self, timeout: t.Optional[float] = None, *, wait_all: bool = False
) -> None:
if self.is_running:
# Stop the runner.
self.stop()
# Wait for the running task to complete.
if self._runner_task:
await self._runner_task

if wait_all:
tasks = await asyncio.wait_for(self.wait_all(), timeout=timeout)
self._log_tasks(tasks)

# Clean the tasks.
await super().close(timeout=timeout)

async def run(self) -> None:
while self.is_running:
done = await self.wait()
for task in done:
if not task.cancelled() and isinstance(task.exception(), Exception):
self.logger.error(
"Task '%s' failed", task, exc_info=task.exception()
)
self._log_tasks(done)

def _log_tasks(self, tasks: set[asyncio.Task]) -> None:
for task in tasks:
if not task.cancelled() and isinstance(task.exception(), Exception):
self.logger.error("Task '%s' failed", task, exc_info=task.exception())


class DelayedThrottle(t.Generic[AsyncFNone]):
Expand Down
2 changes: 1 addition & 1 deletion src/saturn_engine/worker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def run_sync(self) -> None:
for resource in work_sync.resources.drop:
await self.resources_manager.remove(resource.key)
for queue in work_sync.queues.drop:
await self.executors_manager.remove_queue(queue)
self.executors_manager.remove_queue(queue)
for executor in work_sync.executors.drop:
await self.executors_manager.remove_executor(executor)

Expand Down
49 changes: 37 additions & 12 deletions src/saturn_engine/worker/executors/executable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import AsyncContextManager
from typing import Optional
import typing as t

import asyncio
import contextlib
Expand All @@ -26,7 +25,7 @@ def __init__(
message: PipelineMessage,
parker: Parkers,
output: dict[str, list[Topic]],
message_context: Optional[AsyncContextManager] = None,
message_context: t.Optional[t.AsyncContextManager] = None,
):
self.message = message
self._context = contextlib.AsyncExitStack()
Expand Down Expand Up @@ -103,7 +102,7 @@ async def run(self) -> AsyncGenerator[ExecutableMessage, None]:
try:
async for message in self.topic.run():
context = None
if isinstance(message, AsyncContextManager):
if isinstance(message, t.AsyncContextManager):
context = message
message = await message.__aenter__()

Expand All @@ -120,8 +119,9 @@ async def run(self) -> AsyncGenerator[ExecutableMessage, None]:
)
await self.services.s.hooks.message_polled.emit(executable_message)
await self.parkers.wait()
executable_message._context.enter_context(self.pending_context())
yield executable_message
with self.pending_context() as message_context:
executable_message._context.enter_context(message_context())
yield executable_message
finally:
await self.close()

Expand All @@ -144,14 +144,39 @@ async def wait_for_done(self) -> None:
await self.done.wait()

@contextlib.contextmanager
def pending_context(self) -> Iterator[None]:
def pending_context(self) -> Iterator[t.Callable[[], t.ContextManager]]:
# Yield a new contextmanager to be attached to the message.
# This allow tracking when all message from this job have been processed
# so we can properly clean the jobs resources afterward. We need to yield
# the context from a context so that if the job's yield has an exception
# in the case of cancellation we won't to mark the pending message as not
# being pending anymore.
self.pending_messages_count += 1
processed = False

@contextlib.contextmanager
def message_context() -> Iterator[None]:
nonlocal processed

try:
yield
finally:
if not processed:
self.message_processed()
processed = True

try:
yield
finally:
self.pending_messages_count -= 1
if self.is_closed and self.pending_messages_count == 0:
self.done.set()
yield message_context
except BaseException:
if not processed:
self.message_processed()
processed = True
raise

def message_processed(self) -> None:
self.pending_messages_count -= 1
if self.is_closed and self.pending_messages_count == 0:
self.done.set()

@cached_property
def config(self) -> LazyConfig:
Expand Down
18 changes: 13 additions & 5 deletions src/saturn_engine/worker/executors/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from saturn_engine.core import api
from saturn_engine.utils.asyncutils import TasksGroupRunner
from saturn_engine.utils.log import getLogger
from saturn_engine.worker.services import Services

from . import Executor
Expand All @@ -21,27 +22,30 @@ def __init__(
self.services = services
self.executors: dict[str, ExecutorWorker] = {}
self.executors_tasks_group = TasksGroupRunner(name="executors")
self.logger = getLogger(__name__, self)

def start(self) -> None:
self.executors_tasks_group.start()

async def close(self) -> None:
self.logger.debug("Closing executors")
await asyncio.gather(
*[executor.close() for executor in self.executors.values()]
)
await self.executors_tasks_group.close()
self.logger.debug("Stopping executors tasks")
await self.executors_tasks_group.close(wait_all=True)

def add_queue(self, queue: ExecutableQueue) -> None:
executor = self.executors.get(queue.executor)
if not executor:
raise ValueError("Executor missing")
executor.add_schedulable(queue)

async def remove_queue(self, queue: ExecutableQueue) -> None:
def remove_queue(self, queue: ExecutableQueue) -> None:
executor = self.executors.get(queue.executor)
if not executor:
return
await executor.remove_schedulable(queue)
executor.remove_schedulable(queue)

def add_executor(self, executor_definition: api.ComponentDefinition) -> None:
if executor_definition.name in self.executors:
Expand Down Expand Up @@ -79,6 +83,7 @@ def __init__(
services=services,
)
self.scheduler: Scheduler[ExecutableMessage] = Scheduler()
self.logger = getLogger(__name__, self)

@classmethod
def from_item(
Expand All @@ -104,13 +109,16 @@ async def run(self) -> None:
async for message in self.scheduler.run():
await self.services.s.hooks.message_scheduled.emit(message)
await self.executor_queue.submit(message)
self.logger.debug("Executor worker done")

async def close(self) -> None:
self.logger.debug("Closing scheduler")
await self.scheduler.close()
self.logger.debug("Closing executor queue")
await self.executor_queue.close()

def add_schedulable(self, schedulable: ExecutableQueue) -> None:
self.scheduler.add(schedulable)

async def remove_schedulable(self, schedulable: ExecutableQueue) -> None:
await self.scheduler.remove(schedulable)
def remove_schedulable(self, schedulable: ExecutableQueue) -> None:
self.scheduler.remove(schedulable)
55 changes: 27 additions & 28 deletions src/saturn_engine/worker/executors/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import typing as t

import asyncio
import contextlib
import dataclasses
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
Expand All @@ -23,6 +22,7 @@ class ScheduleSlot(t.Generic[T]):
generator: AsyncGenerator[T, None]
task: asyncio.Task
order: int = 0
is_running: bool = True


class Scheduler(t.Generic[T]):
Expand All @@ -34,6 +34,7 @@ def __init__(self) -> None:
self.schedule_slots = {}
self.tasks = {}
self.tasks_group = TasksGroup()
self.is_running = False

def add(self, item: Schedulable[T]) -> None:
generator = t.cast(AsyncGenerator[T, None], item.iterable.__aiter__())
Expand All @@ -44,39 +45,33 @@ def add(self, item: Schedulable[T]) -> None:
self.tasks[task] = item
self.tasks_group.add(task)

async def remove(self, item: Schedulable[T]) -> None:
schedule_slot = self.schedule_slots.pop(item, None)
if schedule_slot is None:
return
await self.stop_slot(schedule_slot)
def remove(self, item: Schedulable[T]) -> None:
schedule_slot = self.schedule_slots.get(item)
if schedule_slot:
self.stop_slot(schedule_slot)

async def close(self) -> None:
cleanup_tasks = []
self.is_running = False

await self.tasks_group.close()
for item in self.schedule_slots.values():
cleanup_tasks.append(self.stop_slot(item))
self.stop_slot(item)

await asyncio.gather(*cleanup_tasks)
self.schedule_slots.clear()
self.tasks.clear()

async def stop_slot(self, schedule_slot: ScheduleSlot[T]) -> None:
def stop_slot(self, schedule_slot: ScheduleSlot[T]) -> None:
schedule_slot.is_running = False
if not schedule_slot.task.done():
try:
schedule_slot.task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await schedule_slot.task
except Exception:
self.logger.exception("Failed to cancel item: %s", schedule_slot)
schedule_slot.task.cancel()

async def stop_slot_generator(self, schedule_slot: ScheduleSlot[T]) -> None:
try:
self.logger.debug("Closing slot: %s", schedule_slot)
await schedule_slot.generator.aclose()
except Exception:
self.logger.exception("Failed to close item: %s", schedule_slot)

async def run(self) -> AsyncIterator[T]:
while True:
self.is_running = True
while self.is_running or self.tasks_group.tasks:
done = await self.tasks_group.wait()
if not done:
continue
Expand All @@ -103,7 +98,7 @@ async def process_task(self, task: asyncio.Task) -> AsyncIterator[T]:
if exception is None:
yield task.result()
elif isinstance(exception, StopAsyncIteration):
await self.remove(item)
self.remove(item)
elif isinstance(exception, asyncio.CancelledError):
pass
elif exception:
Expand All @@ -122,13 +117,17 @@ async def process_task(self, task: asyncio.Task) -> AsyncIterator[T]:
raise
else:
# Requeue the __anext__ task to process next item.
self._requeue_task(item)

def _requeue_task(self, item: Schedulable[T]) -> None:
schedule_slot = self.schedule_slots.get(item)
if schedule_slot is None:
return

schedule_slot = self.schedule_slots.get(item)
if schedule_slot:
if not schedule_slot.is_running:
del self.schedule_slots[item]
await self.stop_slot_generator(schedule_slot)
else:
await self._requeue_task(item=item, schedule_slot=schedule_slot)

async def _requeue_task(
self, *, item: Schedulable[T], schedule_slot: ScheduleSlot[T]
) -> None:
name = f"scheduler.anext({item.name})"
anext = t.cast(Coroutine[t.Any, t.Any, T], schedule_slot.generator.__anext__())
new_task = asyncio.create_task(anext, name=name)
Expand Down
12 changes: 8 additions & 4 deletions tests/worker/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def test_scheduler(
assert messages == {sentinel.schedulable1: 5, sentinel.schedulable2: 5}

# Removing an item should cancel its task.
await scheduler.remove(schedulable2)
scheduler.remove(schedulable2)

messages.clear()
async for item in alib.islice(generator, 10):
Expand Down Expand Up @@ -141,7 +141,11 @@ async def error_close() -> AsyncGenerator:

schedulable = make_schedulable(iterable=error_close())
scheduler.add(schedulable)
async for item in alib.islice(scheduler.run(), 10):
pass
await scheduler.close()
async with alib.scoped_iter(scheduler.run()) as generator:
async for item in alib.islice(generator, 10):
pass
await scheduler.close()
async for item in generator:
raise AssertionError()

close_mock.assert_called_once()

0 comments on commit 52bb67f

Please sign in to comment.