diff --git a/pyproject.toml b/pyproject.toml index 27fe6d79..3a97c56d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,7 +133,7 @@ disallow_incomplete_defs = true disallow_untyped_defs = true disallow_untyped_calls = true namespace_packages = true -plugins = ["sqlalchemy.ext.mypy.plugin", "mypy_typing_asserts.mypy_plugin"] +plugins = ["mypy_typing_asserts.mypy_plugin"] [[tool.mypy.overrides]] module = [ diff --git a/src/saturn_engine/models/__init__.py b/src/saturn_engine/models/__init__.py index 0845f9cb..ee4a6ae8 100644 --- a/src/saturn_engine/models/__init__.py +++ b/src/saturn_engine/models/__init__.py @@ -1,6 +1,6 @@ from .base import Base from .job import Job -from .job import JobCursorState +from .job_cursor_state import JobCursorState from .queue import Queue __all__ = [ diff --git a/src/saturn_engine/models/base.py b/src/saturn_engine/models/base.py index 59be7030..fa2b68a5 100644 --- a/src/saturn_engine/models/base.py +++ b/src/saturn_engine/models/base.py @@ -1,3 +1,5 @@ -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase -Base = declarative_base() + +class Base(DeclarativeBase): + pass diff --git a/src/saturn_engine/models/job.py b/src/saturn_engine/models/job.py index 6121982d..65511681 100644 --- a/src/saturn_engine/models/job.py +++ b/src/saturn_engine/models/job.py @@ -3,14 +3,12 @@ from datetime import datetime from sqlalchemy import CheckConstraint -from sqlalchemy import Column from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped -from sqlalchemy.orm import backref +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship -from sqlalchemy.sql.sqltypes import Text -from sqlalchemy.types import JSON +import saturn_engine.models.queue as queue_model from saturn_engine.core import Cursor from saturn_engine.core import JobId from saturn_engine.core.api import JobItem @@ -20,13 +18,6 @@ from .types import UTCDateTime -class JobCursorState(Base): - __tablename__ = "job_cursor_states" - job_definition_name: Mapped[str] = Column(Text, primary_key=True) - cursor: Mapped[str] = Column(Text, primary_key=True) - state: Mapped[dict] = Column(JSON, nullable=False) - - class Job(Base): __tablename__ = "jobs" __table_args__ = ( @@ -36,18 +27,18 @@ class Job(Base): ), ) - name: Mapped[str] = Column(Text, primary_key=True) - cursor = Column(Text, nullable=True) - completed_at: Mapped[Optional[datetime]] = Column(UTCDateTime, nullable=True) # type: ignore[assignment] # noqa: B950 - started_at: Mapped[datetime] = Column(UTCDateTime, nullable=False) # type: ignore[assignment] # noqa: B950 - queue_name: Mapped[str] = Column(Text, ForeignKey("queues.name"), nullable=False) - error = Column(Text, nullable=True) - queue: Mapped["Queue"] = relationship( - "Queue", + name: Mapped[str] = mapped_column(primary_key=True) + cursor: Mapped[Optional[str]] = mapped_column(nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(UTCDateTime, nullable=True) + started_at: Mapped[datetime] = mapped_column(UTCDateTime, nullable=False) + queue_name: Mapped[str] = mapped_column(ForeignKey("queues.name"), nullable=False) + error: Mapped[Optional[str]] = mapped_column(nullable=True) + queue: Mapped[queue_model.Queue] = relationship( + lambda: queue_model.Queue, uselist=False, - backref=backref("job", uselist=False), + back_populates="job", ) - job_definition_name: Mapped[Optional[str]] = Column(Text, nullable=True) + job_definition_name: Mapped[Optional[str]] = mapped_column(nullable=True) def __init__( self, @@ -82,6 +73,3 @@ def as_core_item(self) -> JobItem: error=self.error, **queue_args, # type: ignore[arg-type] ) - - -from .queue import Queue diff --git a/src/saturn_engine/models/job_cursor_state.py b/src/saturn_engine/models/job_cursor_state.py new file mode 100644 index 00000000..5f13698e --- /dev/null +++ b/src/saturn_engine/models/job_cursor_state.py @@ -0,0 +1,13 @@ +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.sqltypes import Text +from sqlalchemy.types import JSON + +from .base import Base + + +class JobCursorState(Base): + __tablename__ = "job_cursor_states" + job_definition_name: Mapped[str] = mapped_column(Text, primary_key=True) + cursor: Mapped[str] = mapped_column(Text, primary_key=True) + state: Mapped[dict] = mapped_column(JSON, nullable=False) diff --git a/src/saturn_engine/models/queue.py b/src/saturn_engine/models/queue.py index 74ea9850..d974ee86 100644 --- a/src/saturn_engine/models/queue.py +++ b/src/saturn_engine/models/queue.py @@ -1,25 +1,27 @@ -from typing import ClassVar from typing import Optional import dataclasses -from sqlalchemy import Boolean -from sqlalchemy import Column from sqlalchemy import Index -from sqlalchemy import Text from sqlalchemy import text from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship from sqlalchemy.sql.sqltypes import DateTime +import saturn_engine.models.job as job_model from saturn_engine.core import Cursor from saturn_engine.core.api import QueueItemState from saturn_engine.core.api import QueueItemWithState +from saturn_engine.core.types import JobId from saturn_engine.worker_manager.config.static_definitions import StaticDefinitions from .base import Base class Queue(Base): + __allow_unmapped__ = True + __tablename__ = "queues" __table_args__ = ( Index( @@ -29,12 +31,12 @@ class Queue(Base): ), ) - name: Mapped[str] = Column(Text, primary_key=True) - assigned_at = Column(DateTime(timezone=True)) - assigned_to = Column(Text) - job: ClassVar[Optional["Job"]] = None - _queue_item: ClassVar[Optional[QueueItemWithState]] = None - enabled = Column(Boolean, default=True, nullable=False) + name: Mapped[str] = mapped_column(primary_key=True) + assigned_at = mapped_column(DateTime(timezone=True)) + assigned_to: Mapped[Optional[str]] = mapped_column() + job = relationship(lambda: job_model.Job, uselist=False, back_populates="queue") + _queue_item: Optional[QueueItemWithState] = None + enabled: Mapped[bool] = mapped_column(default=True, nullable=False) @property def queue_item(self) -> QueueItemWithState: @@ -53,15 +55,12 @@ def join_definitions(self, static_definitions: StaticDefinitions) -> None: static_definitions.job_definitions[ self.job.job_definition_name ].template, - name=self.name, + name=JobId(self.name), ).with_state(state) else: self._queue_item = dataclasses.replace( static_definitions.jobs[self.job.name], - name=self.name, + name=JobId(self.name), ).with_state(state) else: raise NotImplementedError("Only support Job queue") - - -from .job import Job diff --git a/src/saturn_engine/stores/jobs_store.py b/src/saturn_engine/stores/jobs_store.py index a547553b..2a40b3a4 100644 --- a/src/saturn_engine/stores/jobs_store.py +++ b/src/saturn_engine/stores/jobs_store.py @@ -15,7 +15,7 @@ from saturn_engine.core.api import QueueItem from saturn_engine.core.api import StartJobInput from saturn_engine.models import Job -from saturn_engine.models.job import JobCursorState +from saturn_engine.models.job_cursor_state import JobCursorState from saturn_engine.models.queue import Queue from saturn_engine.stores import queues_store from saturn_engine.utils import utcnow @@ -172,9 +172,9 @@ def sync_jobs_states( session.execute(cursors_stmt) if jobs_values: - session.bulk_update_mappings(Job, jobs_values) + session.execute(update(Job), jobs_values) if queues_values: - session.bulk_update_mappings(Queue, queues_values) + session.execute(update(Queue), queues_values) CursorsStates = dict[JobId, dict[Cursor, t.Optional[dict]]] diff --git a/src/saturn_engine/utils/sqlalchemy.py b/src/saturn_engine/utils/sqlalchemy.py index 4e1a2784..693c2b26 100644 --- a/src/saturn_engine/utils/sqlalchemy.py +++ b/src/saturn_engine/utils/sqlalchemy.py @@ -11,7 +11,9 @@ AnySession = AnySyncSession -def upsert(session: AnySession) -> t.Callable[[object], postgresql.Insert]: +def upsert( + session: AnySession, +) -> t.Callable[[t.Any], postgresql.Insert | sqlite.Insert]: if not session.bind: raise ValueError("Session is unbound")