From 8cebf5b484352c4c80f80e24d62afc1b7f91d275 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Wed, 29 May 2024 22:11:24 +0800 Subject: [PATCH] Use torch.testing.assert_close instead of approx The tests outputs `RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.` for torch above v1.4. The error seems to came from trying to convert tensor to scalar. Replace the pytest approx() with `torch.testing.assert_close()` to let torch handle the conversion. Use `assert_allclose()` for torch versions before 1.9.0. --- requirements-test.txt | 1 + tests/test_crf.py | 24 +++++++++++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 2ace388..ca6a507 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -4,3 +4,4 @@ torch pytest>=8,<9 pytest-cov>=2,<3 +packaging diff --git a/tests/test_crf.py b/tests/test_crf.py index dc70340..60c93ed 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -6,6 +6,7 @@ import pytest import torch import torch.nn as nn +from packaging.version import Version from torchcrf import CRF @@ -14,6 +15,8 @@ random.seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED) +HAS_ASSERT_CLOSE = Version(torch.__version__) >= Version("1.9.0") + def compute_score(crf, emission, tag): # emission: (seq_length, num_tags) @@ -116,7 +119,10 @@ def test_works_with_mask(self): denominator = math.log(sum(math.exp(s) for s in all_scores)) manual_llh += numerator - denominator - assert llh.item() == approx(manual_llh) + if HAS_ASSERT_CLOSE: + torch.testing.assert_close(llh, manual_llh, atol=1e-12, rtol=1e-6) + else: + torch.testing.assert_allclose(llh, manual_llh, atol=1e-12, rtol=1e-6) llh.backward() # ensure gradients can be computed def test_works_without_mask(self): @@ -186,8 +192,10 @@ def test_reduction_none(self): denominator = math.log(sum(math.exp(s) for s in all_scores)) manual_llh.append(numerator - denominator) - for llh_, manual_llh_ in zip(llh, manual_llh): - assert llh_.item() == approx(manual_llh_) + if HAS_ASSERT_CLOSE: + torch.testing.assert_close(llh, torch.tensor(manual_llh), atol=1e-12, rtol=1e-6) + else: + torch.testing.assert_allclose(llh, torch.tensor(manual_llh), atol=1e-12, rtol=1e-6) def test_reduction_mean(self): crf = make_crf() @@ -219,7 +227,10 @@ def test_reduction_mean(self): denominator = math.log(sum(math.exp(s) for s in all_scores)) manual_llh += numerator - denominator - assert llh.item() == approx(manual_llh / batch_size) + if HAS_ASSERT_CLOSE: + torch.testing.assert_close(llh, manual_llh / batch_size, atol=1e-12, rtol=1e-6) + else: + torch.testing.assert_allclose(llh, manual_llh / batch_size, atol=1e-12, rtol=1e-6) def test_reduction_token_mean(self): crf = make_crf() @@ -258,7 +269,10 @@ def test_reduction_token_mean(self): manual_llh += numerator - denominator n_tokens += seq_len - assert llh.item() == approx(manual_llh / n_tokens) + if HAS_ASSERT_CLOSE: + torch.testing.assert_close(llh, manual_llh / n_tokens, atol=1e-12, rtol=1e-6) + else: + torch.testing.assert_allclose(llh, manual_llh / n_tokens, atol=1e-12, rtol=1e-6) def test_batch_first(self): crf = make_crf()