Skip to content

Commit

Permalink
Make functorch optional (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed May 23, 2023
1 parent 22df391 commit f26ec4d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tad_dftd3
version = 0.1.1
version = attr: tad_dftd3.__version__.__version__
description = Torch autodiff DFT-D3 implementation
long_description = file: README.rst
long_description_content_type = text/x-rst
Expand Down
10 changes: 10 additions & 0 deletions src/tad_dftd3/__version__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
Version module for tad_dftd3.
"""
import torch

__version__ = "0.1.1"

__torch_version__ = tuple(
int(x) for x in torch.__version__.split("+", maxsplit=1)[0].split(".")
)
41 changes: 29 additions & 12 deletions src/tad_dftd3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""
import torch

from .__version__ import __torch_version__
from .typing import (
Any,
Callable,
Expand All @@ -33,6 +34,14 @@
Union,
)

if __torch_version__ < (2, 0, 0): # pragma: no cover
try:
from functorch import jacrev # type: ignore
except ModuleNotFoundError:
jacrev = None
else: # pragma: no cover
from torch.func import jacrev # type: ignore


def real_atoms(numbers: Tensor) -> Tensor:
return numbers != 0
Expand Down Expand Up @@ -62,11 +71,14 @@ def euclidean_dist_quadratic_expansion(x: Tensor, y: Tensor) -> Tensor:
While this is significantly faster than the "direct expansion" or
"broadcast" approach, it only works for euclidean (p=2) distances.
Additionally, it has issues with numerical stability (the diagonal slightly
deviates from zero for `x=y`). The numerical stability should not pose
deviates from zero for ``x=y``). The numerical stability should not pose
problems, since we must remove zeros anyway for batched calculations.
For more information, see `https://github.com/eth-cscs/PythonHPC/blob/master/numpy/03-euclidean-distance-matrix-numpy.ipynb`__ or
`https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065`__.
For more information, see \
`this Jupyter notebook <https://github.com/eth-cscs/PythonHPC/blob/master/\
numpy/03-euclidean-distance-matrix-numpy.ipynb>`__ or \
`this discussion thread on PyTorch forum <https://discuss.pytorch.org/t/\
efficient-distance-matrix-computation/9065>`__.
Parameters
----------
Expand Down Expand Up @@ -103,8 +115,8 @@ def cdist_direct_expansion(x: Tensor, y: Tensor, p: int = 2) -> Tensor:
"""
Computation of cartesian distance matrix.
This currently replaces the use of `torch.cdist`, which does not handle
zeros well and produces nan's in the backward pass.
Contrary to `euclidean_dist_quadratic_expansion`, this function allows
arbitrary powers but is considerably slower.
Parameters
----------
Expand Down Expand Up @@ -231,22 +243,27 @@ def to_number(symbols: List[str]) -> Tensor:
)


def jacobian(f: Callable[..., Tensor], argnums: int) -> Any:
def jacobian(f: Callable[..., Tensor], argnums: int = 0) -> Any:
"""
Wrapper for Jacobian calcluation.
Note
----
Only reverse mode AD is given through the custom autograd classes. Forward
mode requires implementation of `jvp`.
Parameters
----------
f : Callable[[Any], Tensor]
The function whose result is differentiated.
argnums : int, optional
The variable w.r.t. which will be differentiated. Defaults to 0.
"""
return torch.func.jacrev(f, argnums=argnums) # type: ignore
if jacrev is None: # pragma: no cover
raise ModuleNotFoundError("PyTorch's `functorch` is not installed.")

return jacrev(f, argnums=argnums) # type: ignore


def hessian(
f: Callable[..., Tensor],
inputs: Tuple[Any, ...],
argnums: int = 0,
argnums: int,
is_batched: bool = False,
) -> Tensor:
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/test_grad/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,26 @@
import torch

from tad_dftd3 import dftd3, util
from tad_dftd3.__version__ import __torch_version__
from tad_dftd3.typing import Tensor

from ..samples import samples
from ..utils import reshape_fortran

if __torch_version__ < (2, 0, 0):
try:
from functorch import jacrev # type: ignore
except ModuleNotFoundError:
jacrev = None
else:
from torch.func import jacrev # type: ignore

sample_list = ["LiH", "SiH4", "PbH4-BiH3", "MB16_43_01"]

tol = 1e-8


@pytest.mark.skipif(jacrev is None, reason="Hessian tests require functorch")
def test_fail() -> None:
sample = samples["LiH"]
numbers = sample["numbers"]
Expand All @@ -28,6 +38,7 @@ def test_fail() -> None:
util.hessian(dftd3, (numbers, positions, param), argnums=2)


@pytest.mark.skipif(jacrev is None, reason="Hessian tests require functorch")
def test_zeros() -> None:
d = torch.randn(2, 3, requires_grad=True)

Expand All @@ -38,6 +49,7 @@ def dummy(x: Tensor) -> Tensor:
assert pytest.approx(torch.zeros([*d.shape, *d.shape])) == hess.detach()


@pytest.mark.skipif(jacrev is None, reason="Hessian tests require functorch")
@pytest.mark.parametrize("dtype", [torch.double])
@pytest.mark.parametrize("name", sample_list)
def test_single(dtype: torch.dtype, name: str) -> None:
Expand Down Expand Up @@ -69,6 +81,7 @@ def test_single(dtype: torch.dtype, name: str) -> None:


# TODO: Figure out batched Hessian computation
@pytest.mark.skipif(jacrev is None, reason="Hessian tests require functorch")
@pytest.mark.parametrize("dtype", [torch.double])
@pytest.mark.parametrize("name1", ["LiH"])
@pytest.mark.parametrize("name2", sample_list)
Expand Down

0 comments on commit f26ec4d

Please sign in to comment.