Skip to content

Commit

Permalink
Set pyright to strict mode and fix type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Sep 13, 2024
1 parent 0c589dd commit 509c056
Show file tree
Hide file tree
Showing 15 changed files with 524 additions and 2,429 deletions.
2,072 changes: 1 addition & 2,071 deletions schema.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/scanspec/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
@click.version_option(prog_name="scanspec", message="%(version)s")
@click.pass_context
def cli(ctx, log_level: str):
def cli(ctx: click.Context, log_level: str):
"""Top level scanspec command line interface."""
level = getattr(logging, log_level.upper(), None)
logging.basicConfig(format="%(levelname)s:%(message)s", level=level)
Expand Down Expand Up @@ -48,7 +48,7 @@ def plot(spec: str):
@click.option(
"--port", default=8080, help="The port that the scanspec service will be hosted on."
)
def service(cors, port):
def service(cors: bool, port: int):
"""Run up a REST service."""
from scanspec.service import run_app

Expand Down
116 changes: 71 additions & 45 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import lru_cache
from inspect import isclass
Expand All @@ -10,17 +11,18 @@
TypeVar,
get_origin,
get_type_hints,
overload,
)

import numpy as np
import numpy.typing as npt
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
from pydantic.dataclasses import is_pydantic_dataclass, rebuild_dataclass
from pydantic_core import CoreSchema
from pydantic_core.core_schema import tagged_union_schema

