Skip to content

Commit

Permalink
fix typing for SQlAlchemy 2.0 migration
Browse files Browse the repository at this point in the history
  • Loading branch information
Benoit Doyon committed Mar 22, 2024
1 parent 7784358 commit 524830c
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 50 deletions.
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ packages = [

[tool.poetry.dependencies]
python = "^3.10"
SQLAlchemy = {extras = ["mypy", "asyncio"], version = ">=1.4.29"}
SQLAlchemy = {extras = ["mypy"], version = "^2.0.28"}
aiosqlite = "^0.17.0"
asyncstdlib = "^3.10.2"
aio-pika = ">=8.0"
Expand Down Expand Up @@ -133,7 +133,7 @@ disallow_incomplete_defs = true
disallow_untyped_defs = true
disallow_untyped_calls = true
namespace_packages = true
plugins = ["sqlalchemy.ext.mypy.plugin", "mypy_typing_asserts.mypy_plugin"]
plugins = ["mypy_typing_asserts.mypy_plugin"]

[[tool.mypy.overrides]]
module = [
Expand Down
2 changes: 1 addition & 1 deletion src/saturn_engine/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import Base
from .job import Job
from .job import JobCursorState
from .job_cursor_state import JobCursorState
from .queue import Queue

__all__ = [
Expand Down
6 changes: 4 additions & 2 deletions src/saturn_engine/models/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

Base = declarative_base()

class Base(DeclarativeBase):
pass
36 changes: 12 additions & 24 deletions src/saturn_engine/models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from datetime import datetime

from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import backref
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.sql.sqltypes import Text
from sqlalchemy.types import JSON

import saturn_engine.models.queue as queue_model
from saturn_engine.core import Cursor
from saturn_engine.core import JobId
from saturn_engine.core.api import JobItem
Expand All @@ -20,13 +18,6 @@
from .types import UTCDateTime


class JobCursorState(Base):
__tablename__ = "job_cursor_states"
job_definition_name: Mapped[str] = Column(Text, primary_key=True)
cursor: Mapped[str] = Column(Text, primary_key=True)
state: Mapped[dict] = Column(JSON, nullable=False)


class Job(Base):
__tablename__ = "jobs"
__table_args__ = (
Expand All @@ -36,18 +27,18 @@ class Job(Base):
),
)

name: Mapped[str] = Column(Text, primary_key=True)
cursor = Column(Text, nullable=True)
completed_at: Mapped[Optional[datetime]] = Column(UTCDateTime, nullable=True) # type: ignore[assignment] # noqa: B950
started_at: Mapped[datetime] = Column(UTCDateTime, nullable=False) # type: ignore[assignment] # noqa: B950
queue_name: Mapped[str] = Column(Text, ForeignKey("queues.name"), nullable=False)
error = Column(Text, nullable=True)
queue: Mapped["Queue"] = relationship(
"Queue",
name: Mapped[str] = mapped_column(primary_key=True)
cursor: Mapped[Optional[str]] = mapped_column(nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(UTCDateTime, nullable=True)
started_at: Mapped[datetime] = mapped_column(UTCDateTime, nullable=False)
queue_name: Mapped[str] = mapped_column(ForeignKey("queues.name"), nullable=False)
error: Mapped[Optional[str]] = mapped_column(nullable=True)
queue: Mapped[queue_model.Queue] = relationship(
lambda: queue_model.Queue,
uselist=False,
backref=backref("job", uselist=False),
back_populates="job",
)
job_definition_name: Mapped[Optional[str]] = Column(Text, nullable=True)
job_definition_name: Mapped[Optional[str]] = mapped_column(nullable=True)

def __init__(
self,
Expand Down Expand Up @@ -82,6 +73,3 @@ def as_core_item(self) -> JobItem:
error=self.error,
**queue_args, # type: ignore[arg-type]
)


from .queue import Queue
13 changes: 13 additions & 0 deletions src/saturn_engine/models/job_cursor_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.sql.sqltypes import Text
from sqlalchemy.types import JSON

from .base import Base


class JobCursorState(Base):
__tablename__ = "job_cursor_states"
job_definition_name: Mapped[str] = mapped_column(Text, primary_key=True)
cursor: Mapped[str] = mapped_column(Text, primary_key=True)
state: Mapped[dict] = mapped_column(JSON, nullable=False)
29 changes: 14 additions & 15 deletions src/saturn_engine/models/queue.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from typing import ClassVar
from typing import Optional

import dataclasses

from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import Index
from sqlalchemy import Text
from sqlalchemy import text
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.sql.sqltypes import DateTime

import saturn_engine.models.job as job_model
from saturn_engine.core import Cursor
from saturn_engine.core.api import QueueItemState
from saturn_engine.core.api import QueueItemWithState
from saturn_engine.core.types import JobId
from saturn_engine.worker_manager.config.static_definitions import StaticDefinitions

from .base import Base


class Queue(Base):
__allow_unmapped__ = True

__tablename__ = "queues"
__table_args__ = (
Index(
Expand All @@ -29,12 +31,12 @@ class Queue(Base):
),
)

name: Mapped[str] = Column(Text, primary_key=True)
assigned_at = Column(DateTime(timezone=True))
assigned_to = Column(Text)
job: ClassVar[Optional["Job"]] = None
_queue_item: ClassVar[Optional[QueueItemWithState]] = None
enabled = Column(Boolean, default=True, nullable=False)
name: Mapped[str] = mapped_column(primary_key=True)
assigned_at = mapped_column(DateTime(timezone=True))
assigned_to: Mapped[Optional[str]] = mapped_column()
job = relationship(lambda: job_model.Job, uselist=False, back_populates="queue")
_queue_item: Optional[QueueItemWithState] = None
enabled: Mapped[bool] = mapped_column(default=True, nullable=False)

@property
def queue_item(self) -> QueueItemWithState:
Expand All @@ -53,15 +55,12 @@ def join_definitions(self, static_definitions: StaticDefinitions) -> None:
static_definitions.job_definitions[
self.job.job_definition_name
].template,
name=self.name,
name=JobId(self.name),
).with_state(state)
else:
self._queue_item = dataclasses.replace(
static_definitions.jobs[self.job.name],
name=self.name,
name=JobId(self.name),
).with_state(state)
else:
raise NotImplementedError("Only support Job queue")


from .job import Job
6 changes: 3 additions & 3 deletions src/saturn_engine/stores/jobs_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from saturn_engine.core.api import QueueItem
from saturn_engine.core.api import StartJobInput
from saturn_engine.models import Job
from saturn_engine.models.job import JobCursorState
from saturn_engine.models.job_cursor_state import JobCursorState
from saturn_engine.models.queue import Queue
from saturn_engine.stores import queues_store
from saturn_engine.utils import utcnow
Expand Down Expand Up @@ -172,9 +172,9 @@ def sync_jobs_states(
session.execute(cursors_stmt)

if jobs_values:
session.bulk_update_mappings(Job, jobs_values)
session.execute(update(Job), jobs_values)
if queues_values:
session.bulk_update_mappings(Queue, queues_values)
session.execute(update(Queue), queues_values)


CursorsStates = dict[JobId, dict[Cursor, t.Optional[dict]]]
Expand Down
4 changes: 3 additions & 1 deletion src/saturn_engine/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
AnySession = AnySyncSession


def upsert(session: AnySession) -> t.Callable[[object], postgresql.Insert]:
def upsert(
session: AnySession,
) -> t.Callable[[t.Any], postgresql.Insert | sqlite.Insert]:
if not session.bind:
raise ValueError("Session is unbound")

Expand Down

0 comments on commit 524830c

Please sign in to comment.