Skip to content

Commit

Permalink
Detach
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed May 3, 2024
1 parent 9d38d06 commit aa27ba2
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions test/test_grad/test_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,11 @@ def test_autograd(dtype: torch.dtype, name: str) -> None:
}

# variable to be differentiated
positions.requires_grad_(True)
pos = positions.clone().requires_grad_(True)

# automatic gradient
energy = torch.sum(dftd3(numbers, positions, param))
(grad,) = torch.autograd.grad(energy, positions)

positions.detach_()
energy = torch.sum(dftd3(numbers, pos, param))
(grad,) = torch.autograd.grad(energy, pos)

assert pytest.approx(ref.cpu(), abs=tol) == grad.cpu()

Expand Down Expand Up @@ -243,11 +241,14 @@ def test_functorch(dtype: torch.dtype, name: str) -> None:
"a2": torch.tensor(5.00000000, **dd),
}

def dftd3_func(pos: Tensor) -> Tensor:
return dftd3(numbers, pos, param).sum()
# variable to be differentiated
pos = positions.clone().requires_grad_(True)

def dftd3_func(p: Tensor) -> Tensor:
return dftd3(numbers, p, param).sum()

grad = jacrev(dftd3_func)(positions)
grad = jacrev(dftd3_func)(pos)
assert isinstance(grad, Tensor)

assert grad.shape == ref.shape
assert pytest.approx(ref.cpu(), abs=tol) == grad.cpu()
assert pytest.approx(ref.cpu(), abs=tol) == grad.detach().cpu()

0 comments on commit aa27ba2

Please sign in to comment.