Skip to content

Commit

Permalink
Add runner --standalone option
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Jul 26, 2023
1 parent 7ec1ede commit 39c1e74
Show file tree
Hide file tree
Showing 15 changed files with 230 additions and 43 deletions.
2 changes: 1 addition & 1 deletion bin/worker
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

set -eo pipefail

exec python -m saturn_engine.worker.runner
exec python -m saturn_engine.worker.runner "$@"
3 changes: 2 additions & 1 deletion example/worker
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ ROOT_DIR="$( cd ${DIR}/.. && pwd )"
export SATURN_ENV="${SATURN_ENV:-development}"
export SATURN_SETTINGS="example.settings.config"
export PYTHONPATH="${PYTHONPATH}:$DIR/src"
exec bash ${ROOT_DIR}/bin/worker
export SATURN_STATIC_DEFINITIONS_DIRS="${SATURN_STATIC_DEFINITIONS_DIRS:-${DIR}/definitions}"
exec bash ${ROOT_DIR}/bin/worker --standalone
19 changes: 18 additions & 1 deletion src/saturn_engine/client/worker_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import abc
import socket

import aiohttp
Expand All @@ -15,7 +16,23 @@
from saturn_engine.utils.options import fromdict


class WorkerManagerClient:
class AbstractWorkerManagerClient(abc.ABC):
@abc.abstractmethod
async def lock(self) -> LockResponse:
pass

@abc.abstractmethod
async def sync(self, sync: JobsStatesSyncInput) -> JobsStatesSyncResponse:
pass

@abc.abstractmethod
async def fetch_cursors_states(
self, cursors: FetchCursorsStatesInput
) -> FetchCursorsStatesResponse:
pass


class WorkerManagerClient(AbstractWorkerManagerClient):
def __init__(
self,
http_client: aiohttp.ClientSession,
Expand Down
5 changes: 2 additions & 3 deletions src/saturn_engine/config_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ class RedisConfig:

@dataclasses.dataclass
class ServicesManagerConfig:
# Base services, should not be overriden.
base_services: list[str]
# Services to load
services: list[str]
# Check services type dependancies match loaded services. `False` value is
# needed to load fake services.
strict_services: bool


@dataclasses.dataclass
Expand Down
10 changes: 9 additions & 1 deletion src/saturn_engine/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ class config(SaturnConfig):
)

class services_manager(ServicesManagerConfig):
base_services = [
"saturn_engine.worker.services.http_client.HttpClient",
"saturn_engine.worker.services.api_client.ApiClient",
"saturn_engine.worker.services.job_state.service.JobStateService",
]
services = [
"saturn_engine.worker.services.tracing.Tracer",
"saturn_engine.worker.services.metrics.Metrics",
"saturn_engine.worker.services.loggers.Logger",
"saturn_engine.worker.services.rabbitmq.RabbitMQService",
]
strict_services = True

class rabbitmq(RabbitMQConfig):
url = os.environ.get("SATURN_AMQP_URL", "amqp://127.0.0.1/")
Expand Down Expand Up @@ -53,6 +57,10 @@ class redis(RedisConfig):
class tracer:
rate: float = 0.0

class databases:
engines: dict[str, t.Any] = {}
sync_engines: dict[str, t.Any] = {}


class client_config(config):
class services_manager(config.services_manager):
Expand Down
5 changes: 1 addition & 4 deletions src/saturn_engine/worker/resources/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from saturn_engine.worker.resources.manager import ResourceKey
from saturn_engine.worker.resources.manager import ResourceRateLimit
from saturn_engine.worker.services import Services
from saturn_engine.worker.services.tasks_runner import TasksRunnerService


@dataclasses.dataclass
Expand Down Expand Up @@ -104,9 +103,7 @@ async def open(self) -> None:

async def _open(self) -> None:
await super()._open()
self._sync_task = self.services.cast_service(
TasksRunnerService
).runner.create_task(
self._sync_task = self.services.s.tasks_runner.create_task(
self.poller(), name=f"provider-sync({self.definition.name})"
)

Expand Down
33 changes: 31 additions & 2 deletions src/saturn_engine/worker/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import logging
import signal
import sys

from saturn_engine.config import default_config_with_env

Expand All @@ -21,8 +22,34 @@ def unset_term_handler() -> None:
loop.remove_signal_handler(getattr(signal, signame))


async def async_main() -> None:
async def async_main(standalone: bool = False) -> None:
config = default_config_with_env()
if standalone:
base_services = config.c.services_manager.base_services.copy()
services = config.c.services_manager.services.copy()

try:
base_services.remove("saturn_engine.worker.services.api_client.ApiClient")
services.remove("saturn_engine.worker.services.databases.Databases")
except ValueError:
pass
base_services = [
"saturn_engine.worker.services.databases.Databases",
"saturn_engine.worker.services.api_client.StandaloneApiClient",
] + base_services

engines = config.r.get("databases", {}).get("sync_engines", {}).copy()
engines.setdefault("default", "sqlite:///standalone.db")
config = config.load_object(
{
"services_manager": {
"base_services": base_services,
"services": services,
},
"databases": {"sync_engines": engines},
}
)

broker = Broker(config)

def stop() -> None:
Expand All @@ -39,8 +66,10 @@ def main() -> None:
)
logger = logging.getLogger(__name__)

