Skip to content

Commit

Permalink
fix typing for SQlAlchemy 2.0 migration
Browse files Browse the repository at this point in the history
  • Loading branch information
Benoit Doyon authored and isra17 committed Apr 3, 2024
1 parent 612b999 commit 94dd172
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 47 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ disallow_incomplete_defs = true
disallow_untyped_defs = true
disallow_untyped_calls = true
namespace_packages = true
plugins = ["sqlalchemy.ext.mypy.plugin", "mypy_typing_asserts.mypy_plugin"]
plugins = ["mypy_typing_asserts.mypy_plugin"]

[[tool.mypy.overrides]]
module = [
Expand Down
2 changes: 1 addition & 1 deletion src/saturn_engine/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import Base
from .job import Job
from .job import JobCursorState
from .job_cursor_state import JobCursorState
from .queue import Queue

__all__ = [
Expand Down
6 changes: 4 additions & 2 deletions src/saturn_engine/models/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

Base = declarative_base()

class Base(DeclarativeBase):
pass
36 changes: 12 additions & 24 deletions src/saturn_engine/models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from datetime import datetime

from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import backref
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.sql.sqltypes import Text
from sqlalchemy.types import JSON

import saturn_engine.models.queue as queue_model
from saturn_engine.core import Cursor
from saturn_engine.core import JobId
from saturn_engine.core.api import JobItem
Expand All @@ -20,13 +18,6 @@
from .types import UTCDateTime


class JobCursorState(Base):
__tablename__ = "job_cursor_states"
job_definition_name: Mapped[str] = Column(Text, primary_key=True)
cursor: Mapped[str] = Column(Text, primary_key=True)
state: Mapped[dict] = Column(JSON, nullable=False)


class Job(Base):
__tablename__ = "jobs"
__table_args__ = (
Expand All @@ -36,18 +27,18 @@ class Job(Base):
),
)

name: Mapped[str] = Column(Text, primary_key=True)
cursor = Column(Text, nullable=True)
completed_at: Mapped[Optional[datetime]] = Column(UTCDateTime, nullable=True) # type: ignore[assignment] # noqa: B950
started_at: Mapped[datetime] = Column(UTCDateTime, nullable=False) # type: ignore[assignment] # noqa: B950
queue_name: Mapped[str] = Column(Text, ForeignKey("queues.name"), nullable=False)
error = Column(Text, nullable=True)
queue: Mapped["Queue"] = relationship(
"Queue",
name: Mapped[str] = mapped_column(primary_key=True)
cursor: Mapped[Optional[str]] = mapped_column(nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(UTCDateTime, nullable=True)
started_at: Mapped[datetime] = mapped_column(UTCDateTime, nullable=False)
queue_name: Mapped[str] = mapped_column(ForeignKey("queues.name"), nullable=False)
error: Mapped[Optional[str]] = mapped_column(nullable=True)
queue: Mapped[queue_model.Queue] = relationship(
lambda: queue_model.Queue,
uselist=False,
backref=backref("job", uselist=False),
back_populates="job",
)
job_definition_name: Mapped[Optional[str]] = Column(Text, nullable=True)
job_definition_name: Mapped[Optional[str]] = mapped_column(nullable=True)

def __init__(
self,
Expand Down Expand Up @@ -82,6 +73,3 @@ def as_core_item(self) -> JobItem:
error=self.error,
**queue_args, # type: ignore[arg-type]
)


from .queue import Queue
13 changes: 13 additions & 0 deletions src/saturn_engine/models/job_cursor_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.sql.sqltypes import Text
from sqlalchemy.types import JSON

from .base import Base


class JobCursorState(Base):
__tablename__ = "job_cursor_states"
job_definition_name: Mapped[str] = mapped_column(Text, primary_key=True)
cursor: Mapped[str] = mapped_column(Text, primary_key=True)
state: Mapped[dict] = mapped_column(JSON, nullable=False)
29 changes: 14 additions & 15 deletions src/saturn_engine/models/queue.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from typing import ClassVar
from typing import Optional

import dataclasses

from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import Index
from sqlalchemy import Text
from sqlalchemy import text
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.sql.sqltypes import DateTime

import saturn_engine.models.job as job_model
from saturn_engine.core import Cursor
from saturn_engine.core.api import QueueItemState
from saturn_engine.core.api import QueueItemWithState
from saturn_engine.core.types import JobId
from saturn_engine.worker_manager.config.static_definitions import StaticDefinitions

from .base import Base


class Queue(Base):
__allow_unmapped__ = True

__tablename__ = "queues"
__table_args__ = (
Index(
Expand All @@ -29,12 +31,12 @@ class Queue(Base):
),
)

name: Mapped[str] = Column(Text, primary_key=True)
assigned_at = Column(DateTime(timezone=True))
assigned_to = Column(Text)
job: ClassVar[Optional["Job"]] = None
_queue_item: ClassVar[Optional[QueueItemWithState]] = None
enabled = Column(Boolean, default=True, nullable=False)
name: Mapped[str] = mapped_column(primary_key=True)
assigned_at = mapped_column(DateTime(timezone=True))
assigned_to: Mapped[Optional[str]] = mapped_column()
job = relationship(lambda: job_model.Job, uselist=False, back_populates="queue")
_queue_item: Optional[QueueItemWithState] = None
enabled: Mapped[bool] = mapped_column(default=True, nullable=False)

@property
def queue_item(self) -> QueueItemWithState:
Expand All @@ -53,15 +55,12 @@ def join_definitions(self, static_definitions: StaticDefinitions) -> None:
static_definitions.job_definitions[
self.job.job_definition_name
].template,
name=self.name,
name=JobId(self.name),
).with_state(state)
else:
self._queue_item = dataclasses.replace(
static_definitions.jobs[self.job.name],
name=self.name,
name=JobId(self.name),
).with_state(state)
else:
raise NotImplementedError("Only support Job queue")


from .job import Job
6 changes: 3 additions & 3 deletions src/saturn_engine/stores/jobs_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from saturn_engine.core.api import QueueItem
from saturn_engine.core.api import StartJobInput
from saturn_engine.models import Job
from saturn_engine.models.job import JobCursorState
from saturn_engine.models.job_cursor_state import JobCursorState
from saturn_engine.models.queue import Queue
from saturn_engine.stores import queues_store
from saturn_engine.utils import utcnow
Expand Down Expand Up @@ -172,9 +172,9 @@ def sync_jobs_states(
session.execute(cursors_stmt)

if jobs_values:
session.bulk_update_mappings(Job, jobs_values)
session.execute(update(Job), jobs_values)
if queues_values:
session.bulk_update_mappings(Queue, queues_values)
session.execute(update(Queue), queues_values)


CursorsStates = dict[JobId, dict[Cursor, t.Optional[dict]]]
Expand Down
4 changes: 3 additions & 1 deletion src/saturn_engine/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
AnySession = AnySyncSession


def upsert(session: AnySession) -> t.Callable[[object], postgresql.Insert]:
def upsert(
session: AnySession,
) -> t.Callable[[t.Any], postgresql.Insert | sqlite.Insert]:
if not session.bind:
raise ValueError("Session is unbound")

Expand Down

0 comments on commit 94dd172

Please sign in to comment.