Skip to content

Commit

Permalink
services: Add JobState service and store
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Jul 4, 2023
1 parent bba7d12 commit 1705bc3
Show file tree
Hide file tree
Showing 13 changed files with 413 additions and 21 deletions.
1 change: 1 addition & 0 deletions src/saturn_engine/config_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ServicesManagerConfig:
@dataclasses.dataclass
class SaturnConfig:
env: Env
worker_id: str
# Worker Manager URL used by clients and workers.
worker_manager_url: str
services_manager: ServicesManagerConfig
Expand Down
2 changes: 2 additions & 0 deletions src/saturn_engine/default_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing as t

import os
import socket

from .config import Env
from .config import RabbitMQConfig
Expand All @@ -13,6 +14,7 @@

class config(SaturnConfig):
env = Env(os.environ.get("SATURN_ENV", "development"))
worker_id = socket.gethostname()
worker_manager_url = os.environ.get(
"SATURN_WORKER_MANAGER_URL", "http://127.0.0.1:5000"
)
Expand Down
24 changes: 24 additions & 0 deletions src/saturn_engine/worker/services/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from saturn_engine.client.worker_manager import WorkerManagerClient

from . import BaseServices
from . import Service
from .http_client import HttpClient


class Services(BaseServices):
http_client: HttpClient


class ApiClient(Service[Services, None]):
name = "api_client"

Services = Services

client: WorkerManagerClient

async def open(self) -> None:
self.client = WorkerManagerClient(
http_client=self.services.http_client.session,
base_url=self.services.config.c.worker_manager_url,
worker_id=self.services.config.c.worker_id,
)
Empty file.
79 changes: 79 additions & 0 deletions src/saturn_engine/worker/services/job_state/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import dataclasses

from saturn_engine.core import Cursor
from saturn_engine.core import JobId
from saturn_engine.core.api import JobsStatesSyncInput
from saturn_engine.utils.asyncutils import DelayedThrottle

from .. import BaseServices
from .. import Service
from ..api_client import ApiClient
from .store import JobsStates
from .store import JobsStatesSyncStore


class Services(BaseServices):
api_client: ApiClient


@dataclasses.dataclass
class Options:
flush_delay: float = 10.0
auto_flush: bool = True


class JobStateService(Service[Services, Options]):
name = "job_state"

Services = Services
Options = Options

_store: JobsStatesSyncStore
_delayed_flush: DelayedThrottle

async def open(self) -> None:
self._store = JobsStatesSyncStore()
self._delayed_flush = DelayedThrottle(
self.flush, delay=self.options.flush_delay
)

def set_job_cursor(self, job_name: JobId, *, cursor: Cursor) -> None:
self._store.set_job_cursor(job_name, cursor)
self._maybe_flush()

def set_job_completed(self, job_name: JobId) -> None:
self._store.set_job_completed(job_name)
self._maybe_flush()

def set_job_failed(self, job_name: JobId, *, error: Exception) -> None:
self._store.set_job_failed(job_name, f"{type(error).__name__}: {error}")
self._maybe_flush()

def set_job_cursor_state(
self,
job_name: JobId,
*,
cursor: Cursor,
cursor_state: dict,
) -> None:
self._store.set_job_cursor_state(
job_name, cursor=cursor, cursor_state=cursor_state
)
self._maybe_flush()

def _maybe_flush(self) -> None:
if self.options.auto_flush:
self._delayed_flush()

async def flush(self) -> None:
with self._store.flush() as state:
if not state.is_empty:
await self.flush_state(state)

async def flush_state(self, state: JobsStates) -> None:
# Have to cast the job states to dict since defaultdict break dataclasses.
state = dataclasses.replace(state, jobs=dict(state.jobs))
await self.services.api_client.client.sync(JobsStatesSyncInput(state=state))

async def close(self) -> None:
await self._delayed_flush.flush()
100 changes: 100 additions & 0 deletions src/saturn_engine/worker/services/job_state/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import typing as t

import contextlib
import dataclasses
from collections import defaultdict

from saturn_engine.core import Cursor
from saturn_engine.core import JobId
from saturn_engine.core import api
from saturn_engine.utils import utcnow


@dataclasses.dataclass
class JobCompletion(api.JobCompletion):
def merge(self, other: "JobCompletion") -> "JobCompletion":
self.completed_at = other.completed_at
self.error = other.error
return self


@dataclasses.dataclass
class JobState(api.JobState):
completion: t.Optional[JobCompletion] = None

def merge(self, other: "JobState") -> "JobState":
if other.cursor:
self.cursor = other.cursor
self.cursors_states.update(other.cursors_states)
if other.completion:
completion = other.completion
if self.completion:
completion = self.completion.merge(other.completion)
self.completion = completion
return self


@dataclasses.dataclass
class JobsStates(api.JobsStates):
jobs: dict[JobId, JobState] = dataclasses.field(
default_factory=lambda: defaultdict(JobState)
)

def merge(self, other: "JobsStates") -> "JobsStates":
for job, state in other.jobs.items():
new_state = state
if job in self.jobs:
new_state = self.jobs[job].merge(state)
self.jobs[job] = new_state
return self

@property
def is_empty(self) -> bool:
return not self.jobs


class JobsStatesSyncStore:
def __init__(self) -> None:
self._current_state = JobsStates()
self._flushing_state: t.Optional[JobsStates] = None

def set_job_cursor(self, job_name: JobId, cursor: Cursor) -> None:
self._current_state.jobs[job_name].cursor = cursor

def set_job_completed(self, job_name: JobId) -> None:
self._current_state.jobs[job_name].completion = JobCompletion(
completed_at=utcnow(),
)

def set_job_failed(self, job_name: JobId, error: str) -> None:
self._current_state.jobs[job_name].completion = JobCompletion(
completed_at=utcnow(),
error=error,
)

