Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop support for Python 3.8, 3.9 #128

Merged
merged 8 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .github/pages/make_switcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,27 @@
from argparse import ArgumentParser
from pathlib import Path
from subprocess import CalledProcessError, check_output
from typing import List, Optional


def report_output(stdout: bytes, label: str) -> List[str]:
def report_output(stdout: bytes, label: str) -> list[str]:
ret = stdout.decode().strip().split("\n")
print(f"{label}: {ret}")
return ret


def get_branch_contents(ref: str) -> List[str]:
def get_branch_contents(ref: str) -> list[str]:
"""Get the list of directories in a branch."""
stdout = check_output(["git", "ls-tree", "-d", "--name-only", ref])
return report_output(stdout, "Branch contents")


def get_sorted_tags_list() -> List[str]:
def get_sorted_tags_list() -> list[str]:
"""Get a list of sorted tags in descending order from the repository."""
stdout = check_output(["git", "tag", "-l", "--sort=-v:refname"])
return report_output(stdout, "Tags list")


def get_versions(ref: str, add: Optional[str]) -> List[str]:
def get_versions(ref: str, add: str | None) -> list[str]:
"""Generate the file containing the list of all GitHub Pages builds."""
# Get the directories (i.e. builds) from the GitHub Pages branch
try:
Expand All @@ -41,7 +40,7 @@ def get_versions(ref: str, add: Optional[str]) -> List[str]:
tags = get_sorted_tags_list()

# Make the sorted versions list from main branches and tags
versions: List[str] = []
versions: list[str] = []
for version in ["master", "main"] + tags:
if version in builds:
versions.append(version)
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
strategy:
matrix:
runs-on: ["ubuntu-latest"] # can add windows-latest, macos-latest
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
include:
# Include one that runs in the dev environment
- runs-on: "ubuntu-latest"
Expand Down
6 changes: 3 additions & 3 deletions docs/how-to/iterate-a-spec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ frame. You can get these by using the `Spec.midpoints()` method to produce a
>>> for d in spec.midpoints():
... print(d)
...
{'x': 1.0}
{'x': 1.5}
{'x': 2.0}
{'x': np.float64(1.0)}
{'x': np.float64(1.5)}
{'x': np.float64(2.0)}

This is simple, but not particularly performant, as the numpy arrays of
points are unpacked point by point into point dictionaries
Expand Down
20 changes: 3 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,19 @@ name = "scanspec"
classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
description = "Specify step and flyscan paths in a serializable, efficient and Pythonic way"
dependencies = [
"numpy>=1.19.3",
"click==8.1.3",
"pydantic<2.0",
"typing_extensions",
]
dependencies = ["numpy>=2", "click>=8.1", "pydantic<2.0", "httpx==0.26.0"]
dynamic = ["version"]
license.file = "LICENSE"
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.10"

[project.optional-dependencies]
# Plotting
plotting = [
# make sure a python 3.9 compatible scipy and matplotlib are selected
"scipy>=1.5.4",
"matplotlib>=3.2.2",
]
plotting = ["scipy", "matplotlib"]
# REST service support
service = ["fastapi==0.99", "uvicorn"]
# For development tests/docs
Expand Down Expand Up @@ -131,8 +119,6 @@ extend-select = [
"I", # isort - https://docs.astral.sh/ruff/rules/#isort-i
"UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up
]
# We use pydantic, so don't upgrade to py3.10 syntax yet
pyupgrade.keep-runtime-typing = true
ignore = [
"B008", # We use function calls in service arguments
]
85 changes: 37 additions & 48 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
from __future__ import annotations

from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import field
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
from typing import Any, Generic, Literal, TypeVar, Union

import numpy as np
from pydantic import BaseConfig, Extra, Field, ValidationError, create_model
from pydantic.error_wrappers import ErrorWrapper
from typing_extensions import Literal

