Skip to content

Commit

Permalink
Move to pyright and fix type errors (#135)
Browse files Browse the repository at this point in the history
* Move to pyright and fix type errors
  • Loading branch information
callumforrester committed Aug 19, 2024
1 parent c15b8c3 commit 0f7329f
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 27 deletions.
32 changes: 20 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ dev = [
"scanspec[plotting]",
"scanspec[service]",
"copier",
"mypy",
"myst-parser",
"pipdeptree",
"pre-commit",
"pydata-sphinx-theme>=0.12",
"pyright",
"pytest",
"pytest-cov",
"ruff",
Expand Down Expand Up @@ -61,8 +61,9 @@ name = "Tom Cobb"
[tool.setuptools_scm]
write_to = "src/scanspec/_version.py"

[tool.mypy]
ignore_missing_imports = true # Ignore missing stubs in imported modules
[tool.pyright]
# strict = ["src", "tests"]
reportMissingImports = false # Ignore missing stubs in imported modules

[tool.pytest.ini_options]
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error
Expand Down Expand Up @@ -95,12 +96,12 @@ passenv = *
allowlist_externals =
pytest
pre-commit
mypy
pyright
sphinx-build
sphinx-autobuild
commands =
pre-commit: pre-commit run --all-files {posargs}
type-checking: mypy src tests {posargs}
type-checking: pyright src tests {posargs}
tests: pytest --cov=scanspec --cov-report term --cov-report xml:cov.xml {posargs}
docs: sphinx-{posargs:build -E --keep-going} -T docs build/html
"""
Expand All @@ -111,14 +112,21 @@ line-length = 88

[tool.ruff.lint]
extend-select = [
"B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
"C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4
"E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e
"F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f
"W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w
"I", # isort - https://docs.astral.sh/ruff/rules/#isort-i
"UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up
"B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
"C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4
"E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e
"F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f
"W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w
"I", # isort - https://docs.astral.sh/ruff/rules/#isort-i
"UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up
"SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self
]
ignore = [
"B008", # We use function calls in service arguments
]

[tool.ruff.lint.per-file-ignores]
# By default, private member access is allowed in tests
# See https://github.com/DiamondLightSource/python-copier-template/issues/154
# Remove this line to forbid private member access in tests
"tests/**/*" = ["SLF001"]
3 changes: 3 additions & 0 deletions src/scanspec/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def cli(ctx, log_level: str):

# if no command is supplied, print the help message
if ctx.invoked_subcommand is None:
# We need to prove that cli has been converted to a command
# by the click decorator to keep pyright happy.
assert isinstance(cli, click.Command)
click.echo(cli.get_help(ctx))


Expand Down
12 changes: 6 additions & 6 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@

StrictConfig: ConfigDict = {"extra": "forbid"}

C = TypeVar("C")
T = TypeVar("T", type, Callable)


def discriminated_union_of_subclasses(
super_cls: type,
super_cls: type[C],
discriminator: str = "type",
) -> type:
) -> type[C]:
"""Add all subclasses of super_cls to a discriminated union.
For all subclasses of super_cls, add a discriminator field to identify
Expand Down Expand Up @@ -137,9 +140,6 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
return super_cls


T = TypeVar("T", type, Callable)


def uses_tagged_union(cls_or_func: T) -> T:
"""
T = TypeVar("T", type, Callable)
Expand Down Expand Up @@ -562,7 +562,7 @@ def __init__(
self.lengths = np.array([len(f) for f in stack])
#: Index of the end frame, one more than the last index that will be
#: produced
self.end_index = np.prod(self.lengths)
self.end_index = int(np.prod(self.lengths))
if num is not None and start + num < self.end_index:
self.end_index = start + num

Expand Down
18 changes: 12 additions & 6 deletions src/scanspec/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, xs, ys, zs, *args, **kwargs):
# Added here because of https://github.com/matplotlib/matplotlib/issues/21688
def do_3d_projection(self, renderer=None):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) # type: ignore
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))

return np.min(zs)
Expand Down Expand Up @@ -109,11 +109,17 @@ def plot_spec(spec: Spec[Any], title: str | None = None):
# Setup axes
if ndims > 2:
plt.figure(figsize=(6, 6))
plt_axes: Axes3D = plt.axes(projection="3d")
plt_axes = plt.axes(projection="3d")
plt_axes.grid(False)
plt_axes.set_zlabel(axes[-3])
plt_axes.set_ylabel(axes[-2])
plt_axes.view_init(elev=15)
if isinstance(plt_axes, Axes3D):
plt_axes.set_zlabel(axes[-3])
plt_axes.set_ylabel(axes[-2])
plt_axes.view_init(elev=15)
else:
raise TypeError(
"Expected matplotlib to create an Axes3D object, "
f"instead got: {plt_axes}"
)
elif ndims == 2:
plt.figure(figsize=(6, 6))
plt_axes = plt.axes()
Expand Down Expand Up @@ -208,7 +214,7 @@ def plot_spec(spec: Spec[Any], title: str | None = None):
_plot_arrow(plt_axes, arrow_arr)
elif splines:
# Plot the starting arrow in the direction of the first point
arrow_arr = [(2 * a[0] - a[1], a[0]) for a in splines[0]]
arrow_arr = [np.array([2 * a[0] - a[1], a[0]]) for a in splines[0]]
_plot_arrow(plt_axes, arrow_arr)
else:
# First point isn't moving, put a right caret marker
Expand Down
3 changes: 1 addition & 2 deletions src/scanspec/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from fastapi import Body, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse
from pydantic import Field
from pydantic.dataclasses import dataclass

Expand Down Expand Up @@ -127,7 +126,7 @@ class SmallestStepResponse:
@app.post("/valid", response_model=ValidResponse)
def valid(
spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]),
) -> ValidResponse | JSONResponse:
) -> ValidResponse:
"""Validate wether a ScanSpec can produce a viable scan.
Args:
Expand Down
3 changes: 2 additions & 1 deletion src/scanspec/sphinxext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager

from docutils.statemachine import StringList
from matplotlib.sphinxext import plot_directive

from . import __version__
Expand All @@ -25,7 +26,7 @@ class ExampleSpecDirective(plot_directive.PlotDirective):
"""Runs `plot_spec` on the ``spec`` definied in the content."""

def run(self):
self.content = (
self.content = StringList(
["# Example Spec", "", "from scanspec.plot import plot_spec"]
+ [str(x) for x in self.content]
+ ["plot_spec(spec)"]
Expand Down

0 comments on commit 0f7329f

Please sign in to comment.