def set_job_cursor_state(
self,
job_name: JobId,
*,
cursor: Cursor,
cursor_state: dict,
) -> None:
self._current_state.jobs[job_name].cursors_states[cursor] = cursor_state

@contextlib.contextmanager
def flush(self) -> t.Iterator[JobsStates]:
"""Allow to retrieve the jobs state in a safe-way for flushing.
The yielded object won't be updated while inside the context.
If an error happen inside the context, the state is restored and
merged with any change that occured during the flush.
"""
self._flushing_state = self._current_state
self._current_state = JobsStates()

try:
yield self._flushing_state
except BaseException:
self._current_state = self._flushing_state.merge(self._current_state)
raise
finally:
self._flushing_state = None
36 changes: 21 additions & 15 deletions src/saturn_engine/worker/services/manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Type

from itertools import chain

from saturn_engine.utils import inspect as extra_inspect

from ..resources.manager import ResourcesManager
Expand All @@ -14,22 +16,20 @@

class ServicesManager:
def __init__(self, config: Config) -> None:
self.strict = config.c.services_manager.strict_services
self.services: Services = ServicesNamespace(
config=config,
hooks=Hooks(),
resources_manager=ResourcesManager(),
strict=True,
strict=self.strict,
)
self.loaded_services: list[Service] = []
self.strict = config.c.services_manager.strict_services
self.is_opened = False

# Some services are required for saturn to work at all.
for service_cls in BASE_SERVICES:
self._load_service(service_cls)

# Load optional services based on config.
for service_cls_path in config.c.services_manager.services:
for service_cls_path in chain(
BASE_SERVICES, config.c.services_manager.services
):
service_cls = extra_inspect.import_name(service_cls_path)
self._load_service(service_cls)

Expand Down Expand Up @@ -70,16 +70,22 @@ def _load_service(self, service_cls: Type[TService]) -> TService:
self.services[service_cls.name] = service
return service

# Useful for tests loading mock service.
async def _reload_service(self, service_cls: Type[TService]) -> TService:
if old_service := self.services.pop(service_cls.name, None):
self.loaded_services.remove(old_service)
await old_service.close()
service = self._load_service(service_cls)
await service.open()
return service

def has_loaded(self, service_cls: Type[TService]) -> bool:
return service_cls.name in self.services


from .http_client import HttpClient
from .job_store import JobStoreService
from .tasks_runner import TasksRunnerService

BASE_SERVICES: list[Type[Service]] = [
HttpClient,
JobStoreService,
TasksRunnerService,
BASE_SERVICES: list[str] = [
"saturn_engine.worker.services.http_client.HttpClient",
"saturn_engine.worker.services.api_client.ApiClient",
"saturn_engine.worker.services.tasks_runner.TasksRunnerService",
"saturn_engine.worker.services.job_store.JobStoreService",
]
8 changes: 2 additions & 6 deletions src/saturn_engine/worker/work_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from saturn_engine.worker.resources.manager import ResourceRateLimit
from saturn_engine.worker.resources.provider import ResourcesProvider
from saturn_engine.worker.services import Services
from saturn_engine.worker.services.http_client import HttpClient
from saturn_engine.worker.services.api_client import ApiClient

T = TypeVar("T")

Expand Down Expand Up @@ -62,11 +62,7 @@ def __init__(
self, *, services: Services, client: Optional[WorkerManagerClient] = None
) -> None:
self.logger = getLogger(__name__, self)
http_client = services.cast_service(HttpClient)
self.client: WorkerManagerClient = client or WorkerManagerClient(
http_client=http_client.session,
base_url=services.s.config.c.worker_manager_url,
)
self.client = client or services.cast_service(ApiClient).client
self.worker_items: WorkerItems = {}
self.worker_resources: dict[str, ResourceData] = {}
self.worker_resources_providers: dict[str, ResourcesProvider] = {}
Expand Down
3 changes: 3 additions & 0 deletions tests/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ class services_manager(ServicesManagerConfig):

class worker(WorkerConfig):
job_store_cls = "MemoryJobStore"

class job_state:
auto_flush = False
21 changes: 21 additions & 0 deletions tests/worker/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@
from saturn_engine.worker.executors.queue import ExecutorQueue
from saturn_engine.worker.pipeline_message import PipelineMessage
from saturn_engine.worker.resources.provider import ResourcesProvider
from saturn_engine.worker.services import MinimalService
from saturn_engine.worker.services import Services
from saturn_engine.worker.services.api_client import ApiClient
from saturn_engine.worker.services.manager import ServicesManager
from saturn_engine.worker.services.rabbitmq import RabbitMQService
from saturn_engine.worker.topics import MemoryTopic
from saturn_engine.worker.topics import Topic
from saturn_engine.worker.topics.memory import reset as reset_memory_queues
from saturn_engine.worker.work_manager import WorkManager
from tests.utils import HttpClientMock
from tests.utils import TimeForwardLoop
from tests.utils.metrics import MetricsCapture
from tests.utils.span_exporter import InMemorySpanExporter
Expand Down Expand Up @@ -401,3 +404,21 @@ async def executor(services_manager: ServicesManager) -> AsyncIterator[Executor]
)
yield executor
await executor.close()


@pytest.fixture
async def fake_http_client_service(
http_client_mock: HttpClientMock,
services_manager: ServicesManager,
) -> t.Any:
class FakeHttpClient(MinimalService):
name = "http_client"

async def open(self) -> None:
self.session = http_client_mock.client()

async def close(self) -> None:
pass

await services_manager._reload_service(FakeHttpClient)
await services_manager._reload_service(ApiClient)
Empty file.
Loading

0 comments on commit 1705bc3

Please sign in to comment.