__all__ = [
"if_instance_do",
Expand All @@ -43,11 +30,11 @@ class StrictConfig(BaseConfig):


def discriminated_union_of_subclasses(
super_cls: Optional[Type] = None,
super_cls: type | None = None,
*,
discriminator: str = "type",
config: Optional[Type[BaseConfig]] = None,
) -> Union[Type, Callable[[Type], Type]]:
config: type[BaseConfig] | None = None,
) -> type | Callable[[type], type]:
"""Add all subclasses of super_cls to a discriminated union.

For all subclasses of super_cls, add a discriminator field to identify
Expand Down Expand Up @@ -114,7 +101,7 @@ def calculate(self) -> int:
subclasses. Defaults to None.

Returns:
Union[Type, Callable[[Type], Type]]: A decorator that adds the necessary
Type | Callable[[Type], Type]: A decorator that adds the necessary
functionality to a class.
"""

Expand All @@ -130,12 +117,12 @@ def wrap(cls):


def _discriminated_union_of_subclasses(
super_cls: Type,
super_cls: type,
discriminator: str,
config: Optional[Type[BaseConfig]] = None,
) -> Union[Type, Callable[[Type], Type]]:
super_cls._ref_classes = set()
super_cls._model = None
config: type[BaseConfig] | None = None,
) -> type | Callable[[type], type]:
super_cls._ref_classes = set() # type: ignore
super_cls._model = None # type: ignore

