Skip to content

Commit

Permalink
Allow vanilla Pydantic serialisation/deserialisation of discriminated…
Browse files Browse the repository at this point in the history
… unions
  • Loading branch information
ZohebShaikh authored and DiamondJoseph committed Aug 9, 2024
1 parent f49320c commit 52a58ea
Show file tree
Hide file tree
Showing 8 changed files with 445 additions and 2,202 deletions.
2,404 changes: 335 additions & 2,069 deletions schema.json

Large diffs are not rendered by default.

158 changes: 45 additions & 113 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
from __future__ import annotations

import dataclasses
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import partial
from inspect import isclass
from functools import lru_cache
from typing import (
Any,
Generic,
Literal,
TypeVar,
Union,
get_origin,
get_type_hints,
)

import numpy as np
from pydantic import (
ConfigDict,
Field,
GetCoreSchemaHandler,
TypeAdapter,
)
from pydantic.dataclasses import rebuild_dataclass
from pydantic.fields import FieldInfo
from pydantic_core.core_schema import tagged_union_schema

__all__ = [
"if_instance_do",
Expand All @@ -43,13 +37,13 @@


def discriminated_union_of_subclasses(
cls,
cls: type,
discriminator: str = "type",
):
) -> type:
"""Add all subclasses of super_cls to a discriminated union.
For all subclasses of super_cls, add a discriminator field to identify
the type. Raw JSON should look like {"type": <type name>, params for
the type. Raw JSON should look like {<discriminator>: <type name>, params for
<type name>...}.
Example::
Expand Down Expand Up @@ -107,131 +101,69 @@ def calculate(self) -> int:
super_cls: The superclass of the union, Expression in the above example
discriminator: The discriminator that will be inserted into the
serialized documents for type determination. Defaults to "type".
config: A pydantic config class to be inserted into all
subclasses. Defaults to None.
Returns:
Type | Callable[[Type], Type]: A decorator that adds the necessary
Type: A decorator that adds the necessary
functionality to a class.
"""
tagged_union = _TaggedUnion(cls, discriminator)
_tagged_unions[cls] = tagged_union
cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator))
cls.__get_pydantic_core_schema__ = classmethod(
partial(__get_pydantic_core_schema__, tagged_union=tagged_union)
)
return cls

def add_subclass_to_union(subclass):
# Add a discriminator field to a subclass so it can
# be identified when deserializing
subclass.__annotations__ = {
**subclass.__annotations__,
discriminator: Literal[subclass.__name__], # type: ignore
}
setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore

T = TypeVar("T", type, Callable)
def default_handler(subclass, source_type: Any, handler: GetCoreSchemaHandler):
tagged_union.add_member(subclass)
return handler(subclass)

subclass.__get_pydantic_core_schema__ = classmethod(default_handler)

def deserialize_as(cls, obj):
return _tagged_unions[cls].type_adapter.validate_python(obj)
def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
return tagged_union.schema(handler)

cls.__init_subclass__ = classmethod(add_subclass_to_union)
cls.__get_pydantic_core_schema__ = classmethod(get_schema_of_union)
return cls

def uses_tagged_union(cls_or_func: T) -> T:
"""
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
"""
for k, v in get_type_hints(cls_or_func).items():
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls_or_func, k)
return cls_or_func

T = TypeVar("T", type, Callable)


class _TaggedUnion:
def __init__(self, base_class: type, discriminator: str):
self._base_class = base_class
# The members of the tagged union, i.e. subclasses of the baseclasses
self._members: list[type] = []
# Classes and their field names that refer to this tagged union
self._referrers: dict[type | Callable, set[str]] = {}
self.type_adapter: TypeAdapter = TypeAdapter(None)
self._discriminator = discriminator

def _make_union(self):
if len(self._members) > 0:
return Union[tuple(self._members)] # type: ignore # noqa

def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
# Set the field to use the `type` discriminator on deserialize
# https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators
if isclass(cls):
assert isinstance(
field, FieldInfo
), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501
field.discriminator = self._discriminator
# The members of the tagged union, i.e. subclasses of the baseclass
self._members: list[type] = []

def add_member(self, cls: type):
if cls in self._members:
# A side effect of hooking to __get_pydantic_core_schema__ is that it is
# called muliple times for the same member, do no process if it wouldn't
# change the member list
return

self._members.append(cls)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set all the referrers
# to use this union
for referrer, fields in self._referrers.items():
if isclass(referrer):
for field in dataclasses.fields(referrer):
if field.name in fields:
field.type = union
self._set_discriminator(referrer, field.name, field.default)
rebuild_dataclass(referrer, force=True)
# Make a type adapter for use in deserialization
self.type_adapter = TypeAdapter(union)

def add_referrer(self, cls: type | Callable, attr_name: str):
self._referrers.setdefault(cls, set()).add(attr_name)
union = self._make_union()
if union:
# There are more than 1 subclasses in the union, so set the referrer
# (which is currently being constructed) to use it
# note that we use annotations as the class has not been turned into
# a dataclass yet
cls.__annotations__[attr_name] = union
self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None))


_tagged_unions: dict[type, _TaggedUnion] = {}


def __init_subclass__(discriminator: str, cls: type):
# Add a discriminator field to the class so it can
# be identified when deserailizing, and make sure it is last in the list
cls.__annotations__ = {
**cls.__annotations__,
discriminator: Literal[cls.__name__], # type: ignore
}
cls.type = Field(cls.__name__, repr=False) # type: ignore
# Replace any bare annotation with a discriminated union of subclasses
# and register this class as one that refers to that union so it can be updated
for k, v in get_type_hints(cls).items():
# This works for Expression[T] or Expression
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_referrer(cls, k)


