diff --git a/examples/single.py b/examples/single.py index d2fb93c..24c2c3f 100644 --- a/examples/single.py +++ b/examples/single.py @@ -5,7 +5,7 @@ import tad_dftd3 as d3 numbers = mctc.convert.symbol_to_number(symbols="C C C C N C S H H H H H".split()) -positions = torch.Tensor( +positions = torch.tensor( [ [-2.56745685564671, -0.02509985979910, 0.00000000000000], [-1.39177582455797, +2.27696188880014, 0.00000000000000], diff --git a/src/tad_dftd3/disp.py b/src/tad_dftd3/disp.py index c928a37..e686267 100644 --- a/src/tad_dftd3/disp.py +++ b/src/tad_dftd3/disp.py @@ -52,7 +52,7 @@ >>> print(torch.sum(energy[0] - energy[1] - energy[2])) # energy in Hartree tensor(-0.0003964, dtype=torch.float64) """ -from typing import Dict, Optional +from __future__ import annotations import torch from tad_mctc import storch @@ -77,16 +77,17 @@ def dftd3( numbers: Tensor, positions: Tensor, - param: Dict[str, Tensor], + param: dict[str, Tensor], *, - ref: Optional[Reference] = None, - rcov: Optional[Tensor] = None, - rvdw: Optional[Tensor] = None, - r4r2: Optional[Tensor] = None, - cutoff: Optional[Tensor] = None, + ref: Reference | None = None, + rcov: Tensor | None = None, + rvdw: Tensor | None = None, + r4r2: Tensor | None = None, + cutoff: Tensor | None = None, counting_function: CountingFunction = ncoord.exp_count, weighting_function: WeightingFunction = model.gaussian_weight, damping_function: DampingFunction = rational_damping, + chunk_size: int | None = None, ) -> Tensor: """ Evaluate DFT-D3 dispersion energy for a batch of geometries. @@ -113,6 +114,9 @@ def dftd3( Function to calculate weight of individual reference systems. counting_function : Callable, optional Calculates counting value in range 0 to 1 for each atom pair. + chunk_size : int, optional + Chunk size for chunked computation of huge tensors that otherwise + create memory bottlenecks. Returns ------- @@ -142,7 +146,7 @@ def dftd3( numbers, positions, counting_function=counting_function, rcov=rcov ) weights = model.weight_references(numbers, cn, ref, weighting_function) - c6 = model.atomic_c6(numbers, weights, ref) + c6 = model.atomic_c6(numbers, weights, ref, chunk_size=chunk_size) return dispersion( numbers, @@ -159,12 +163,12 @@ def dftd3( def dispersion( numbers: Tensor, positions: Tensor, - param: Dict[str, Tensor], + param: dict[str, Tensor], c6: Tensor, - rvdw: Optional[Tensor] = None, - r4r2: Optional[Tensor] = None, + rvdw: Tensor | None = None, + r4r2: Tensor | None = None, damping_function: DampingFunction = rational_damping, - cutoff: Optional[Tensor] = None, + cutoff: Tensor | None = None, **kwargs: Any, ) -> Tensor: """ @@ -210,7 +214,7 @@ def dispersion( ) if torch.max(numbers) >= defaults.MAX_ELEMENT: raise ValueError( - f"No D3 parameters available for Z > {defaults.MAX_ELEMENT-1} " + f"No D3 parameters available for Z > {defaults.MAX_ELEMENT - 1} " f"({pse.Z2S[defaults.MAX_ELEMENT]})." ) @@ -232,7 +236,7 @@ def dispersion( def dispersion2( numbers: Tensor, positions: Tensor, - param: Dict[str, Tensor], + param: dict[str, Tensor], c6: Tensor, r4r2: Tensor, damping_function: DampingFunction, @@ -292,7 +296,7 @@ def dispersion2( def dispersion3( numbers: Tensor, positions: Tensor, - param: Dict[str, Tensor], + param: dict[str, Tensor], c6: Tensor, rvdw: Tensor, cutoff: Tensor, diff --git a/src/tad_dftd3/model/__init__.py b/src/tad_dftd3/model/__init__.py new file mode 100644 index 0000000..4a8ee3d --- /dev/null +++ b/src/tad_dftd3/model/__init__.py @@ -0,0 +1,45 @@ +# This file is part of tad-dftd3. +# SPDX-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model: Dispersion model +======================= + +Implementation of D3 model to obtain atomic C6 coefficients for a given +geometry. + +Examples +-------- +>>> import torch +>>> import tad_dftd3 as d3 +>>> import tad_mctc as mctc +>>> numbers = mctc.convert.symbol_to_number(["O", "H", "H"]) +>>> positions = torch.Tensor([ +... [+0.00000000000000, +0.00000000000000, -0.73578586109551], +... [+1.44183152868459, +0.00000000000000, +0.36789293054775], +... [-1.44183152868459, +0.00000000000000, +0.36789293054775], +... ]) +>>> ref = d3.reference.Reference() +>>> rcov = d3.data.covalent_rad_d3[numbers] +>>> cn = mctc.ncoord.cn_d3(numbers, positions, rcov=rcov, counting_function=d3.ncoord.exp_count) +>>> weights = d3.model.weight_references(numbers, cn, ref, d3.model.gaussian_weight) +>>> c6 = d3.model.atomic_c6(numbers, weights, ref) +>>> torch.set_printoptions(precision=7) +>>> print(c6) +tensor([[10.4130471, 5.4368822, 5.4368822], + [ 5.4368822, 3.0930154, 3.0930154], + [ 5.4368822, 3.0930154, 3.0930154]], dtype=torch.float64) +""" +from .c6 import * +from .weights import * diff --git a/src/tad_dftd3/model/c6.py b/src/tad_dftd3/model/c6.py new file mode 100644 index 0000000..e55c3de --- /dev/null +++ b/src/tad_dftd3/model/c6.py @@ -0,0 +1,393 @@ +# This file is part of tad-dftd3. +# SPDX-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model: Atomic C6 +================ + +Computation of atomic C6 dispersion coefficients. + +Since this part can be the most memory-intensive, we provide a custom backward +function (i.e., analytical gradient) and options for chunking. +""" +from __future__ import annotations + +import torch +from tad_mctc._version import __tversion__ +from tad_mctc.math import einsum +from tad_mctc.tools import memory + +from ..reference import Reference +from ..typing import Callable, Protocol, Tensor + +__all__ = ["atomic_c6"] + + +# main entry point + + +def atomic_c6( + numbers: Tensor, + weights: Tensor, + reference: Reference, + chunk_size: None | int = None, +) -> Tensor: + """ + Calculate atomic dispersion coefficients. + + .. warning:: + + This function is the most memory intensive part of the calculation and + may require chunking for large systems. Without chunking, for example, + 2000 atoms (`numbers`) require the construction of a 1.5 GB tensor. + + Parameters + ---------- + numbers : Tensor + The atomic numbers of the atoms in the system of shape `(..., nat)`. + weights : Tensor + Weights of all reference systems of shape `(..., nat, 7)`. + reference : Reference + Reference systems for D3 model. Contains the reference C6 coefficients + of shape `(..., nelements, nelements, 7, 7)`. + + Returns + ------- + Tensor + Atomic dispersion coefficients of shape `(..., nat, nat)`. + """ + _check_memory(numbers, weights, chunk_size) + + AtomicC6 = AtomicC6_V1 if __tversion__ < (2, 0, 0) else AtomicC6_V2 + res = AtomicC6.apply(numbers, weights, reference, chunk_size) + assert res is not None + return res + + +# helpers + + +def _check_memory( + numbers: Tensor, weights: Tensor, chunk_size: None | int = None +) -> None: + """ + Check memory usage for the construction of the C6 tensor. + Throw an error or warning for potential memory issues. + + Parameters + ---------- + numbers : Tensor + Atomic numbers of the atoms in the system. + weights : Tensor + Weights of all reference systems. + chunk_size : None | int, optional + Chunk size for the calculation of the C6 tensor. Defaults to `None`. + + Raises + ------ + MemoryError + If the estimated memory usage exceeds the total available memory. + """ + # Required memory for the C6 tensor + if chunk_size is None: + size = (numbers.shape[-1], numbers.shape[-1], 7, 7) + else: + size = (numbers.shape[-1], chunk_size, 7, 7) + mem = memory.memory_tensor(size, weights.dtype) + + # actual memory usage + free, total = memory.memory_device(numbers.device) + + if mem > total: + raise MemoryError( + f"Estimated memory usage exceeds total available memory: {mem:.2f} " + f"MB > {total:.2f} MB. During the construction of the C6 " + f"dispersion coefficients, a 4D tensor of shape {size} is required " + "for efficient tensor operations. To fit the tensor into memory, " + "try using a chunk size or reduce the chunk size via the optional " + "`chunk_size` argument." + ) + + if mem > free: + # pylint: disable=import-outside-toplevel + from warnings import warn + + warn( + "Estimated memory usage appears to exceed the available memory: " + f"{mem:.2f} MB > {free:.2f} MB. If the calculation fails due to " + "memory issues, consider reducing the chunk size via the optional " + "`chunk_size` argument.", + ResourceWarning, + ) + + +def _einsum(rc6: Tensor, weights_i: Tensor, weights_j: Tensor) -> Tensor: + """ + Perform an einsum operation for the atomic C6 coefficients. + + Parameters + ---------- + rc6 : Tensor + Reference C6 coefficients. + weights_i : Tensor + Weights of all reference systems. + weights_j : Tensor + Weights of all reference systems. + + Returns + ------- + Tensor + Atomic C6 dispersion coefficients. + """ + # The default einsum path is fastest if the large tensors comes first. + # (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2) + return einsum( + "...ijab,...ia,...jb->...ij", + *(rc6, weights_i, weights_j), + optimize=[(0, 1), (0, 1)], + ) + + +# full and chunked versions + + +def _atomic_c6_full( + numbers: Tensor, + weights: Tensor, + reference: Reference, +) -> Tensor: + """ + Calculation of atomic dispersion coefficients without chunking. Might cause + memory issues for very large systems. + + Parameters + ---------- + numbers : Tensor + The atomic numbers of the atoms in the system of shape `(..., nat)`. + weights : Tensor + Weights of all reference systems of shape `(..., nat, 7)`. + reference : Reference + Reference systems for D3 model. Contains the reference C6 coefficients + of shape `(..., nelements, nelements, 7, 7)`. + + Returns + ------- + Tensor + Atomic dispersion coefficients of shape `(..., nat, nat)`. + """ + # NOTE: This old version creates large intermediate tensors and builds the + # full matrix before the sum reduction, which requires a lot of memory. + # + # gw = w.unsqueeze(-1).unsqueeze(-3) * w.unsqueeze(-2).unsqueeze(-4) + # c6 = torch.sum(torch.sum(torch.mul(gw, rc6), dim=-1), dim=-1) + + rc6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] + return _einsum(rc6, weights, weights) + + +def _atomic_c6_chunked( + numbers: Tensor, + weights: Tensor, + reference: Reference, + chunk_size: int, +) -> Tensor: + """ + Chunked version of the calculation of atomic dispersion coefficients. + + Parameters + ---------- + numbers : Tensor + The atomic numbers of the atoms in the system of shape `(..., nat)`. + weights : Tensor + Weights of all reference systems of shape `(..., nat, 7)`. + reference : Reference + Reference systems for D3 model. Contains the reference C6 coefficients + of shape `(..., nelements, nelements, 7, 7)`. + chunk_size : int + Chunk size for the calculation of the C6 tensor. + + Returns + ------- + Tensor + Atomic dispersion coefficients of shape `(..., nat, nat)`. + """ + + nat = numbers.shape[-1] + c6_output = torch.zeros( + (*numbers.shape, nat), device=numbers.device, dtype=weights.dtype + ) + + for start in range(0, nat, chunk_size): + end = min(start + chunk_size, nat) + num_chunk = numbers[..., start:end] # (..., chunk_size) + + # Chunked indexing into reference.c6: (..., chunk_size, nat, 7, 7) + rc6_chunk = reference.c6[num_chunk.unsqueeze(-1), numbers.unsqueeze(-2)] + + # Also chunk the weights: (..., chunk_size, 7) + weights_chunk = weights[..., start:end, :] + + # (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2) + contribution = _einsum(rc6_chunk, weights_chunk, weights) + + # Add contributions to the correct slice of the output tensor + c6_output[..., start:end, :] += contribution + + return c6_output + + +# custom autograd functions + + +class CTX(Protocol): + save_for_backward: Callable[[Tensor, Tensor], None] + saved_tensors: tuple[Tensor, Tensor] + chunk_size: None | int + reference: Reference + + +class AtomicC6Base(torch.autograd.Function): + """ + Base class for the version-specific autograd function for atomic C6. + Different PyTorch versions only require different `forward()` signatures. + """ + + @staticmethod + def backward(ctx: CTX, grad_out: Tensor) -> tuple[None, Tensor, None, None]: + numbers, weights = ctx.saved_tensors + chunk_size = ctx.chunk_size + ref = ctx.reference + + # We need the derivatives of the following expression: + # c_ij ​= ∑a,b w_ia *× w_jb ​* c_ijab​ + + ########################### + ### Non-chunked version ### + ########################### + + if chunk_size is None: + # (..., nel, nel, 7, 7) -> (..., nat, nat, 7, 7) + rc6 = ref.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] + + # ∂c_ij/∂w_jb = ∑a w_ia * c_ijab + # (..., n1, n2, r1, r2) * (..., n2, r2) -> (..., n1, n2, r2) + g_jb = einsum("...ijab,...ia->...ijb", rc6, weights) + + # vjp: (..., n1, n2) * (..., n1, n2, r2) -> (..., n2, r2) + _gj = einsum("...ij,...ijb->...jb", grad_out, g_jb) + + # ∂c_ij/∂w_ia = ∑b w_jb * c_ijab + # (..., n1, n2, r1, r2) * (..., n2, r2) -> (..., n1, n2, r1) + g_ia = einsum("...ijab,...jb->...ija", rc6, weights) + + # vjp: (..., n1, n2) * (..., n1, n2, r1) -> (..., n1, r1) + _gi = einsum("...ij,...ija->...ia", grad_out, g_ia) + + weights_bar = _gi + _gj + + return None, weights_bar, None, None + + ####################### + ### Chunked version ### + ####################### + + nat = weights.shape[-2] + weights_bar = torch.zeros_like(weights) + + for start in range(0, nat, chunk_size): + end = min(start + chunk_size, nat) + + # Numbers and derivatives for this chunk + grad_chunk = grad_out[..., start:end, :] # (..., chunk_size, nat) + num_chunk = numbers[..., start:end] # (..., chunk_size) + + # Chunked indexing into reference.c6: (..., chunk_size, nat, 7, 7) + # -> Only the "i" index is chunked! + rc6_chunk = ref.c6[num_chunk.unsqueeze(-1), numbers.unsqueeze(-2)] + + # Also chunk the weights: (..., chunk_size, 7) + weights_chunk = weights[..., start:end, :] + + # _gi derivative is chunked (sum over non-chunked "j" index) + g_ia = einsum("...ijab,...jb->...ija", rc6_chunk, weights) + _gi = einsum("...ij,...ija->...ia", grad_chunk, g_ia) + + # _gj derivative is NOT chunked (sum over chunked "i" index) + g_jb = einsum("...ijab,...ia->...ijb", rc6_chunk, weights_chunk) + _gj = einsum("...ij,...ijb->...jb", grad_chunk, g_jb) + + # Accumulate gradients for the current chunk + weights_bar[..., start:end, :] += _gi + weights_bar += _gj + + return None, weights_bar, None, None + + +class AtomicC6_V1(AtomicC6Base): + """ + Custom autograd function for atomic C6 coefficients. + This is supposed to reduce memory usage. + """ + + @staticmethod + def forward( + ctx: CTX, + numbers: Tensor, + weights: Tensor, + reference: Reference, + chunk_size: None | int = None, + ) -> Tensor: + ctx.save_for_backward(numbers, weights) + ctx.chunk_size = chunk_size + ctx.reference = reference + + if chunk_size is None: + return _atomic_c6_full(numbers, weights, reference) + + return _atomic_c6_chunked(numbers, weights, reference, chunk_size) + + +class AtomicC6_V2(AtomicC6Base): + """ + Custom autograd function for atomic C6 coefficients. + This is supposed to reduce memory usage. + """ + + generate_vmap_rule = True + # https://pytorch.org/docs/master/notes/extending.func.html#automatically-generate-a-vmap-rule + # should work since we only use PyTorch operations + + @staticmethod + def forward( + numbers: Tensor, + weights: Tensor, + reference: Reference, + chunk_size: None | int = None, + ) -> Tensor: + if chunk_size is None: + return _atomic_c6_full(numbers, weights, reference) + + return _atomic_c6_chunked(numbers, weights, reference, chunk_size) + + @staticmethod + def setup_context( + ctx: CTX, + inputs: tuple[Tensor, Tensor, Reference, int | None], + output: Tensor, + ) -> None: + numbers, weights, reference, chunk_size = inputs + + ctx.save_for_backward(numbers, weights) + ctx.chunk_size = chunk_size + ctx.reference = reference diff --git a/src/tad_dftd3/model.py b/src/tad_dftd3/model/weights.py similarity index 75% rename from src/tad_dftd3/model.py rename to src/tad_dftd3/model/weights.py index e8f1559..c6b7b83 100644 --- a/src/tad_dftd3/model.py +++ b/src/tad_dftd3/model/weights.py @@ -40,50 +40,15 @@ [ 5.4368822, 3.0930154, 3.0930154], [ 5.4368822, 3.0930154, 3.0930154]], dtype=torch.float64) """ -import torch -from tad_mctc.math import einsum - -from .reference import Reference -from .typing import Any, Tensor, WeightingFunction - -__all__ = ["atomic_c6", "gaussian_weight", "weight_references"] +from __future__ import annotations +import torch +from tad_mctc import storch -def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor: - """ - Calculate atomic dispersion coefficients. - - Parameters - ---------- - numbers : Tensor - The atomic numbers of the atoms in the system of shape `(..., nat)`. - weights : Tensor - Weights of all reference systems of shape `(..., nat, 7)`. - reference : Reference - Reference systems for D3 model. Contains the reference C6 coefficients - of shape `(..., nelements, nelements, 7, 7)`. - - Returns - ------- - Tensor - Atomic dispersion coefficients of shape `(..., nat, nat)`. - """ - # (..., nel, nel, 7, 7) -> (..., nat, nat, 7, 7) - rc6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] - - # The default einsum path is fastest if the large tensors comes first. - # (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2) - return einsum( - "...ijab,...ia,...jb->...ij", - *(rc6, weights, weights), - optimize=[(0, 1), (0, 1)], - ) +from ..reference import Reference +from ..typing import Any, Tensor, WeightingFunction - # NOTE: This old version creates large intermediate tensors and builds the - # full matrix before the sum reduction, which requires a lot of memory. - # - # gw = w.unsqueeze(-1).unsqueeze(-3) * w.unsqueeze(-2).unsqueeze(-4) - # c6 = torch.sum(torch.sum(torch.mul(gw, rc6), dim=-1), dim=-1) +__all__ = ["gaussian_weight", "weight_references"] def gaussian_weight(dcn: Tensor, factor: float = 4.0) -> Tensor: @@ -165,15 +130,18 @@ def weight_references( # We solve this by running in double precision, adding a very small number # and using multiple masks. + small = torch.tensor(1e-300, device=cn.device, dtype=torch.double) + # normalize weights norm = torch.where( mask, torch.sum(weights, dim=-1, keepdim=True), - torch.tensor(1e-300, device=cn.device, dtype=torch.double), # double! + small, # double! ) # back to real dtype - gw_temp = (weights / norm).type(cn.dtype) + gw_temp = storch.divide(weights, norm, eps=small).type(cn.dtype) + assert torch.isnan(gw_temp).sum() == 0 # The following section handles cases with large CNs that lead to zeros in # after the exponential in the weighting function. If this happens all @@ -186,8 +154,11 @@ def weight_references( # maximum reference CN for each atom maxcn = torch.max(refcn, dim=-1, keepdim=True)[0] - # prevent division by 0 and small values - exceptional = (torch.isnan(gw_temp)) | (gw_temp > torch.finfo(cn.dtype).max) + # Here, we catch the potential NaN's from `gw_temp`. We cannot use `gw_temp` + # directly, because we have to use safe divide to not get NaN's in the + # backward. But `norm == 0` is equivalent. Additionally, we catch very + # large values occuring because of division by small values. + exceptional = (norm == 0) | (gw_temp > torch.finfo(cn.dtype).max) gw = torch.where( exceptional, diff --git a/src/tad_dftd3/typing/builtin.py b/src/tad_dftd3/typing/builtin.py index dd86c75..aab741b 100644 --- a/src/tad_dftd3/typing/builtin.py +++ b/src/tad_dftd3/typing/builtin.py @@ -19,6 +19,6 @@ Built-in type annotations are imported from the *tad-mctc* library, which handles some version checking. """ -from tad_mctc.typing import Any, Callable, NoReturn, TypedDict +from tad_mctc.typing import Any, Callable, NoReturn, Protocol, TypedDict -__all__ = ["Any", "Callable", "NoReturn", "TypedDict"] +__all__ = ["Any", "Callable", "NoReturn", "Protocol", "TypedDict"] diff --git a/src/tad_dftd3/typing/pytorch.py b/src/tad_dftd3/typing/pytorch.py index 7411476..c682b5a 100644 --- a/src/tad_dftd3/typing/pytorch.py +++ b/src/tad_dftd3/typing/pytorch.py @@ -23,6 +23,7 @@ CountingFunction, DampingFunction, Molecule, + Size, Tensor, TensorOrTensors, get_default_device, @@ -34,6 +35,7 @@ "CountingFunction", "DampingFunction", "Molecule", + "Size", "Tensor", "TensorOrTensors", "get_default_device", diff --git a/test/test_disp/test_special.py b/test/test_disp/test_special.py index 2cee39b..7fc697e 100644 --- a/test/test_disp/test_special.py +++ b/test/test_disp/test_special.py @@ -112,7 +112,5 @@ def test_batch(dtype: torch.dtype) -> None: } energy = dftd3(numbers, positions, param) - print(energy.sum(-1)) - print(ref.sum(-1)) assert energy.dtype == dtype assert pytest.approx(ref.cpu()) == energy.cpu() diff --git a/test/test_model/test_c6.py b/test/test_model/test_c6.py index 3e7b8a6..f27092d 100644 --- a/test/test_model/test_c6.py +++ b/test/test_model/test_c6.py @@ -15,18 +15,23 @@ """ Test C6 coefficients. """ +from __future__ import annotations + import pytest import torch +from tad_mctc.autograd import dgradcheck, dgradgradcheck from tad_mctc.batch import pack -from tad_dftd3 import model, reference -from tad_dftd3.typing import DD +from tad_dftd3 import model, ncoord, reference +from tad_dftd3.typing import DD, Callable, Protocol, Tensor -from ..conftest import DEVICE +from ..conftest import DEVICE, FAST_MODE from .samples import samples sample_list = ["SiH4", "PbH4-BiH3", "C6H5I-CH3SH", "MB16_43_01"] +tol = 1e-8 + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @pytest.mark.parametrize("name", sample_list) @@ -81,3 +86,121 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: assert c6.dtype == dtype assert pytest.approx(refc6.cpu(), abs=tol, rel=tol) == c6.cpu() + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize("size", [100, 200, 500]) +@pytest.mark.parametrize("chunk_size", [10, 100]) +def test_chunked(dtype: torch.dtype, size: int, chunk_size: int) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 + + ref = reference.Reference(**dd) + numbers = torch.randint(1, 86, (size,), device=DEVICE) + positions = torch.rand((size, 3), **dd) * 10 + + cn = ncoord.cn_d3(numbers, positions) + weights = model.weight_references(numbers, cn, ref) + + c6 = model.atomic_c6(numbers, weights, ref) + c6_chunked = model.atomic_c6(numbers, weights, ref, chunk_size=chunk_size) + + assert c6.dtype == c6_chunked.dtype == dtype + assert pytest.approx(c6.cpu(), abs=tol, rel=tol) == c6_chunked.cpu() + + +############################################################################### + + +class C6Func(Protocol): + """ + Type annotation for a function that calculates C6 coefficients. + """ + + def __call__( + self, + numbers: Tensor, + weights: Tensor, + ref: reference.Reference, + chunk_size: int | None = None, + ) -> Tensor: ... + + +def gradchecker( + dtype: torch.dtype, + name: str, + f: C6Func, + chunk_size: int | None = None, +) -> tuple[ + Callable[[Tensor], Tensor], # autograd function + Tensor, # differentiable variables +]: + dd: DD = {"device": DEVICE, "dtype": dtype} + + sample = samples[name] + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + + ref = reference.Reference(**dd) + cn = ncoord.cn_d3(numbers, positions) + w = model.weight_references(numbers, cn, ref) + + # variables to be differentiated + w = w.detach().clone().requires_grad_(True) + + def func(weights: Tensor) -> Tensor: + if chunk_size is None: + return f(numbers, weights, ref) + return f(numbers, weights, ref, chunk_size) + + return func, w + + +@pytest.mark.grad +@pytest.mark.parametrize("dtype", [torch.double]) +@pytest.mark.parametrize("name", ["LiH"] + sample_list) +@pytest.mark.parametrize( + "f, chunk_size", + [ + (model.c6._atomic_c6_full, None), + (model.c6._atomic_c6_chunked, 2), + (model.atomic_c6, None), + (model.atomic_c6, 2), + (model.c6.AtomicC6_V1.apply, None), + (model.c6.AtomicC6_V1.apply, 2), + ], +) +def test_gradcheck( + dtype: torch.dtype, name: str, f: C6Func, chunk_size: int | None +) -> None: + """ + Check a single analytical gradient of parameters against numerical + gradient from `torch.autograd.gradcheck`. + """ + func, diffvars = gradchecker(dtype, name, f, chunk_size) + assert dgradcheck(func, diffvars, atol=tol, fast_mode=FAST_MODE) + + +@pytest.mark.grad +@pytest.mark.parametrize("dtype", [torch.double]) +@pytest.mark.parametrize("name", sample_list) +@pytest.mark.parametrize( + "f, chunk_size", + [ + (model.c6._atomic_c6_full, None), + (model.c6._atomic_c6_chunked, 2), + (model.atomic_c6, None), + (model.atomic_c6, 2), + (model.c6.AtomicC6_V1.apply, None), + (model.c6.AtomicC6_V1.apply, 2), + ], +) +def test_gradgradcheck( + dtype: torch.dtype, name: str, f: C6Func, chunk_size: int | None +) -> None: + """ + Check a single analytical gradient of parameters against numerical + gradient from `torch.autograd.gradgradcheck`. + """ + func, diffvars = gradchecker(dtype, name, f, chunk_size=chunk_size) + assert dgradgradcheck(func, diffvars, atol=tol, fast_mode=FAST_MODE) diff --git a/test/test_model/test_util.py b/test/test_model/test_util.py new file mode 100644 index 0000000..b00fac3 --- /dev/null +++ b/test/test_model/test_util.py @@ -0,0 +1,45 @@ +# This file is part of tad-dftd3. +# SPDX-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test model utility. +""" +from unittest.mock import patch + +import pytest +import torch + +from tad_dftd3.model.c6 import _check_memory + +from ..conftest import DEVICE + +tol = 1e-8 + + +@patch("tad_mctc.tools.memory.memory_device") +def test_memory_total(mock_memory) -> None: + mock_memory.return_value = (0, 0) + + x = torch.randn((1000000,), device=DEVICE, dtype=torch.double) + with pytest.raises(MemoryError): + _check_memory(x, x) + + +@patch("tad_mctc.tools.memory.memory_device") +def test_memory_free(mock_memory) -> None: + mock_memory.return_value = (0, 1e10) + + x = torch.randn((10000,), device=DEVICE, dtype=torch.double) + with pytest.warns(ResourceWarning): + _check_memory(x, x)