diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index b2c5fc50b21..e90c74201d6 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -1,10 +1,13 @@ import sys import torch +import torch.nn as nn +import torch.nn.functional as F import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import unittest +import copy # TODO(alanwaketan): add test for views. @@ -82,6 +85,52 @@ def test_aliasing_with_multiple_inplace_update(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) torch.allclose(k_cache[slot_mapping[0][0]].cpu(), key[0].cpu()) + def test_grad_accum(self): + + class MLP(nn.Module): + + def __init__(self, input_size=28 * 28, output_size=10): + super(MLP, self).__init__() + self.fc1 = nn.Linear(input_size, output_size, bias=False) + + def forward(self, x): + x = self.fc1(x) + return F.log_softmax(x, dim=1) + + def try_grad_accum(model, device, train_x, train_label, accum_steps): + loss_fn = nn.NLLLoss() + train_x = train_x.to(device) + train_label = train_label.to(device) + model.zero_grad() + for i in range(accum_steps): + output = model(train_x) + t_loss = loss_fn(output, train_label) + t_loss.backward() + xm.mark_step() + return [p.grad.to('cpu').numpy() for p in model.parameters()] + + dev = xm.xla_device() + train_x_sample = torch.rand((1, 28 * 28)) + train_label_sample = torch.tensor([5]) + c_model = MLP().to('cpu') + t_model = copy.deepcopy(c_model).to(dev) + t_model.train() + c_model.train() + accum_steps = 4 + c_grads_5 = try_grad_accum(c_model, 'cpu', train_x_sample, + train_label_sample, accum_steps) + met.clear_metrics() + t_grads_5 = try_grad_accum(t_model, dev, train_x_sample, train_label_sample, + accum_steps) + torch.testing.assert_close(t_grads_5, c_grads_5, rtol=3e-2, atol=1e-3) + graph_count, alias_count, _ = met.metric_data("InputOutputAliasCount") + assert ( + graph_count == 2 + ), f"Expect 2 graphs for gradient accumulation test, got {graph_count}" + assert ( + alias_count == 1.0 + ), f"Expect 1 input-output alias pair for gradient accumulation, got {alias_count}" + if __name__ == '__main__': test = unittest.main()