def __init_subclass__(cls) -> None:
# Keep track of inherting classes in super class
Expand All @@ -157,7 +144,7 @@ def __validate__(cls, v: Any) -> Any:
# needs to be done once, after all subclasses have been
# declared
if cls._model is None:
root = Union[tuple(cls._ref_classes)] # type: ignore
root = Union[tuple(cls._ref_classes)] # type: ignore # noqa
cls._model = create_model(
super_cls.__name__,
__root__=(root, Field(..., discriminator=discriminator)),
Expand Down Expand Up @@ -185,7 +172,7 @@ def __validate__(cls, v: Any) -> Any:
return super_cls


def if_instance_do(x: Any, cls: Type, func: Callable):
def if_instance_do(x: Any, cls: type, func: Callable):
"""If x is of type cls then return func(x), otherwise return NotImplemented.

Used as a helper when implementing operator overloading.
Expand All @@ -201,7 +188,7 @@ def if_instance_do(x: Any, cls: Type, func: Callable):

#: Map of axes to float ndarray of points
#: E.g. {xmotor: array([0, 1, 2]), ymotor: array([2, 2, 2])}
AxesPoints = Dict[Axis, np.ndarray]
AxesPoints = dict[Axis, np.ndarray]


class Frames(Generic[Axis]):
Expand Down Expand Up @@ -234,9 +221,9 @@ class Frames(Generic[Axis]):
def __init__(
self,
midpoints: AxesPoints[Axis],
lower: Optional[AxesPoints[Axis]] = None,
upper: Optional[AxesPoints[Axis]] = None,
gap: Optional[np.ndarray] = None,
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
):
#: The midpoints of scan frames for each axis
self.midpoints = midpoints
Expand All @@ -253,7 +240,9 @@ def __init__(
# We have a gap if upper[i] != lower[i+1] for any axes
axes_gap = [
np.roll(upper, 1) != lower
for upper, lower in zip(self.upper.values(), self.lower.values())
for upper, lower in zip(
self.upper.values(), self.lower.values(), strict=False
)
]
self.gap = np.logical_or.reduce(axes_gap)
# Check all axes and ordering are the same
Expand All @@ -270,7 +259,7 @@ def __init__(
lengths.add(len(self.gap))
assert len(lengths) <= 1, f"Mismatching lengths {list(lengths)}"

def axes(self) -> List[Axis]:
def axes(self) -> list[Axis]:
"""The axes which will move during the scan.

These will be present in `midpoints`, `lower` and `upper`.
Expand Down Expand Up @@ -300,7 +289,7 @@ def extract_dict(ds: Iterable[AxesPoints[Axis]]) -> AxesPoints[Axis]:
return {k: v[dim_indices] for k, v in d.items()}
return {}

def extract_gap(gaps: Iterable[np.ndarray]) -> Optional[np.ndarray]:
def extract_gap(gaps: Iterable[np.ndarray]) -> np.ndarray | None:
for gap in gaps:
if not calculate_gap:
return gap[dim_indices]
Expand Down Expand Up @@ -371,7 +360,7 @@ def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:
def _merge_frames(
*stack: Frames[Axis],
dict_merge=Callable[[Sequence[AxesPoints[Axis]]], AxesPoints[Axis]], # type: ignore
gap_merge=Callable[[Sequence[np.ndarray]], Optional[np.ndarray]],
gap_merge=Callable[[Sequence[np.ndarray]], np.ndarray | None],
) -> Frames[Axis]:
types = {type(fs) for fs in stack}
assert len(types) == 1, f"Mismatching types for {stack}"
Expand All @@ -397,9 +386,9 @@ class SnakedFrames(Frames[Axis]):
def __init__(
self,
midpoints: AxesPoints[Axis],
lower: Optional[AxesPoints[Axis]] = None,
upper: Optional[AxesPoints[Axis]] = None,
gap: Optional[np.ndarray] = None,
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
):
super().__init__(midpoints, lower=lower, upper=upper, gap=gap)
# Override first element of gap to be True, as subsequent runs
Expand Down Expand Up @@ -431,7 +420,7 @@ def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
length = len(self)
backwards = (indices // length) % 2
snake_indices = np.where(backwards, (length - 1) - indices, indices) % length
cls: Type[Frames[Any]]
cls: type[Frames[Any]]
if not calculate_gap:
cls = Frames
gap = self.gap[np.where(backwards, length - indices, indices) % length]
Expand Down Expand Up @@ -464,7 +453,7 @@ def gap_between_frames(frames1: Frames[Axis], frames2: Frames[Axis]) -> bool:
return any(frames1.upper[a][-1] != frames2.lower[a][0] for a in frames1.axes())


def squash_frames(stack: List[Frames[Axis]], check_path_changes=True) -> Frames[Axis]:
def squash_frames(stack: list[Frames[Axis]], check_path_changes=True) -> Frames[Axis]:
"""Squash a stack of nested Frames into a single one.

Args:
Expand Down Expand Up @@ -530,7 +519,7 @@ class Path(Generic[Axis]):
"""

def __init__(
self, stack: List[Frames[Axis]], start: int = 0, num: Optional[int] = None
self, stack: list[Frames[Axis]], start: int = 0, num: int | None = None
):
#: The Frames stack describing the scan, from slowest to fastest moving
self.stack = stack
Expand All @@ -544,7 +533,7 @@ def __init__(
if num is not None and start + num < self.end_index:
self.end_index = start + num

def consume(self, num: Optional[int] = None) -> Frames[Axis]:
def consume(self, num: int | None = None) -> Frames[Axis]:
"""Consume at most num frames from the Path and return as a Frames object.

>>> fx = SnakedFrames({"x": np.array([1, 2])})
Expand Down Expand Up @@ -613,18 +602,18 @@ class Midpoints(Generic[Axis]):
>>> fy = Frames({"y": np.array([3, 4])})
>>> mp = Midpoints([fy, fx])
>>> for p in mp: print(p)
{'y': 3, 'x': 1}
{'y': 3, 'x': 2}
{'y': 4, 'x': 2}
{'y': 4, 'x': 1}
{'y': np.int64(3), 'x': np.int64(1)}
{'y': np.int64(3), 'x': np.int64(2)}
{'y': np.int64(4), 'x': np.int64(2)}
{'y': np.int64(4), 'x': np.int64(1)}
"""

def __init__(self, stack: List[Frames[Axis]]):
def __init__(self, stack: list[Frames[Axis]]):
#: The stack of Frames describing the scan, from slowest to fastest moving
self.stack = stack

@property
def axes(self) -> List[Axis]:
def axes(self) -> list[Axis]:
"""The axes that will be present in each points dictionary."""
axes = []
for frames in self.stack:
Expand All @@ -635,7 +624,7 @@ def __len__(self) -> int:
"""The number of dictionaries that will be produced if iterated over."""
return int(np.prod([len(frames) for frames in self.stack]))

def __iter__(self) -> Iterator[Dict[Axis, float]]:
def __iter__(self) -> Iterator[dict[Axis, float]]:
"""Yield {axis: midpoint} for each frame in the scan."""
path = Path(self.stack)
while len(path):
Expand Down
Loading
Loading