__all__ = [
"if_instance_do",
"Axis",
"AxesPoints",
"Frames",
"SnakedFrames",
Expand All @@ -36,7 +38,9 @@
StrictConfig: ConfigDict = {"extra": "forbid"}

C = TypeVar("C")
T = TypeVar("T", type, Callable)
T = TypeVar("T")

GapArray = npt.NDArray[np.bool]


def discriminated_union_of_subclasses(
Expand Down Expand Up @@ -117,7 +121,7 @@ def calculate(self) -> int:
tagged_union = _TaggedUnion(super_cls, discriminator)
_tagged_unions[super_cls] = tagged_union

def add_subclass_to_union(subclass):
def add_subclass_to_union(subclass: type[C]):
# Add a discriminator field to a subclass so it can
# be identified when deserializing
subclass.__annotations__ = {
Expand All @@ -126,7 +130,9 @@ def add_subclass_to_union(subclass):
}
setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore

def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
def get_schema_of_union(
cls: type[C], source_type: Any, handler: GetCoreSchemaHandler
):
if cls is not super_cls:
tagged_union.add_member(cls)
return handler(cls)
Expand All @@ -140,7 +146,17 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
return super_cls


def uses_tagged_union(cls_or_func: T) -> T:
@overload
def uses_tagged_union(cls_or_func: type[C]) -> type[C]: ...


@overload
def uses_tagged_union(cls_or_func: Callable[..., T]) -> Callable[..., T]: ...


def uses_tagged_union(
cls_or_func: type[C] | Callable[..., T],
) -> type[C] | Callable[..., T]:
"""
T = TypeVar("T", type, Callable)
Decorator that processes the type hints of a class or function to detect and
Expand All @@ -162,13 +178,13 @@ def uses_tagged_union(cls_or_func: T) -> T:


class _TaggedUnion:
def __init__(self, base_class: type, discriminator: str):
def __init__(self, base_class: type[Any], discriminator: str):
self._base_class = base_class
# Classes and their field names that refer to this tagged union
self._discriminator = discriminator
# The members of the tagged union, i.e. subclasses of the baseclass
self._subclasses: list[type] = []
self._references: set[type | Callable] = set()
self._references: set[type | Callable[..., Any]] = set()

def add_member(self, cls: type):
if cls in self._subclasses:
Expand All @@ -180,12 +196,12 @@ def add_member(self, cls: type):
for ref in self._references:
_TaggedUnion._rebuild(ref)

def add_reference(self, cls_or_func: type | Callable):
def add_reference(self, cls_or_func: type | Callable[..., Any]):
self._references.add(cls_or_func)

@staticmethod
# https://github.com/bluesky/scanspec/issues/133
def _rebuild(cls_or_func: type | Callable):
def _rebuild(cls_or_func: type[Any] | Callable[..., Any]):
if isclass(cls_or_func):
if is_pydantic_dataclass(cls_or_func):
rebuild_dataclass(cls_or_func, force=True)
Expand All @@ -201,11 +217,13 @@ def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema:


@lru_cache(1)
def make_schema(members: tuple[type, ...], handler):
def make_schema(
members: tuple[type[Any], ...], handler: Callable[[Any], CoreSchema]
) -> dict[str, CoreSchema]:
return {member.__name__: handler(member) for member in members}


def if_instance_do(x: Any, cls: type, func: Callable):
def if_instance_do(x: C, cls: type[C], func: Callable[[C], T]):
"""If x is of type cls then return func(x), otherwise return NotImplemented.
Used as a helper when implementing operator overloading.
Expand All @@ -219,9 +237,12 @@ def if_instance_do(x: Any, cls: type, func: Callable):
#: A type variable for an `axis_` that can be specified for a scan
Axis = TypeVar("Axis")

#: Alternative axis variable to be used when two are required in the same type binding
OtherAxis = TypeVar("OtherAxis")

#: Map of axes to float ndarray of points
#: E.g. {xmotor: array([0, 1, 2]), ymotor: array([2, 2, 2])}
AxesPoints = dict[Axis, np.ndarray]
AxesPoints = dict[Axis, npt.NDArray[np.floating[Any]]]


class Frames(Generic[Axis]):
Expand Down Expand Up @@ -256,7 +277,7 @@ def __init__(
midpoints: AxesPoints[Axis],
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
gap: GapArray | None = None,
):
#: The midpoints of scan frames for each axis
self.midpoints = midpoints
Expand Down Expand Up @@ -304,7 +325,9 @@ def __len__(self) -> int:
# All axespoints arrays are same length, pick the first one
return len(self.gap)

def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
def extract(
self, indices: npt.NDArray[np.signedinteger[Any]], calculate_gap: bool = True
) -> Frames[Axis]:
"""Return a new Frames object restricted to the indices provided.
Args:
Expand All @@ -322,7 +345,7 @@ def extract_dict(ds: Iterable[AxesPoints[Axis]]) -> AxesPoints[Axis]:
return {k: v[dim_indices] for k, v in d.items()}
return {}

def extract_gap(gaps: Iterable[np.ndarray]) -> np.ndarray | None:
def extract_gap(gaps: Iterable[GapArray]) -> GapArray | None:
for gap in gaps:
if not calculate_gap:
return gap[dim_indices]
Expand Down Expand Up @@ -354,7 +377,7 @@ def concat_dict(ds: Sequence[AxesPoints[Axis]]) -> AxesPoints[Axis]:
# lower[ax] = np.concatenate(self.lower[ax], other.lower[ax])
return {a: np.concatenate([d[a] for d in ds]) for a in self.axes()}

def concat_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:
def concat_gap(gaps: Sequence[GapArray]) -> GapArray:
g = np.concatenate(gaps)
# Calc the first frame
g[0] = gap_between_frames(other, self)
Expand Down Expand Up @@ -382,7 +405,7 @@ def zip_dict(ds: Sequence[AxesPoints[Axis]]) -> AxesPoints[Axis]:
# lower[ax] = {**self.lower[ax], **other.lower[ax]}
return dict(kv for d in ds for kv in d.items())

def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:
def zip_gap(gaps: Sequence[GapArray]) -> GapArray:
# Gap if either frames has a gap. E.g.
# gap[i] = self.gap[i] | other.gap[i]
return np.logical_or.reduce(gaps)
Expand All @@ -392,24 +415,24 @@ def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:

def _merge_frames(
*stack: Frames[Axis],
dict_merge=Callable[[Sequence[AxesPoints[Axis]]], AxesPoints[Axis]], # type: ignore
gap_merge=Callable[[Sequence[np.ndarray]], np.ndarray | None],
dict_merge: Callable[[Sequence[AxesPoints[Axis]]], AxesPoints[Axis]], # type: ignore
gap_merge: Callable[[Sequence[GapArray]], GapArray | None],
) -> Frames[Axis]:
types = {type(fs) for fs in stack}
assert len(types) == 1, f"Mismatching types for {stack}"
cls = types.pop()

# If any lower or upper are different, apply to those
kwargs = {}
for a in ("lower", "upper"):
if any(fs.midpoints is not getattr(fs, a) for fs in stack):
kwargs[a] = dict_merge([getattr(fs, a) for fs in stack])

# Apply to midpoints, force calculation of gap
return cls(
midpoints=dict_merge([fs.midpoints for fs in stack]),
gap=gap_merge([fs.gap for fs in stack]),
**kwargs,
# If any lower or upper are different, apply to those
lower=dict_merge([fs.lower for fs in stack])
if any(fs.midpoints is not fs.lower for fs in stack)
else None,
upper=dict_merge([fs.upper for fs in stack])
if any(fs.midpoints is not fs.upper for fs in stack)
else None,
)


Expand All @@ -421,19 +444,23 @@ def __init__(
midpoints: AxesPoints[Axis],
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
gap: GapArray | None = None,
):
super().__init__(midpoints, lower=lower, upper=upper, gap=gap)
# Override first element of gap to be True, as subsequent runs
# of snake scans are always joined end -> start
self.gap[0] = False

@classmethod
def from_frames(cls, frames: Frames[Axis]) -> SnakedFrames[Axis]:
def from_frames(
cls: type[SnakedFrames[Any]], frames: Frames[OtherAxis]
) -> SnakedFrames[OtherAxis]:
"""Create a snaked version of a `Frames` object."""
return cls(frames.midpoints, frames.lower, frames.upper, frames.gap)

def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
def extract(
self, indices: npt.NDArray[np.int32], calculate_gap: bool = True
) -> Frames[Axis]:
"""Return a new Frames object restricted to the indices provided.
Args:
Expand Down Expand Up @@ -461,23 +488,23 @@ def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
cls = type(self)
gap = None

# If lower or upper are different, apply to those
kwargs = {}
if self.midpoints is not self.lower:
# If going backwards select from the opposite bound
kwargs["lower"] = {
# Apply to midpoints
return cls(
{k: v[snake_indices] for k, v in self.midpoints.items()},
gap=gap,
# If lower or upper are different, apply to those
lower={
k: np.where(backwards, self.upper[k][snake_indices], v[snake_indices])
for k, v in self.lower.items()
}
if self.midpoints is not self.upper:
kwargs["upper"] = {
if self.midpoints is not self.lower
else None,
upper={
k: np.where(backwards, self.lower[k][snake_indices], v[snake_indices])
for k, v in self.upper.items()
}

# Apply to midpoints
return cls(
{k: v[snake_indices] for k, v in self.midpoints.items()}, gap=gap, **kwargs
if self.midpoints is not self.upper
else None,
)


Expand All @@ -486,7 +513,9 @@ def gap_between_frames(frames1: Frames[Axis], frames2: Frames[Axis]) -> bool:
return any(frames1.upper[a][-1] != frames2.lower[a][0] for a in frames1.axes())


def squash_frames(stack: list[Frames[Axis]], check_path_changes=True) -> Frames[Axis]:
def squash_frames(
stack: list[Frames[Axis]], check_path_changes: bool = True
) -> Frames[Axis]:
"""Squash a stack of nested Frames into a single one.
Args:
Expand Down Expand Up @@ -648,10 +677,7 @@ def __init__(self, stack: list[Frames[Axis]]):
@property
def axes(self) -> list[Axis]:
"""The axes that will be present in each points dictionary."""
axes = []
for frames in self.stack:
axes += frames.axes()
return axes
return list(itertools.chain(*(frames.axes() for frames in self.stack)))

def __len__(self) -> int:
"""The number of dictionaries that will be produced if iterated over."""
Expand Down
Loading

0 comments on commit 509c056

Please sign in to comment.