def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler, tagged_union: _TaggedUnion
):
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
tagged_union.add_member(cls)
return handler(source_type)
for member in self._members:
if member != cls:
rebuild_dataclass(member, force=True)

def schema(self, handler):
return tagged_union_schema(
make_schema(tuple(self._members), handler),
discriminator=self._discriminator,
ref=self._base_class.__name__,
)


@lru_cache(1)
def make_schema(members: tuple[type, ...], handler):
return {member.__name__: handler(member) for member in members}


def if_instance_do(x: Any, cls: type, func: Callable):
Expand Down
3 changes: 1 addition & 2 deletions src/scanspec/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mpl_toolkits.mplot3d import Axes3D, proj3d
from scipy import interpolate

from .core import Path, uses_tagged_union
from .core import Path
from .regions import Circle, Ellipse, Polygon, Rectangle, Region, find_regions
from .specs import DURATION, Spec

Expand Down Expand Up @@ -86,7 +86,6 @@ def _plot_spline(axes, ranges, arrays: list[np.ndarray], index_colours: dict[int
yield unscaled_splines


@uses_tagged_union
def plot_spec(spec: Spec[Any], title: str | None = None):
"""Plot a spec, drawing the path taken through the scan.
Expand Down
8 changes: 3 additions & 5 deletions src/scanspec/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
from typing import Any, Generic

import numpy as np
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.dataclasses import dataclass

from .core import (
AxesPoints,
Axis,
StrictConfig,
deserialize_as,
discriminated_union_of_subclasses,
if_instance_do,
)
Expand Down Expand Up @@ -71,9 +70,8 @@ def serialize(self) -> Mapping[str, Any]:
return asdict(self) # type: ignore

@staticmethod
def deserialize(obj):
"""Deserialize the Region from a dictionary."""
return deserialize_as(Region, obj)
def deserialize(obj: Mapping[str, Any]) -> Region:
return TypeAdapter(Region).validate_python(obj)


def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray:
Expand Down
7 changes: 1 addition & 6 deletions src/scanspec/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import Field
from pydantic.dataclasses import dataclass

from scanspec.core import AxesPoints, Frames, Path, uses_tagged_union
from scanspec.core import AxesPoints, Frames, Path

from .specs import Line, Spec

Expand All @@ -27,7 +27,6 @@


@dataclass
@uses_tagged_union
class ValidResponse:
"""Response model for spec validation."""

Expand All @@ -44,7 +43,6 @@ class PointsFormat(str, Enum):


@dataclass
@uses_tagged_union
class PointsRequest:
"""A request for generated scan points."""

Expand Down Expand Up @@ -125,7 +123,6 @@ class SmallestStepResponse:


@app.post("/valid", response_model=ValidResponse)
@uses_tagged_union
def valid(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> ValidResponse | JSONResponse:
Expand Down Expand Up @@ -198,7 +195,6 @@ def bounds(


@app.post("/gap", response_model=GapResponse)
@uses_tagged_union
def gap(
spec: Spec = Body(
...,
Expand All @@ -224,7 +220,6 @@ def gap(


@app.post("/smalleststep", response_model=SmallestStepResponse)
@uses_tagged_union
def smallest_step(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> SmallestStepResponse:
Expand Down
10 changes: 4 additions & 6 deletions src/scanspec/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)

import numpy as np
from pydantic import Field, validate_call
from pydantic import Field, TypeAdapter, validate_call
from pydantic.dataclasses import dataclass

from .core import (
Expand All @@ -18,7 +18,6 @@
Path,
SnakedFrames,
StrictConfig,
deserialize_as,
discriminated_union_of_subclasses,
gap_between_frames,
if_instance_do,
Expand Down Expand Up @@ -107,13 +106,12 @@ def concat(self, other: Spec) -> Concat[Axis]:
return Concat(self, other)

def serialize(self) -> Mapping[str, Any]:
"""Serialize the spec to a dictionary."""
"""Serialize the Spec to a dictionary."""
return asdict(self) # type: ignore

@staticmethod
def deserialize(obj):
"""Deserialize the spec from a dictionary."""
return deserialize_as(Spec, obj)
def deserialize(obj: Mapping[str, Any]) -> Spec:
return TypeAdapter(Spec).validate_python(obj)


@dataclass(config=StrictConfig)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_basemodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from pydantic import BaseModel, TypeAdapter

from scanspec.specs import Line, Spec


class Foo(BaseModel):
spec: Spec


simple_foo = Foo(spec=Line("x", 1, 2, 5))
nested_foo = Foo(spec=Line("x", 1, 2, 5) * Line("y", 1, 2, 5))


@pytest.mark.parametrize("model", [simple_foo, nested_foo])
def test_model_validation(model: Foo):
# To/from Python dict
as_dict = model.model_dump()
deserialized = Foo.model_validate(as_dict)
assert deserialized == model

# To/from Json dict
as_json = model.model_dump_json()
deserialized = Foo.model_validate_json(as_json)
assert deserialized == model


@pytest.mark.parametrize("model", [simple_foo, nested_foo])
def test_type_adapter(model: Foo):
type_adapter = TypeAdapter(Foo)

# To/from Python dict
as_dict = model.model_dump()
deserialized = type_adapter.validate_python(as_dict)
assert deserialized == model

# To/from Json dict
as_json = model.model_dump_json()
deserialized = type_adapter.validate_json(as_json)
assert deserialized == model
Loading

0 comments on commit 52a58ea

Please sign in to comment.