Skip to content

Commit

Permalink
Use torch.testing.assert_close instead of approx
Browse files Browse the repository at this point in the history
The tests outputs `RuntimeError: Can't call numpy() on Tensor that
requires grad. Use tensor.detach().numpy() instead.` for torch above
v1.4. The error seems to came from trying to convert tensor to scalar.

Replace the pytest approx() with `torch.testing.assert_close()`
to let torch handle the conversion.
Use `assert_allclose()` for torch versions before 1.9.0.
  • Loading branch information
leejuyuu committed Jun 2, 2024
1 parent 9a3afd8 commit 8cebf5b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
torch
pytest>=8,<9
pytest-cov>=2,<3
packaging
24 changes: 19 additions & 5 deletions tests/test_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch
import torch.nn as nn
from packaging.version import Version

from torchcrf import CRF

Expand All @@ -14,6 +15,8 @@
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

HAS_ASSERT_CLOSE = Version(torch.__version__) >= Version("1.9.0")


def compute_score(crf, emission, tag):
# emission: (seq_length, num_tags)
Expand Down Expand Up @@ -116,7 +119,10 @@ def test_works_with_mask(self):
denominator = math.log(sum(math.exp(s) for s in all_scores))
manual_llh += numerator - denominator

assert llh.item() == approx(manual_llh)
if HAS_ASSERT_CLOSE:
torch.testing.assert_close(llh, manual_llh, atol=1e-12, rtol=1e-6)
else:
torch.testing.assert_allclose(llh, manual_llh, atol=1e-12, rtol=1e-6)
llh.backward() # ensure gradients can be computed

def test_works_without_mask(self):
Expand Down Expand Up @@ -186,8 +192,10 @@ def test_reduction_none(self):
denominator = math.log(sum(math.exp(s) for s in all_scores))
manual_llh.append(numerator - denominator)

for llh_, manual_llh_ in zip(llh, manual_llh):
assert llh_.item() == approx(manual_llh_)
if HAS_ASSERT_CLOSE:
torch.testing.assert_close(llh, torch.tensor(manual_llh), atol=1e-12, rtol=1e-6)
else:
torch.testing.assert_allclose(llh, torch.tensor(manual_llh), atol=1e-12, rtol=1e-6)

def test_reduction_mean(self):
crf = make_crf()
Expand Down Expand Up @@ -219,7 +227,10 @@ def test_reduction_mean(self):
denominator = math.log(sum(math.exp(s) for s in all_scores))
manual_llh += numerator - denominator

assert llh.item() == approx(manual_llh / batch_size)
if HAS_ASSERT_CLOSE:
torch.testing.assert_close(llh, manual_llh / batch_size, atol=1e-12, rtol=1e-6)
else:
torch.testing.assert_allclose(llh, manual_llh / batch_size, atol=1e-12, rtol=1e-6)

def test_reduction_token_mean(self):
crf = make_crf()
Expand Down Expand Up @@ -258,7 +269,10 @@ def test_reduction_token_mean(self):
manual_llh += numerator - denominator
n_tokens += seq_len

assert llh.item() == approx(manual_llh / n_tokens)
if HAS_ASSERT_CLOSE:
torch.testing.assert_close(llh, manual_llh / n_tokens, atol=1e-12, rtol=1e-6)
else:
torch.testing.assert_allclose(llh, manual_llh / n_tokens, atol=1e-12, rtol=1e-6)

def test_batch_first(self):
crf = make_crf()
Expand Down

0 comments on commit 8cebf5b

Please sign in to comment.