From 3fb98d543f18e3213898ac761810b8bbc29e8f8a Mon Sep 17 00:00:00 2001 From: David Muhr Date: Fri, 13 Oct 2023 14:32:03 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Unify=20typecheck=20and=20shapec?= =?UTF-8?q?heck=20(#3)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Taskfile.yml | 32 ++++++---- pyproject.toml | 4 +- safecheck/__init__.py | 21 +++---- safecheck/_protocol.py | 125 ++++++++++++++++++++++++++++++++++++++++ safecheck/_typecheck.py | 25 ++++++++ tests/test_benchmark.py | 68 ++++++++++++++++++++++ tests/test_protocol.py | 112 +++++++++++++++++++++++++++++++++++ tests/test_safecheck.py | 59 ++----------------- 8 files changed, 368 insertions(+), 78 deletions(-) create mode 100644 safecheck/_protocol.py create mode 100644 safecheck/_typecheck.py create mode 100644 tests/test_benchmark.py create mode 100644 tests/test_protocol.py diff --git a/Taskfile.yml b/Taskfile.yml index f748d45..d006632 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -1,19 +1,23 @@ -version: '3' +version: "3" env: PROJECT: safecheck - TESTS: tests + +vars: CONDA: micromamba + TESTS: tests tasks: env-create: - cmd: $CONDA env create -n $PROJECT --file env.yml --yes + cmd: $CONDA env create -n {{.PROJECT}} --file env.yml --yes env-remove: - cmd: $CONDA env remove -n $PROJECT --yes + cmd: $CONDA env remove -n {{.PROJECT}} --yes poetry-install: - cmd: curl -sSL https://install.python-poetry.org | python - + cmds: + - curl -sSL https://install.python-poetry.org | python - + - poetry config virtualenvs.create false poetry-remove: cmd: curl -sSL https://install.python-poetry.org | python - --uninstall @@ -21,6 +25,12 @@ tasks: poetry-update-dev: cmd: poetry add pytest@latest pytest-html@latest hypothesis@latest coverage@latest pytest-cov@latest pytest-benchmark@latest coverage-badge@latest ruff@latest pre-commit@latest black@latest pyright@latest typing-extensions@latest bandit@latest safety@latest numpy@latest torch@latest jax@latest -G dev + poetry-use: + cmds: + - | + $CONDA activate {{.PROJECT}} + poetry env use system + pre-commit-install: cmd: poetry run pre-commit install @@ -29,18 +39,18 @@ tasks: format: cmds: - - poetry run ruff check $PROJECT --fix - - poetry run black --config pyproject.toml $PROJECT $TESTS + - poetry run ruff check {{.PROJECT}} --fix + - poetry run black --config pyproject.toml {{.PROJECT}} {{.TESTS}} test: cmds: - - poetry run pytest -c pyproject.toml --cov-report=html --cov=$PROJECT $TESTS/ + - poetry run pytest -c pyproject.toml --cov-report=html --cov={{.PROJECT}} {{.TESTS}}/ - poetry run coverage-badge -o assets/coverage.svg -f lint: cmds: - - poetry run ruff check $PROJECT - - poetry run black --diff --check --config pyproject.toml $PROJECT $TESTS + - poetry run ruff check {{.PROJECT}} + - poetry run black --diff --check --config pyproject.toml {{.PROJECT}} {{.TESTS}} typing: cmd: poetry run pyright @@ -48,7 +58,7 @@ tasks: safety: cmds: - poetry run safety check --full-report - - poetry run bandit -ll --recursive $PROJECT $TESTS + - poetry run bandit -ll --recursive {{.PROJECT}} {{.TESTS}} check: cmds: diff --git a/pyproject.toml b/pyproject.toml index 52c859e..c8aa95c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "safecheck" -version = "0.0.3" +version = "0.1.0" description = "Utilities for typechecking, shapechecking and dispatch." readme = "README.md" authors = ["David Muhr "] @@ -99,6 +99,8 @@ force-exclude = true ignore = [ "D203", # one blank line required before class docstring "D213", # multi line summary should start at second line + "ANN101", # missing type annotation for `self` in method + "B905", # `zip()` without an explicit `strict=` parameter ] [tool.ruff.isort] diff --git a/safecheck/__init__.py b/safecheck/__init__.py index fd4a49f..89f131d 100644 --- a/safecheck/__init__.py +++ b/safecheck/__init__.py @@ -8,11 +8,7 @@ packages. For example, it should be easily possible to switch from beartype to typeguard for runtime type checking. """ -# re-export everything necessary from beartype, never use beartype itself -from beartype import ( - beartype as typecheck, -) -from beartype._data.hint.datahinttyping import BeartypeableT, BeartypeReturn +# re-export everything necessary from beartype, never use beartype itself. from beartype.door import ( die_if_unbearable as assert_instance, is_bearable as is_instance, @@ -49,7 +45,6 @@ UInt16, # type: ignore[reportGeneralTypeIssues] UInt32, # type: ignore[reportGeneralTypeIssues] UInt64, # type: ignore[reportGeneralTypeIssues] - jaxtyped as _shapecheck, ) # re-export everything necessary from plum, never use plum itself. @@ -65,10 +60,17 @@ promote, ) +from ._protocol import ( + implements, + protocol, +) +from ._typecheck import typecheck + __all__ = [ # decorators (runtime type-checking) "typecheck", - "shapecheck", + "implements", + "protocol", # introspection "is_instance", # like "isinstance(...)" "assert_instance", # like "assert isinstance(...)" @@ -136,11 +138,6 @@ ... -def shapecheck(fn: BeartypeableT) -> BeartypeReturn: - """``shapecheck`` implies typecheck.""" - return _shapecheck(typecheck(fn)) - - def get_version() -> str: """Return the package version or "unknown" if no version can be found.""" from importlib import metadata diff --git a/safecheck/_protocol.py b/safecheck/_protocol.py new file mode 100644 index 0000000..b5fe9a5 --- /dev/null +++ b/safecheck/_protocol.py @@ -0,0 +1,125 @@ +from collections.abc import Callable +from inspect import Parameter, _empty, signature # type: ignore[reportPrivateUsage] +from typing import Any + +from ._typecheck import typecheck + +__all__ = [ + "implements", + "protocol", +] + +CallableAny = Callable[..., Any] + + +class FunctionProtocol: + def __init__( + self, + return_annotation: type, + parameters: list[Parameter], + ) -> None: + super().__init__() + self.return_annotation = return_annotation + self.parameters = parameters + + +class InvalidProtocolError(Exception): + def __init__(self, msg: str) -> None: + super().__init__(msg) + + +class ProtocolImplementationError(Exception): + def __init__(self, msg: str) -> None: + super().__init__(msg) + + +def protocol(func: CallableAny) -> FunctionProtocol: + sig = signature(func) + params = list(sig.parameters.values()) + if sig.return_annotation is _empty: + msg = "Cannot construct a protocol with missing return type annotation." + raise InvalidProtocolError(msg) + + for parameter in params: + if parameter.annotation is _empty: + msg = f"Cannot construct a protocol with missing type annotation, found {parameter}." + raise InvalidProtocolError(msg) + + if parameter.default is not _empty: + msg = f"Unexpected default value found in protocol definition, found {parameter}." + raise InvalidProtocolError(msg) + + return FunctionProtocol(sig.return_annotation, params) + + +def implements(protocol: FunctionProtocol) -> Callable[[CallableAny], CallableAny]: + if not isinstance(protocol, FunctionProtocol): # type: ignore[reportUnnecessaryIsInstance] + msg = ( + f"A protocol implementation using `implements` expects a FunctionProtocol parameter, " + f"but found {type(protocol)}. Did you use `@implements` without parameters? Use " + f"@implements(protocol) instead." + ) + raise ProtocolImplementationError(msg) + + def decorator(func: CallableAny) -> CallableAny: + sig = signature(func) + size = len(protocol.parameters) + + # check if the updated return annotation matches the protocol return annotation + return_annotation = protocol.return_annotation if sig.return_annotation is _empty else sig.return_annotation + if return_annotation != (proto_return := protocol.return_annotation): + msg = ( + f"Cannot implement a protocol without matching return types, but found return type " + f"{return_annotation} for a protocol with return type {proto_return}." + ) + raise ProtocolImplementationError(msg) + + # check if the updated shared parameters exactly match the protocol parameters + sig_params = list(sig.parameters.values()) + shared_params = update_annotations(protocol.parameters, sig_params) + if strip_defaults(shared_params[:size]) != (proto_params := protocol.parameters): + msg = ( + f"Cannot implement a protocol without matching parameter types, but found parameters " + f"{sig_params} for a protocol with parameters {proto_params}." + ) + raise ProtocolImplementationError(msg) + + # check if the other parameters all have default values + other_params = sig_params[size:] + if any(p.default is _empty for p in other_params): + msg = ( + f"Cannot implement a protocol that requires substitution, if any parameters not " + f"included in the protocol do not have a default value, found: {other_params}." + ) + raise ProtocolImplementationError(msg) + + # replace the function signature + final_parameters = shared_params + other_params + func.__signature__ = sig.replace( # type: ignore[reportFunctionMemberAccess] + parameters=final_parameters, + return_annotation=return_annotation, + ) + + # replace the function annotations (used by runtime type checker) + param_annotations = {p.name: p.annotation for p in final_parameters if p.annotation is not _empty} + return_annotation = {} if return_annotation is _empty else {"return": return_annotation} + func.__annotations__ = param_annotations | return_annotation + return typecheck(func) + + return decorator + + +def strip_defaults(params: list[Parameter]) -> list[Parameter]: + params = params.copy() + """Strip the default values for the parameters in the list, which are irrelevant for the comparison.""" + for param in params: + setattr(param, "_default", _empty) # noqa[B010] + + return params + + +def update_annotations(reference: list[Parameter], params: list[Parameter]) -> list[Parameter]: + for ref, param in zip(reference, params): + if param.annotation is _empty: + setattr(param, "_annotation", ref.annotation) # noqa[B010] + return params diff --git a/safecheck/_typecheck.py b/safecheck/_typecheck.py new file mode 100644 index 0000000..af5de6c --- /dev/null +++ b/safecheck/_typecheck.py @@ -0,0 +1,25 @@ +from beartype import beartype as _typecheck +from beartype._data.hint.datahinttyping import BeartypeableT, BeartypeReturn +from jaxtyping import jaxtyped as _shapecheck + + +def typecheck(fn: BeartypeableT) -> BeartypeReturn: + """Typecheck a function without jaxtyping annotations, otherwise additionally shapecheck the function. + + :param fn: Any function or method. + :return: Typechecked function or method. + :raises: BeartypeException if a call to the function does not satisfy the typecheck. + """ + # check if there is any annotation requiring a shapecheck, i.e. any jaxtyping annotation that is not "..." + # this check is significantly slower than the string-based check implemented below (~+50%), but this should + # only be relevant in tight loops. + # for annotation in fn.__annotations__.values(): + # if getattr(annotation, "dim_str", "") != "...": + + # simply check if there is any mention of jaxtyping in the annotations, this adds barely any overhead to + # a base call of beartype's @beartype + if "jaxtyping" in str(fn.__annotations__): + # shapecheck implies typecheck + return _shapecheck(_typecheck(fn)) + + return _typecheck(fn) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py new file mode 100644 index 0000000..692c1fc --- /dev/null +++ b/tests/test_benchmark.py @@ -0,0 +1,68 @@ +import numpy +from beartype import beartype + +from safecheck import * + +args = list(range(10)) +args_shaped = numpy.random.randn(10, 100) # dim0=number of args, dim1=size of arg + + +def decorate(f): + return f + + +def f(*_: int) -> None: + ... + + +def f_shaped(*_: Shaped[NumpyArray, "n"]) -> None: + ... + + +def test_no_overhead(benchmark): + benchmark(f, *args) + + +def test_no_overhead_shaped(benchmark): + benchmark(f_shaped, *args_shaped) + + +def test_minimal_overhead(benchmark): + benchmark(decorate(f), *args) + + +def test_minimal_overhead_shaped(benchmark): + benchmark(decorate(f_shaped), *args_shaped) + + +def test_beartype(benchmark): + benchmark(beartype(f), *args) + + +def test_beartype_shaped(benchmark): + benchmark(beartype(f_shaped), *args_shaped) + + +def test_typecheck(benchmark): + benchmark(typecheck(f), *args) + + +def test_typecheck_shaped(benchmark): + benchmark(typecheck(f_shaped), *args_shaped) + + +def test_dispatch(benchmark): + dispatch = Dispatcher() + benchmark(dispatch(f), *args) + + +def test_dispatch_shaped(benchmark): + benchmark(dispatch(f_shaped), *args_shaped) + + +def test_protocol(benchmark): + benchmark(implements(protocol(f))(f), *args) + + +def test_protocol_shaped(benchmark): + benchmark(implements(protocol(f_shaped))(f_shaped), *args_shaped) diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 0000000..4530931 --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,112 @@ +from typing import Any + +import pytest + +from safecheck import * + + +def valid_protocol(): + def pos_kw(a: int, b: float) -> float: + return a + b + + def kw_only(*, a: int, b: float) -> float: + ... + + def pos_only(a: int, /, b: float) -> float: + ... + + def pos_varargs(*args: Any) -> Any: + ... + + def kw_varargs(**kwargs: Any) -> Any: + ... + + return [pos_kw, kw_only, pos_only, pos_varargs, kw_varargs] + + +def valid_implementation(): + # same specification as the protocol + def f_equal(a: int, b: float) -> float: + ... + + # implementation with default parameter + def f_default(a: int, b: float = 0.0) -> float: + ... + + # implementation without annotations + def f_empty(a, b): + ... + + # implementation with partial annotations + def f_partial(a: int, b) -> float: + ... + + return [f_equal, f_default, f_empty, f_partial] + + +@pytest.mark.parametrize("f", valid_implementation()) +def test_valid_implementation(f): + @protocol + def p(a: int, b: float) -> float: + ... + + implements(p)(f) + + +def invalid_protocol(): + def missing_parameter_annotation(a, b: float) -> float: + ... + + def missing_return_annotation(a: int, b: float): + ... + + def unexpected_default_value(a: int, b: float = 0.0) -> float: + ... + + return [missing_parameter_annotation, missing_return_annotation, unexpected_default_value] + + +@pytest.mark.parametrize("f", invalid_protocol()) +def test_invalid_protocol(f): + from safecheck._protocol import InvalidProtocolError + + with pytest.raises(InvalidProtocolError): + protocol(f) + + +def invalid_implementation(): + def wrong_parameter(a: int, b: int) -> float: + ... + + def wrong_return(a: int, b: float) -> int: + ... + + def wrong_kind_kw(a, *, b): + ... + + def wrong_kind_pos(a, /, b): + ... + + def additional_missing_default(a, b, c): + ... + + return [wrong_parameter, wrong_return, wrong_kind_kw, wrong_kind_pos, additional_missing_default] + + +@pytest.mark.parametrize("f", invalid_implementation()) +def test_invalid_implementation(f): + from safecheck._protocol import ProtocolImplementationError + + @protocol + def p(a: int, b: float) -> float: + ... + + with pytest.raises(ProtocolImplementationError): + implements(p)(f) + + +def test_invalid_implementation_call(): + from safecheck._protocol import ProtocolImplementationError + + with pytest.raises(ProtocolImplementationError): + implements(None) diff --git a/tests/test_safecheck.py b/tests/test_safecheck.py index 78a1ee2..f88cf62 100644 --- a/tests/test_safecheck.py +++ b/tests/test_safecheck.py @@ -1,5 +1,3 @@ -import time - import pytest from beartype.roar import BeartypeCallHintParamViolation @@ -9,9 +7,6 @@ from safecheck import * -benchmark_sleep = 0.001 # time in s per benchmark trial run -benchmark_args = numpy.random.randn(10, 100) # dim0=number of args, dim1=size of arg - np_array = numpy.random.randint(low=0, high=1, size=(1,)) torch_array = torch.randint(low=0, high=1, size=(1,)) jax_array = jax.random.randint(key=jax.random.PRNGKey(0), minval=0, maxval=1, shape=(1,)) @@ -32,7 +27,7 @@ @pytest.mark.parametrize("array_type", array_types.keys()) def test_array_type(array_type): - @shapecheck + @typecheck def f(array: Int[array_type, "..."]) -> Int[array_type, "..."]: return array @@ -51,7 +46,7 @@ def f(array: Int[array_type, "..."]) -> Int[array_type, "..."]: @pytest.mark.parametrize("array_type", data_types.keys()) @pytest.mark.parametrize("data_type", next(iter(data_types.values())).keys()) def test_data_type(array_type, data_type): - @shapecheck + @typecheck def f(array: data_type[array_type, "..."]) -> data_type[array_type, "..."]: return array @@ -114,17 +109,17 @@ def test_array_type_dispatch_with_shapecheck(array_type): dispatch = Dispatcher() @dispatch - @shapecheck + @typecheck def f(_: Shaped[NumpyArray, "..."]) -> str: return "numpy" @dispatch - @shapecheck + @typecheck def f(_: Shaped[TorchArray, "..."]) -> str: return "torch" @dispatch - @shapecheck + @typecheck def f(_: Shaped[JaxArray, "..."]) -> str: return "jax" @@ -173,47 +168,3 @@ def f(_: Bool[JaxArray, "..."]) -> str: return "jax_bool" assert data_types_str[array_type][data_type] == f(data_types[array_type][data_type]) - - -def test_no_overhead(benchmark): - def f(*_): - time.sleep(benchmark_sleep) - - benchmark(f, *benchmark_args) - - -def test_minimal_decorator(benchmark): - def decorate(f): - return f - - @decorate - def f(*_): - time.sleep(benchmark_sleep) - - benchmark(f, *benchmark_args) - - -def test_typecheck_decorator(benchmark): - @typecheck - def f(*_: NumpyArray): - time.sleep(benchmark_sleep) - - benchmark(f, *benchmark_args) - - -def test_shapecheck_decorator(benchmark): - @shapecheck - def f(*_: Shaped[NumpyArray, "n"]) -> None: - time.sleep(benchmark_sleep) - - benchmark(f, *benchmark_args) - - -def test_dispatch_decorator(benchmark): - dispatch = Dispatcher() - - @dispatch - def f(*_: Shaped[NumpyArray, "n"]): - time.sleep(benchmark_sleep) - - benchmark(f, *benchmark_args)