Skip to content

Commit

Permalink
Fix mask for exceptional values
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Apr 26, 2024
1 parent ca3a2bc commit d22eedb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 6 additions & 3 deletions src/tad_dftd3/model/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def weight_references(
)

# back to real dtype
# gw_temp = (storch.divide(weights, norm, eps=small)).type(cn.dtype)
gw_temp = storch.divide(weights, norm, eps=small).type(cn.dtype)
assert torch.isnan(gw_temp).sum() == 0

# The following section handles cases with large CNs that lead to zeros in
# after the exponential in the weighting function. If this happens all
Expand All @@ -154,8 +154,11 @@ def weight_references(
# maximum reference CN for each atom
maxcn = torch.max(refcn, dim=-1, keepdim=True)[0]

# prevent division by 0 and small values
exceptional = (torch.isnan(gw_temp)) | (gw_temp > torch.finfo(cn.dtype).max)
# Here, we catch the potential NaN's from `gw_temp`. We cannot use `gw_temp`
# directly, because we have to use safe divide to not get NaN's in the
# backward. But `norm == 0` is equivalent. Additionally, we catch very
# large values occuring because of division by small values.
exceptional = (norm == 0) | (gw_temp > torch.finfo(cn.dtype).max)

gw = torch.where(
exceptional,
Expand Down
2 changes: 0 additions & 2 deletions test/test_disp/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,5 @@ def test_batch(dtype: torch.dtype) -> None:
}

energy = dftd3(numbers, positions, param)
print(energy.sum(-1))
print(ref.sum(-1))
assert energy.dtype == dtype
assert pytest.approx(ref.cpu()) == energy.cpu()

0 comments on commit d22eedb

Please sign in to comment.