Skip to content

Commit

Permalink
added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Aug 7, 2024
1 parent a7c804b commit e4170d2
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 29 deletions.
25 changes: 1 addition & 24 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,32 +135,10 @@ def uses_tagged_union(cls_or_func: T) -> T:
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
The function works by iterating over the type hints of the provided class or
function For each type hint, it checks if the type (or its origin,
in the case of generic types) is registered as a tagged union in the
global `_tagged_unions` dictionary. If a match is found, the
tagged union's `add_referrer` method is called to register the class
or function as a referrer.
Example:
class PointsRequest:
A request for generated scan points
spec: Spec
This will register `PointsRequest` and `Spec` with the corresponding tagged
union or
@uses_tagged_union
def valid(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> ValidResponse | JSONResponse:
The decorator helps pydantic in understanding the core schema for the
function
"""
for k, v in get_type_hints(cls_or_func).items():
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
Expand Down Expand Up @@ -198,8 +176,7 @@ def add_member(self, cls: type):
# called muliple times for the same member, do no process if it wouldn't
# change the member list
return
if cls is self._base_class:
return

self._members.append(cls)
union = self._make_union()
if union:
Expand Down
10 changes: 7 additions & 3 deletions src/scanspec/regions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from collections.abc import Iterator
from dataclasses import is_dataclass
from typing import Generic
from collections.abc import Iterator, Mapping
from dataclasses import asdict, is_dataclass
from typing import Any, Generic

import numpy as np
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -66,6 +66,10 @@ def __sub__(self, other) -> DifferenceOf[Axis]:
def __xor__(self, other) -> SymmetricDifferenceOf[Axis]:
return if_instance_do(other, Region, lambda o: SymmetricDifferenceOf(self, o))

def serialize(self) -> Mapping[str, Any]:
"""Serialize the Region to a dictionary."""
return asdict(self) # type: ignore

@staticmethod
def deserialize(obj):
"""Deserialize the Region from a dictionary."""
Expand Down
2 changes: 1 addition & 1 deletion src/scanspec/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def bounded(


"""
Validation Decorator is requied as we are using custom pydantic core schema validators
Defers wrapping function with validate_call until class is fully instantiated
"""
Line.bounded = validate_call(Line.bounded) # type:ignore

Expand Down
16 changes: 15 additions & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pydantic import ValidationError

from scanspec.regions import Circle, Rectangle, UnionOf
from scanspec.regions import Circle, Rectangle, Region, UnionOf
from scanspec.specs import Line, Mask, Spec, Spiral


Expand All @@ -15,6 +15,20 @@ def test_line_serializes() -> None:
assert Spec.deserialize(serialized) == ob


def test_circle_serializes() -> None:
ob = Circle("x", "y", x_middle=0, y_middle=1, radius=4)
serialized = {
"x_axis": "x",
"y_axis": "y",
"x_middle": 0.0,
"y_middle": 1.0,
"radius": 4.0,
"type": "Circle",
}
assert ob.serialize() == serialized
assert Region.deserialize(serialized) == ob


def test_masked_circle_serializes() -> None:
ob = Mask(Line("x", 0, 1, 4), Circle("x", "y", x_middle=0, y_middle=1, radius=4))
serialized = {
Expand Down

0 comments on commit e4170d2

Please sign in to comment.