From 408d6e16f935a8d824aef8f6227427a67dcf2065 Mon Sep 17 00:00:00 2001 From: Kemal Maulana Date: Fri, 29 Dec 2017 20:48:39 +0700 Subject: [PATCH] Backport from 0.4.1 * Initialize parameters ini `__init__` (fixes #1) * Refactor tests * Deprecate `summed` in favor of `reduce` (fixes #2) * Rename setup.cfg to .flake8 --- setup.cfg => .flake8 | 0 README.rst | 5 -- setup.py | 2 +- src/torchcrf/__init__.py | 30 +++++++- tests/test_crf.py | 153 ++++++++++++++++++++------------------- 5 files changed, 108 insertions(+), 82 deletions(-) rename setup.cfg => .flake8 (100%) diff --git a/setup.cfg b/.flake8 similarity index 100% rename from setup.cfg rename to .flake8 diff --git a/README.rst b/README.rst index 5c47f8f..b5036bf 100644 --- a/README.rst +++ b/README.rst @@ -38,11 +38,6 @@ In the examples below, we will assume that these lines have been executed :: >>> emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags), requires_grad=True) >>> tags = torch.autograd.Variable(torch.LongTensor([[0, 1], [2, 4], [3, 1]])) # (seq_length, batch_size) >>> model = CRF(num_tags) - >>> # Initialize model parameters - ... for p in model.parameters(): - ... _ = torch.nn.init.uniform(p, -1, 1) - ... - >>> Forward computation ------------------- diff --git a/setup.py b/setup.py index e93be27..c8ae135 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup(name='pytorch-crf', - version='0.3.1', + version='0.3.2', description='Conditional random field in PyTorch', long_description=readme, url='https://github.com/kmkurn/pytorch-crf', diff --git a/src/torchcrf/__init__.py b/src/torchcrf/__init__.py index 82886cb..a9ffbe3 100644 --- a/src/torchcrf/__init__.py +++ b/src/torchcrf/__init__.py @@ -1,4 +1,5 @@ from typing import List, Optional, Union +import warnings from torch.autograd import Variable import torch @@ -47,11 +48,24 @@ def __init__(self, num_tags: int) -> None: self.end_transitions = nn.Parameter(torch.Tensor(num_tags)) self.transitions = nn.Parameter(torch.Tensor(num_tags, num_tags)) + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize the transition parameters. + + The parameters will be initialized randomly from a uniform distribution + between -0.1 and 0.1. + """ + nn.init.uniform(self.start_transitions, -0.1, 0.1) + nn.init.uniform(self.end_transitions, -0.1, 0.1) + nn.init.uniform(self.transitions, -0.1, 0.1) + def forward(self, emissions: Variable, tags: Variable, mask: Optional[Variable] = None, - summed: bool = True) -> Variable: + reduce: bool = True, + **kwargs) -> Variable: """Compute the log likelihood of the given sequence of tags and emission score. Arguments @@ -62,7 +76,7 @@ def forward(self, Sequence of tags as ``LongTensor`` of size ``(seq_length, batch_size)``. mask : :class:`~torch.autograd.Variable`, optional Mask tensor as ``ByteTensor`` of size ``(seq_length, batch_size)``. - summed : bool + reduce : bool Whether to sum the log likelihood over the batch. Returns @@ -94,13 +108,23 @@ def forward(self, if not all(mask[0].data): raise ValueError('mask of the first timestep must all be on') + if 'summed' in kwargs: + msg = "keyword argument 'summed' is deprecated and will be removed in "\ + "future versions, please use 'reduce' instead" + warnings.warn(msg, DeprecationWarning, stacklevel=3) + reduce = kwargs.pop('summed') + + if kwargs: + raise TypeError( + f"'{kwargs.popitem()[0]}' is an invalid keyword argument for this function") + if mask is None: mask = Variable(self._new(*tags.size()).fill_(1)) numerator = self._compute_joint_llh(emissions, tags, mask) denominator = self._compute_log_partition_function(emissions, mask) llh = numerator - denominator - return torch.sum(llh) if summed else llh + return llh if not reduce else torch.sum(llh) def decode(self, emissions: Union[Variable, torch.FloatTensor], diff --git a/tests/test_crf.py b/tests/test_crf.py index a45c67e..17128e8 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -17,12 +17,6 @@ torch.manual_seed(RANDOM_SEED) -def initialize(crf): - nn.init.uniform(crf.start_transitions.data, -5, 5) - nn.init.uniform(crf.end_transitions.data, -5, 5) - nn.init.uniform(crf.transitions.data, -5, 5) - - def compute_score(crf, emission, tag): assert len(emission) == len(tag) @@ -36,6 +30,22 @@ def compute_score(crf, emission, tag): return score +def make_crf(num_tags=5): + return CRF(num_tags) + + +def make_emissions(seq_length=3, batch_size=2, num_tags=5): + return torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags), + requires_grad=True) + + +def make_tags(seq_length=3, batch_size=2, num_tags=5): + return torch.autograd.Variable(torch.LongTensor([ + [random.randrange(num_tags) for b in range(batch_size)] + for _ in range(seq_length) + ])) + + class TestInit(object): def test_minimal(self): num_tags = 10 @@ -57,15 +67,10 @@ def test_nonpositive_num_tags(self): class TestForward(object): def test_batched_loss_is_correct(self): - seq_length, batch_size, num_tags = 3, 10, 5 - emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags), - requires_grad=True) - tags = torch.autograd.Variable(torch.LongTensor([ - [random.randrange(num_tags) for b in range(batch_size)] - for _ in range(seq_length) - ])) - crf = CRF(num_tags) - initialize(crf) + crf = make_crf() + batch_size = 10 + emissions = make_emissions(batch_size=batch_size, num_tags=crf.num_tags) + tags = make_tags(batch_size=batch_size, num_tags=crf.num_tags) llh = crf(emissions, tags) @@ -80,21 +85,16 @@ def test_batched_loss_is_correct(self): assert llh.data[0] == approx(total_llh.data[0]) def test_works_with_mask(self): - seq_length, batch_size, num_tags = 3, 2, 5 - emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags), - requires_grad=True) - tags = torch.autograd.Variable(torch.LongTensor([ - [random.randrange(num_tags) for b in range(batch_size)] - for _ in range(seq_length) - ])) + crf = make_crf() + seq_length, batch_size = 3, 2 + emissions = make_emissions(seq_length, batch_size, crf.num_tags) + tags = make_tags(seq_length, batch_size, crf.num_tags) # mask should be (seq_length, batch_size) mask = torch.autograd.Variable(torch.ByteTensor([ [1, 1], [1, 1], [1, 0], ])) - crf = CRF(num_tags) - initialize(crf) llh = crf(emissions, tags, mask=mask) @@ -109,7 +109,7 @@ def test_works_with_mask(self): emission, tag = emission[:seq_len], tag[:seq_len] numerator = compute_score(crf, emission, tag) all_scores = [compute_score(crf, emission, t) - for t in itertools.product(range(num_tags), repeat=seq_len)] + for t in itertools.product(range(crf.num_tags), repeat=seq_len)] denominator = math.log(sum(math.exp(s) for s in all_scores)) manual_llh += numerator - denominator # Assert equal to manual log likelihood @@ -118,15 +118,10 @@ def test_works_with_mask(self): llh.backward() def test_works_without_mask(self): - seq_length, batch_size, num_tags = 3, 2, 5 - emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags), - requires_grad=True) - tags = torch.autograd.Variable(torch.LongTensor([ - [random.randrange(num_tags) for b in range(batch_size)] - for _ in range(seq_length) - ])) - crf = CRF(num_tags) - initialize(crf) + crf = make_crf() + emissions = make_emissions(num_tags=crf.num_tags) + tags = make_tags(num_tags=crf.num_tags) + seq_length, batch_size = tags.size() llh_no_mask = crf(emissions, tags) # No mask means the mask is all ones @@ -136,17 +131,12 @@ def test_works_without_mask(self): assert llh_no_mask.data[0] == approx(llh_mask.data[0]) def test_not_summed_over_batch(self): - seq_length, batch_size, num_tags = 3, 2, 5 - emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags), - requires_grad=True) - tags = torch.autograd.Variable(torch.LongTensor([ - [random.randrange(num_tags) for b in range(batch_size)] - for _ in range(seq_length) - ])) - crf = CRF(num_tags) - initialize(crf) + crf = make_crf() + emissions = make_emissions(num_tags=crf.num_tags) + tags = make_tags(num_tags=crf.num_tags) + seq_length, batch_size = tags.size() - llh = crf(emissions, tags, summed=False) + llh = crf(emissions, tags, reduce=False) assert isinstance(llh, torch.autograd.Variable) assert llh.size() == (batch_size,) @@ -158,7 +148,7 @@ def test_not_summed_over_batch(self): for emission, tag in zip(emissions.data, tags.data): numerator = compute_score(crf, emission, tag) all_scores = [compute_score(crf, emission, t) - for t in itertools.product(range(num_tags), repeat=seq_length)] + for t in itertools.product(range(crf.num_tags), repeat=seq_length)] denominator = math.log(sum(math.exp(s) for s in all_scores)) manual_llh.append(numerator - denominator) @@ -168,7 +158,7 @@ def test_not_summed_over_batch(self): def test_emissions_has_bad_number_of_dimension(self): emissions = torch.autograd.Variable(torch.randn(1, 2), requires_grad=True) tags = torch.autograd.Variable(torch.LongTensor(2, 2)) - crf = CRF(10) + crf = make_crf() with pytest.raises(ValueError) as excinfo: crf(emissions, tags) @@ -177,7 +167,7 @@ def test_emissions_has_bad_number_of_dimension(self): def test_tags_has_bad_number_of_dimension(self): emissions = torch.autograd.Variable(torch.randn(1, 2, 3), requires_grad=True) tags = torch.autograd.Variable(torch.LongTensor(2, 2, 2)) - crf = CRF(3) + crf = make_crf(3) with pytest.raises(ValueError) as excinfo: crf(emissions, tags) @@ -186,7 +176,7 @@ def test_tags_has_bad_number_of_dimension(self): def test_emissions_and_tags_size_mismatch(self): emissions = torch.autograd.Variable(torch.randn(1, 2, 3), requires_grad=True) tags = torch.autograd.Variable(torch.LongTensor(2, 2)) - crf = CRF(3) + crf = make_crf(3) with pytest.raises(ValueError) as excinfo: crf(emissions, tags) @@ -196,7 +186,7 @@ def test_emissions_and_tags_size_mismatch(self): def test_emissions_last_dimension_not_equal_to_number_of_tags(self): emissions = torch.autograd.Variable(torch.randn(1, 2, 3), requires_grad=True) tags = torch.autograd.Variable(torch.LongTensor(1, 2)) - crf = CRF(10) + crf = make_crf(10) with pytest.raises(ValueError) as excinfo: crf(emissions, tags) @@ -206,7 +196,7 @@ def test_mask_and_tags_size_mismatch(self): emissions = torch.autograd.Variable(torch.randn(1, 2, 3), requires_grad=True) tags = torch.autograd.Variable(torch.LongTensor(1, 2)) mask = torch.autograd.Variable(torch.ByteTensor([[1], [1]])) - crf = CRF(3) + crf = make_crf(3) with pytest.raises(ValueError) as excinfo: crf(emissions, tags, mask=mask) @@ -218,19 +208,39 @@ def test_first_timestep_mask_is_not_all_on(self): emissions = torch.autograd.Variable(torch.randn(1, 2, 3), requires_grad=True) tags = torch.autograd.Variable(torch.LongTensor(1, 2)) mask = torch.autograd.Variable(torch.ByteTensor([[0, 1]])) - crf = CRF(3) + crf = make_crf(3) with pytest.raises(ValueError) as excinfo: crf(emissions, tags, mask=mask) assert 'mask of the first timestep must all be on' in str(excinfo.value) + def test_warning_when_kwarg_summed_is_used(self, recwarn): + crf = make_crf() + emissions = make_emissions(num_tags=crf.num_tags) + tags = make_tags(num_tags=crf.num_tags) + + crf(emissions, tags, summed=False) + + w = recwarn.pop(DeprecationWarning) + msg = "keyword argument 'summed' is deprecated and will be removed in "\ + "future versions, please use 'reduce' instead" + assert msg in str(w.message) + + def test_unknown_kwargs(self): + crf = make_crf() + emissions = make_emissions(num_tags=crf.num_tags) + tags = make_tags(num_tags=crf.num_tags) + + with pytest.raises(TypeError) as excinfo: + crf(emissions, tags, foo='foo') + assert "'foo' is an invalid keyword argument for this function" in str(excinfo.value) + class TestDecode(object): def test_batched_decode_is_correct(self): - seq_length, batch_size, num_tags = 3, 10, 5 - emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags)) - crf = CRF(num_tags) - initialize(crf) + crf = make_crf() + batch_size = 10 + emissions = make_emissions(batch_size=batch_size, num_tags=crf.num_tags) best_tags = crf.decode(emissions) @@ -240,10 +250,9 @@ def test_batched_decode_is_correct(self): assert best_tags[i] == crf.decode(emissions_)[0] def test_works_without_mask(self): - seq_length, batch_size, num_tags = 3, 2, 5 - emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags)) - crf = CRF(num_tags) - initialize(crf) + crf = make_crf() + emissions = make_emissions(num_tags=crf.num_tags) + seq_length = emissions.size(0) best_tags = crf.decode(emissions) @@ -251,21 +260,20 @@ def test_works_without_mask(self): emissions = emissions.transpose(0, 1) # Compute best tag manually for emission, best_tag in zip(emissions.data, best_tags): - manual_best_tag = max(itertools.product(range(num_tags), repeat=seq_length), + manual_best_tag = max(itertools.product(range(crf.num_tags), repeat=seq_length), key=lambda t: compute_score(crf, emission, t)) assert tuple(best_tag) == manual_best_tag def test_works_with_mask(self): - seq_length, batch_size, num_tags = 3, 2, 5 - emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags)) + crf = make_crf() + seq_length, batch_size = 3, 2 + emissions = make_emissions(seq_length, batch_size, crf.num_tags) # mask should be (seq_length, batch_size) mask = torch.autograd.Variable(torch.ByteTensor([ [1, 1], [1, 1], [1, 0], ])) - crf = CRF(num_tags) - initialize(crf) best_tags = crf.decode(emissions, mask=mask) @@ -277,14 +285,15 @@ def test_works_with_mask(self): seq_len = mask_.sum() assert len(best_tag) == seq_len emission = emission[:seq_len] - manual_best_tag = max(itertools.product(range(num_tags), repeat=seq_len), + manual_best_tag = max(itertools.product(range(crf.num_tags), repeat=seq_len), key=lambda t: compute_score(crf, emission, t)) assert tuple(best_tag) == manual_best_tag def test_works_with_tensor(self): - seq_length, batch_size, num_tags = 3, 2, 5 - emissions = torch.randn(seq_length, batch_size, num_tags) - emissions_var = torch.autograd.Variable(emissions) + crf = make_crf() + seq_length, batch_size = 3, 2 + emissions_var = make_emissions(seq_length, batch_size, crf.num_tags) + emissions = emissions_var.data # mask should be (seq_length, batch_size) mask = torch.ByteTensor([ [1, 1], @@ -292,8 +301,6 @@ def test_works_with_tensor(self): [1, 0], ]) mask_var = torch.autograd.Variable(mask) - crf = CRF(num_tags) - initialize(crf) best_tags = crf.decode(emissions, mask=mask) best_tags_var = crf.decode(emissions_var, mask=mask_var) @@ -302,7 +309,7 @@ def test_works_with_tensor(self): def test_emissions_has_bad_number_of_dimension(self): emissions = torch.autograd.Variable(torch.randn(1, 2)) - crf = CRF(3) + crf = make_crf() with pytest.raises(ValueError) as excinfo: crf.decode(emissions) @@ -310,7 +317,7 @@ def test_emissions_has_bad_number_of_dimension(self): def test_emissions_last_dimension_not_equal_to_number_of_tags(self): emissions = torch.autograd.Variable(torch.randn(1, 2, 3)) - crf = CRF(10) + crf = make_crf(10) with pytest.raises(ValueError) as excinfo: crf.decode(emissions) @@ -319,7 +326,7 @@ def test_emissions_last_dimension_not_equal_to_number_of_tags(self): def test_emissions_and_mask_size_mismatch(self): emissions = torch.autograd.Variable(torch.randn(1, 2, 3)) mask = torch.autograd.Variable(torch.ByteTensor([[1, 1], [1, 0]])) - crf = CRF(3) + crf = make_crf(3) with pytest.raises(ValueError) as excinfo: crf.decode(emissions, mask=mask)