Skip to content

Commit

Permalink
Add 1-layer gradient accumulation test to check aliasing (#7692)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws committed Aug 1, 2024
1 parent bc9b606 commit a5520e5
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import unittest
import copy


# TODO(alanwaketan): add test for views.
Expand Down Expand Up @@ -82,6 +85,52 @@ def test_aliasing_with_multiple_inplace_update(self):
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
torch.allclose(k_cache[slot_mapping[0][0]].cpu(), key[0].cpu())

def test_grad_accum(self):

class MLP(nn.Module):

def __init__(self, input_size=28 * 28, output_size=10):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, output_size, bias=False)

def forward(self, x):
x = self.fc1(x)
return F.log_softmax(x, dim=1)

def try_grad_accum(model, device, train_x, train_label, accum_steps):
loss_fn = nn.NLLLoss()
train_x = train_x.to(device)
train_label = train_label.to(device)
model.zero_grad()
for i in range(accum_steps):
output = model(train_x)
t_loss = loss_fn(output, train_label)
t_loss.backward()
xm.mark_step()
return [p.grad.to('cpu').numpy() for p in model.parameters()]

dev = xm.xla_device()
train_x_sample = torch.rand((1, 28 * 28))
train_label_sample = torch.tensor([5])
c_model = MLP().to('cpu')
t_model = copy.deepcopy(c_model).to(dev)
t_model.train()
c_model.train()
accum_steps = 4
c_grads_5 = try_grad_accum(c_model, 'cpu', train_x_sample,
train_label_sample, accum_steps)
met.clear_metrics()
t_grads_5 = try_grad_accum(t_model, dev, train_x_sample, train_label_sample,
accum_steps)
torch.testing.assert_close(t_grads_5, c_grads_5, rtol=3e-2, atol=1e-3)
graph_count, alias_count, _ = met.metric_data("InputOutputAliasCount")
assert (
graph_count == 2
), f"Expect 2 graphs for gradient accumulation test, got {graph_count}"
assert (
alias_count == 1.0
), f"Expect 1 input-output alias pair for gradient accumulation, got {alias_count}"


if __name__ == '__main__':
test = unittest.main()
Expand Down

0 comments on commit a5520e5

Please sign in to comment.