From aa27ba292be7452af9c4f7d17ead14a62d088469 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Fri, 3 May 2024 17:06:00 +0200 Subject: [PATCH] Detach --- test/test_grad/test_pos.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/test/test_grad/test_pos.py b/test/test_grad/test_pos.py index fee4646..2183064 100644 --- a/test/test_grad/test_pos.py +++ b/test/test_grad/test_pos.py @@ -169,13 +169,11 @@ def test_autograd(dtype: torch.dtype, name: str) -> None: } # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) # automatic gradient - energy = torch.sum(dftd3(numbers, positions, param)) - (grad,) = torch.autograd.grad(energy, positions) - - positions.detach_() + energy = torch.sum(dftd3(numbers, pos, param)) + (grad,) = torch.autograd.grad(energy, pos) assert pytest.approx(ref.cpu(), abs=tol) == grad.cpu() @@ -243,11 +241,14 @@ def test_functorch(dtype: torch.dtype, name: str) -> None: "a2": torch.tensor(5.00000000, **dd), } - def dftd3_func(pos: Tensor) -> Tensor: - return dftd3(numbers, pos, param).sum() + # variable to be differentiated + pos = positions.clone().requires_grad_(True) + + def dftd3_func(p: Tensor) -> Tensor: + return dftd3(numbers, p, param).sum() - grad = jacrev(dftd3_func)(positions) + grad = jacrev(dftd3_func)(pos) assert isinstance(grad, Tensor) assert grad.shape == ref.shape - assert pytest.approx(ref.cpu(), abs=tol) == grad.cpu() + assert pytest.approx(ref.cpu(), abs=tol) == grad.detach().cpu()