Skip to content

Commit

Permalink
🎨 Unify typecheck and shapecheck (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
davnn committed Oct 13, 2023
1 parent b3c02c2 commit 3fb98d5
Show file tree
Hide file tree
Showing 8 changed files with 368 additions and 78 deletions.
32 changes: 21 additions & 11 deletions Taskfile.yml
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
version: '3'
version: "3"

env:
PROJECT: safecheck
TESTS: tests

vars:
CONDA: micromamba
TESTS: tests

tasks:
env-create:
cmd: $CONDA env create -n $PROJECT --file env.yml --yes
cmd: $CONDA env create -n {{.PROJECT}} --file env.yml --yes

env-remove:
cmd: $CONDA env remove -n $PROJECT --yes
cmd: $CONDA env remove -n {{.PROJECT}} --yes

poetry-install:
cmd: curl -sSL https://install.python-poetry.org | python -
cmds:
- curl -sSL https://install.python-poetry.org | python -
- poetry config virtualenvs.create false

poetry-remove:
cmd: curl -sSL https://install.python-poetry.org | python - --uninstall

poetry-update-dev:
cmd: poetry add pytest@latest pytest-html@latest hypothesis@latest coverage@latest pytest-cov@latest pytest-benchmark@latest coverage-badge@latest ruff@latest pre-commit@latest black@latest pyright@latest typing-extensions@latest bandit@latest safety@latest numpy@latest torch@latest jax@latest -G dev

poetry-use:
cmds:
- |
$CONDA activate {{.PROJECT}}
poetry env use system
pre-commit-install:
cmd: poetry run pre-commit install

Expand All @@ -29,26 +39,26 @@ tasks:

format:
cmds:
- poetry run ruff check $PROJECT --fix
- poetry run black --config pyproject.toml $PROJECT $TESTS
- poetry run ruff check {{.PROJECT}} --fix
- poetry run black --config pyproject.toml {{.PROJECT}} {{.TESTS}}

test:
cmds:
- poetry run pytest -c pyproject.toml --cov-report=html --cov=$PROJECT $TESTS/
- poetry run pytest -c pyproject.toml --cov-report=html --cov={{.PROJECT}} {{.TESTS}}/
- poetry run coverage-badge -o assets/coverage.svg -f

lint:
cmds:
- poetry run ruff check $PROJECT
- poetry run black --diff --check --config pyproject.toml $PROJECT $TESTS
- poetry run ruff check {{.PROJECT}}
- poetry run black --diff --check --config pyproject.toml {{.PROJECT}} {{.TESTS}}

typing:
cmd: poetry run pyright

safety:
cmds:
- poetry run safety check --full-report
- poetry run bandit -ll --recursive $PROJECT $TESTS
- poetry run bandit -ll --recursive {{.PROJECT}} {{.TESTS}}

check:
cmds:
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "safecheck"
version = "0.0.3"
version = "0.1.0"
description = "Utilities for typechecking, shapechecking and dispatch."
readme = "README.md"
authors = ["David Muhr <muhrdavid+github@gmail.com>"]
Expand Down Expand Up @@ -99,6 +99,8 @@ force-exclude = true
ignore = [
"D203", # one blank line required before class docstring
"D213", # multi line summary should start at second line
"ANN101", # missing type annotation for `self` in method
"B905", # `zip()` without an explicit `strict=` parameter
]

[tool.ruff.isort]
Expand Down
21 changes: 9 additions & 12 deletions safecheck/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
packages. For example, it should be easily possible to switch from beartype to
typeguard for runtime type checking.
"""
# re-export everything necessary from beartype, never use beartype itself
from beartype import (
beartype as typecheck,
)
from beartype._data.hint.datahinttyping import BeartypeableT, BeartypeReturn
# re-export everything necessary from beartype, never use beartype itself.
from beartype.door import (
die_if_unbearable as assert_instance,
is_bearable as is_instance,
Expand Down Expand Up @@ -49,7 +45,6 @@
UInt16, # type: ignore[reportGeneralTypeIssues]
UInt32, # type: ignore[reportGeneralTypeIssues]
UInt64, # type: ignore[reportGeneralTypeIssues]
jaxtyped as _shapecheck,
)

# re-export everything necessary from plum, never use plum itself.
Expand All @@ -65,10 +60,17 @@
promote,
)

from ._protocol import (
implements,
protocol,
)
from ._typecheck import typecheck

__all__ = [
# decorators (runtime type-checking)
"typecheck",
"shapecheck",
"implements",
"protocol",
# introspection
"is_instance", # like "isinstance(...)"
"assert_instance", # like "assert isinstance(...)"
Expand Down Expand Up @@ -136,11 +138,6 @@
...


def shapecheck(fn: BeartypeableT) -> BeartypeReturn:
"""``shapecheck`` implies typecheck."""
return _shapecheck(typecheck(fn))


def get_version() -> str:
"""Return the package version or "unknown" if no version can be found."""
from importlib import metadata
Expand Down
125 changes: 125 additions & 0 deletions safecheck/_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from collections.abc import Callable
from inspect import Parameter, _empty, signature # type: ignore[reportPrivateUsage]
from typing import Any

from ._typecheck import typecheck

__all__ = [
"implements",
"protocol",
]

CallableAny = Callable[..., Any]


class FunctionProtocol:
def __init__(
self,
return_annotation: type,
parameters: list[Parameter],
) -> None:
super().__init__()
self.return_annotation = return_annotation
self.parameters = parameters


class InvalidProtocolError(Exception):
def __init__(self, msg: str) -> None:
super().__init__(msg)


class ProtocolImplementationError(Exception):
def __init__(self, msg: str) -> None:
super().__init__(msg)


def protocol(func: CallableAny) -> FunctionProtocol:
sig = signature(func)
params = list(sig.parameters.values())
if sig.return_annotation is _empty:
msg = "Cannot construct a protocol with missing return type annotation."
raise InvalidProtocolError(msg)

for parameter in params:
if parameter.annotation is _empty:
msg = f"Cannot construct a protocol with missing type annotation, found {parameter}."
raise InvalidProtocolError(msg)

if parameter.default is not _empty:
msg = f"Unexpected default value found in protocol definition, found {parameter}."
raise InvalidProtocolError(msg)

return FunctionProtocol(sig.return_annotation, params)


def implements(protocol: FunctionProtocol) -> Callable[[CallableAny], CallableAny]:
if not isinstance(protocol, FunctionProtocol): # type: ignore[reportUnnecessaryIsInstance]
msg = (
f"A protocol implementation using `implements` expects a FunctionProtocol parameter, "
f"but found {type(protocol)}. Did you use `@implements` without parameters? Use "
f"@implements(protocol) instead."
)
raise ProtocolImplementationError(msg)

def decorator(func: CallableAny) -> CallableAny:
sig = signature(func)
size = len(protocol.parameters)

# check if the updated return annotation matches the protocol return annotation
return_annotation = protocol.return_annotation if sig.return_annotation is _empty else sig.return_annotation
if return_annotation != (proto_return := protocol.return_annotation):
msg = (
f"Cannot implement a protocol without matching return types, but found return type "
f"{return_annotation} for a protocol with return type {proto_return}."
)
raise ProtocolImplementationError(msg)

# check if the updated shared parameters exactly match the protocol parameters
sig_params = list(sig.parameters.values())
shared_params = update_annotations(protocol.parameters, sig_params)
if strip_defaults(shared_params[:size]) != (proto_params := protocol.parameters):
msg = (
f"Cannot implement a protocol without matching parameter types, but found parameters "
f"{sig_params} for a protocol with parameters {proto_params}."
)
raise ProtocolImplementationError(msg)

# check if the other parameters all have default values
other_params = sig_params[size:]
if any(p.default is _empty for p in other_params):
msg = (
f"Cannot implement a protocol that requires substitution, if any parameters not "
f"included in the protocol do not have a default value, found: {other_params}."
)
raise ProtocolImplementationError(msg)

# replace the function signature
final_parameters = shared_params + other_params
func.__signature__ = sig.replace( # type: ignore[reportFunctionMemberAccess]
parameters=final_parameters,
return_annotation=return_annotation,
)

# replace the function annotations (used by runtime type checker)
param_annotations = {p.name: p.annotation for p in final_parameters if p.annotation is not _empty}
return_annotation = {} if return_annotation is _empty else {"return": return_annotation}
func.__annotations__ = param_annotations | return_annotation
return typecheck(func)

return decorator


def strip_defaults(params: list[Parameter]) -> list[Parameter]:
params = params.copy()
"""Strip the default values for the parameters in the list, which are irrelevant for the comparison."""
for param in params:
setattr(param, "_default", _empty) # noqa[B010]

return params


def update_annotations(reference: list[Parameter], params: list[Parameter]) -> list[Parameter]:
for ref, param in zip(reference, params):
if param.annotation is _empty:
setattr(param, "_annotation", ref.annotation) # noqa[B010]
return params
25 changes: 25 additions & 0 deletions safecheck/_typecheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from beartype import beartype as _typecheck
from beartype._data.hint.datahinttyping import BeartypeableT, BeartypeReturn
from jaxtyping import jaxtyped as _shapecheck


def typecheck(fn: BeartypeableT) -> BeartypeReturn:
"""Typecheck a function without jaxtyping annotations, otherwise additionally shapecheck the function.
:param fn: Any function or method.
:return: Typechecked function or method.
:raises: BeartypeException if a call to the function does not satisfy the typecheck.
"""
# check if there is any annotation requiring a shapecheck, i.e. any jaxtyping annotation that is not "..."
# this check is significantly slower than the string-based check implemented below (~+50%), but this should
# only be relevant in tight loops.
# for annotation in fn.__annotations__.values():
# if getattr(annotation, "dim_str", "") != "...":

# simply check if there is any mention of jaxtyping in the annotations, this adds barely any overhead to
# a base call of beartype's @beartype
if "jaxtyping" in str(fn.__annotations__):
# shapecheck implies typecheck
return _shapecheck(_typecheck(fn))

return _typecheck(fn)
68 changes: 68 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy
from beartype import beartype

from safecheck import *

args = list(range(10))
args_shaped = numpy.random.randn(10, 100) # dim0=number of args, dim1=size of arg


def decorate(f):
return f


def f(*_: int) -> None:
...


def f_shaped(*_: Shaped[NumpyArray, "n"]) -> None:
...


def test_no_overhead(benchmark):
benchmark(f, *args)


def test_no_overhead_shaped(benchmark):
benchmark(f_shaped, *args_shaped)


def test_minimal_overhead(benchmark):
benchmark(decorate(f), *args)


def test_minimal_overhead_shaped(benchmark):
benchmark(decorate(f_shaped), *args_shaped)


def test_beartype(benchmark):
benchmark(beartype(f), *args)


def test_beartype_shaped(benchmark):
benchmark(beartype(f_shaped), *args_shaped)


def test_typecheck(benchmark):
benchmark(typecheck(f), *args)


def test_typecheck_shaped(benchmark):
benchmark(typecheck(f_shaped), *args_shaped)


def test_dispatch(benchmark):
dispatch = Dispatcher()
benchmark(dispatch(f), *args)


def test_dispatch_shaped(benchmark):
benchmark(dispatch(f_shaped), *args_shaped)


def test_protocol(benchmark):
benchmark(implements(protocol(f))(f), *args)


def test_protocol_shaped(benchmark):
benchmark(implements(protocol(f_shaped))(f_shaped), *args_shaped)
Loading

0 comments on commit 3fb98d5

Please sign in to comment.