Skip to content

Commit

Permalink
[Pallas] Introduce GMM(torch.autograd.Function) (#7152)
Browse files Browse the repository at this point in the history
Summary:
This pull request make GMM as a torch.autograd.Function such that we can use torch.autograd.backward instead of manual backpropagation.

Test Plan:
python test/test_gmm.py
  • Loading branch information
alanwaketan authored May 30, 2024
1 parent af51f06 commit aeed61a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
77 changes: 76 additions & 1 deletion test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward, GMM
from torch_xla import runtime as xr
from torch_xla._internal import tpu

Expand Down Expand Up @@ -374,6 +374,81 @@ def test_gmm_backward(self):
# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_backward_2(self):
self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

torch.manual_seed(42)
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
lhs.retain_grad()
rhs.retain_grad()

ref_out = self._reference_gmm(lhs, rhs, group_sizes)
ref_out.sum().backward()

torch.manual_seed(42)
lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla")
rhs_xla = torch.rand(
num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla")
lhs_xla.retain_grad()
rhs_xla.retain_grad()

out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla"))
out.sum().backward()

self.assertTrue(torch.allclose(ref_out, out.cpu()))
self.assertTrue(torch.allclose(lhs.grad, lhs_xla.grad.cpu()))
self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_backward_3(self):
self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

torch.manual_seed(42)
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
lhs.retain_grad()
rhs.retain_grad()

ref_out = self._reference_gmm(lhs, rhs, group_sizes)
ref_out.sum().backward()

torch.manual_seed(42)
lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla")
rhs_xla = torch.rand(
num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla")
lhs_xla.retain_grad()
rhs_xla.retain_grad()

out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla"))
grad_out = torch.ones_like(out)
torch.autograd.backward([out], [grad_out, lhs_xla, rhs_xla])

self.assertTrue(torch.allclose(ref_out, out.cpu()))
self.assertTrue(torch.allclose(lhs.grad, lhs_xla.grad.cpu()))
self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,22 @@ def gmm_backward(grad, lhs, rhs, group_sizes, tiling=(512, 512, 512)):
return grad_lhs, grad_rhs


class GMM(torch.autograd.Function):

@staticmethod
def forward(ctx, lhs, rhs, group_sizes, tiling=(512, 512, 512)):
ctx.save_for_backward(lhs, rhs, group_sizes)
ctx.tiling = tiling
return gmm(lhs, rhs, group_sizes, tiling)

@staticmethod
def backward(ctx, grad_output):
lhs, rhs, group_sizes = ctx.saved_tensors
grad_lhs, grad_rhs = gmm_backward(grad_output, lhs, rhs, group_sizes,
ctx.tiling)
return grad_lhs, grad_rhs, None, None


def non_xla_attetion(q, k, v, attention_type):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
Expand Down

0 comments on commit aeed61a

Please sign in to comment.