standalone = "--standalone" in sys.argv

loop = asyncio.get_event_loop()
asyncio.run(async_main())
asyncio.run(async_main(standalone=standalone))
if tasks := asyncio.all_tasks(loop):
logger.error("Leftover tasks: %s", tasks)

Expand Down
12 changes: 3 additions & 9 deletions src/saturn_engine/worker/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from saturn_engine.worker.resources.manager import ResourcesManager

from .hooks import Hooks
from .tasks_runner import TasksRunnerService

__all__ = ("Config", "Service", "BaseServices")

Expand All @@ -23,6 +24,7 @@ class BaseServices:
config: Config
resources_manager: ResourcesManager
hooks: Hooks
tasks_runner: TasksRunnerService


TServices = TypeVar("TServices", bound=BaseServices)
Expand Down Expand Up @@ -69,9 +71,8 @@ async def close(self) -> None:


class ServicesNamespace(Namespace, Generic[T]):
def __init__(self, *args: Any, strict: bool, **kwargs: Any):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.strict = strict
self.s: T = cast(T, self)

def cast(self, interface: Type[U]) -> "ServicesNamespace[U]":
Expand All @@ -91,13 +92,6 @@ def cast_service(self, service_cls: Type[TService]) -> TService:
if not service:
raise ValueError(f"Namespace missing '{typ_import_name}' service")

if self.strict and service_cls is not service.__class__:
dependency_name = extra_inspect.get_import_name(service.__class__)
raise ValueError(
f"Service '{name}' expected to be '{typ_import_name}', "
f"got '{dependency_name}'"
)

return cast(TService, service)


Expand Down
38 changes: 38 additions & 0 deletions src/saturn_engine/worker/services/api_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import asyncio

from sqlalchemy.orm import sessionmaker

from saturn_engine.client.worker_manager import WorkerManagerClient
from saturn_engine.worker.services.databases import Databases
from saturn_engine.worker.services.tasks_runner import TasksRunnerService
from saturn_engine.worker.worker_manager import StandaloneWorkerManagerClient

from . import BaseServices
from . import Service
Expand All @@ -22,3 +29,34 @@ async def open(self) -> None:
base_url=self.services.config.c.worker_manager_url,
worker_id=self.services.config.c.worker_id,
)


class StandaloneServices(BaseServices):
databases: Databases


class StandaloneApiClient(Service[StandaloneServices, None]):
name = "api_client"

Services = StandaloneServices

client: StandaloneWorkerManagerClient

SYNC_DELAY = 60

async def open(self) -> None:
self.client = StandaloneWorkerManagerClient(
config=self.services.config,
sessionmaker=sessionmaker(self.services.databases.sync_engine()),
)

await self.client.init_db()
await self.client.sync_jobs()
self.services.tasks_runner.create_task(
self._sync_jobs(), name="StandaloneClient.sync-jobs"
)

async def _sync_jobs(self) -> None:
while True:
await asyncio.sleep(self.SYNC_DELAY)
await self.client.sync_jobs()
14 changes: 3 additions & 11 deletions src/saturn_engine/worker/services/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import chain

from saturn_engine.utils import inspect as extra_inspect
from saturn_engine.worker.services.tasks_runner import TasksRunnerService

from ..resources.manager import ResourcesManager
from . import BaseServices
Expand All @@ -16,19 +17,18 @@

class ServicesManager:
def __init__(self, config: Config) -> None:
self.strict = config.c.services_manager.strict_services
self.services: Services = ServicesNamespace(
config=config,
hooks=Hooks(),
resources_manager=ResourcesManager(),
strict=self.strict,
tasks_runner=TasksRunnerService(),
)
self.loaded_services: list[Service] = []
self.is_opened = False

# Load optional services based on config.
for service_cls_path in chain(
BASE_SERVICES, config.c.services_manager.services
config.c.services_manager.base_services, config.c.services_manager.services
):
service_cls = extra_inspect.import_name(service_cls_path)
self._load_service(service_cls)
Expand Down Expand Up @@ -81,11 +81,3 @@ async def _reload_service(self, service_cls: Type[TService]) -> TService:

def has_loaded(self, service_cls: Type[TService]) -> bool:
return service_cls.name in self.services


BASE_SERVICES: list[str] = [
"saturn_engine.worker.services.http_client.HttpClient",
"saturn_engine.worker.services.api_client.ApiClient",
"saturn_engine.worker.services.tasks_runner.TasksRunnerService",
"saturn_engine.worker.services.job_state.service.JobStateService",
]
13 changes: 9 additions & 4 deletions src/saturn_engine/worker/services/tasks_runner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from saturn_engine.utils.asyncutils import TasksGroupRunner
import typing as t

import asyncio

from . import MinimalService
from saturn_engine.utils.asyncutils import TasksGroupRunner


class TasksRunnerService(MinimalService):
class TasksRunnerService:
name = "tasks_runner"

async def open(self) -> None:
def __init__(self) -> None:
self.runner = TasksGroupRunner(name="tasks-runner-service")
self.runner.start()

async def close(self) -> None:
await self.runner.close()

def create_task(self, coro: t.Coroutine, *, name: str) -> asyncio.Task:
return self.runner.create_task(coro, name=name)
Loading

0 comments on commit 39c1e74

Please sign in to comment.