diff --git a/src/saturn_engine/client/worker_manager.py b/src/saturn_engine/client/worker_manager.py index ceec4989..64e98cc0 100644 --- a/src/saturn_engine/client/worker_manager.py +++ b/src/saturn_engine/client/worker_manager.py @@ -21,14 +21,16 @@ def __init__( http_client: aiohttp.ClientSession, base_url: str, worker_id: Optional[str] = None, + selector: Optional[str] = None, ) -> None: self.worker_id: str = worker_id or socket.gethostname() + self.selector: Optional[str] = selector self.http_client = http_client self.base_url = base_url async def lock(self) -> LockResponse: lock_url = urlcat(self.base_url, "api/lock") - json = asdict(LockInput(worker_id=self.worker_id)) + json = asdict(LockInput(worker_id=self.worker_id, selector=self.selector)) async with self.http_client.post(lock_url, json=json) as response: return fromdict(await response.json(), LockResponse) diff --git a/src/saturn_engine/config_definitions.py b/src/saturn_engine/config_definitions.py index b7b0162a..a5e1eaa4 100644 --- a/src/saturn_engine/config_definitions.py +++ b/src/saturn_engine/config_definitions.py @@ -50,6 +50,8 @@ class SaturnConfig: # Worker Manager URL used by clients and workers. worker_id: str worker_manager_url: str + # If set, select jobs matching the selector regex. + selector: t.Optional[str] services_manager: ServicesManagerConfig worker_manager: WorkerManagerConfig rabbitmq: RabbitMQConfig diff --git a/src/saturn_engine/core/api.py b/src/saturn_engine/core/api.py index f6994f77..436b8938 100644 --- a/src/saturn_engine/core/api.py +++ b/src/saturn_engine/core/api.py @@ -91,6 +91,7 @@ class LockResponse: @dataclasses.dataclass class LockInput: worker_id: str + selector: t.Optional[str] = None @dataclasses.dataclass diff --git a/src/saturn_engine/default_config.py b/src/saturn_engine/default_config.py index cea56e1e..0011a4c5 100644 --- a/src/saturn_engine/default_config.py +++ b/src/saturn_engine/default_config.py @@ -14,6 +14,7 @@ class config(SaturnConfig): env = Env(os.environ.get("SATURN_ENV", "development")) worker_id = socket.gethostname() + selector: t.Optional[str] = os.environ.get("SATURN_SELECTOR") worker_manager_url = os.environ.get( "SATURN_WORKER_MANAGER_URL", "http://127.0.0.1:5000" ) diff --git a/src/saturn_engine/stores/queues_store.py b/src/saturn_engine/stores/queues_store.py index a5f87dad..82626f9f 100644 --- a/src/saturn_engine/stores/queues_store.py +++ b/src/saturn_engine/stores/queues_store.py @@ -1,3 +1,5 @@ +import typing as t + import datetime from sqlalchemy import or_ @@ -25,7 +27,11 @@ def get_assigned_queues( session: AnySyncSession, worker_id: str, assigned_after: datetime.datetime, + selector: t.Optional[str] = None, ) -> list[Queue]: + extra_filters = [] + if selector: + extra_filters.append(Queue.name.regexp_match(selector)) assigned_jobs: list[Queue] = ( session.execute( select(Queue) @@ -34,6 +40,7 @@ def get_assigned_queues( Queue.enabled.is_(True), Queue.assigned_to == worker_id, Queue.assigned_at >= assigned_after, + *extra_filters, ) .order_by(Queue.name) ) @@ -47,8 +54,12 @@ def get_unassigned_queues( *, session: AnySyncSession, assigned_before: datetime.datetime, + selector: t.Optional[str] = None, limit: int, ) -> list[Queue]: + extra_filters = [] + if selector: + extra_filters.append(Queue.name.regexp_match(selector)) unassigned_queues: list[Queue] = ( session.execute( select(Queue) @@ -59,6 +70,7 @@ def get_unassigned_queues( Queue.assigned_at.is_(None), Queue.assigned_at < assigned_before, ), + *extra_filters, ) .limit(limit) ) diff --git a/src/saturn_engine/worker/services/api_client.py b/src/saturn_engine/worker/services/api_client.py index 98aa2f55..44bd8a23 100644 --- a/src/saturn_engine/worker/services/api_client.py +++ b/src/saturn_engine/worker/services/api_client.py @@ -21,4 +21,5 @@ async def open(self) -> None: http_client=self.services.http_client.session, base_url=self.services.config.c.worker_manager_url, worker_id=self.services.config.c.worker_id, + selector=self.services.config.c.selector, ) diff --git a/src/saturn_engine/worker_manager/services/lock.py b/src/saturn_engine/worker_manager/services/lock.py index e76b1e87..e7a4d618 100644 --- a/src/saturn_engine/worker_manager/services/lock.py +++ b/src/saturn_engine/worker_manager/services/lock.py @@ -34,6 +34,7 @@ def lock_jobs( queues_store.get_assigned_queues( session=session, worker_id=lock_input.worker_id, + selector=lock_input.selector, assigned_after=assignation_expiration_cutoff, ) ) @@ -52,6 +53,7 @@ def lock_jobs( session=session, assigned_before=assignation_expiration_cutoff, limit=max_assigned_items - len(assigned_items), + selector=lock_input.selector, ) ) diff --git a/tests/worker_manager/api/test_lock.py b/tests/worker_manager/api/test_lock.py index a69b52c9..4c404176 100644 --- a/tests/worker_manager/api/test_lock.py +++ b/tests/worker_manager/api/test_lock.py @@ -177,6 +177,10 @@ def create_job( "job-10", } + resp = client.post("/api/lock", json={"worker_id": "worker-2", "selector": "j.*-9"}) + assert resp.status_code == 200 + assert ids(resp) == {"job-9"} + def test_api_lock_with_resources( client: FlaskClient,