From aeed61a94cbe32099ee4e896ca6afd37d4ac9456 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 29 May 2024 22:20:42 -0700 Subject: [PATCH] [Pallas] Introduce GMM(torch.autograd.Function) (#7152) Summary: This pull request make GMM as a torch.autograd.Function such that we can use torch.autograd.backward instead of manual backpropagation. Test Plan: python test/test_gmm.py --- test/test_gmm.py | 77 ++++++++++++++++++++++++- torch_xla/experimental/custom_kernel.py | 16 +++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/test/test_gmm.py b/test/test_gmm.py index bf8fdebe24c..b594a85c065 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -7,7 +7,7 @@ import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met -from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward +from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward, GMM from torch_xla import runtime as xr from torch_xla._internal import tpu @@ -374,6 +374,81 @@ def test_gmm_backward(self): # Make sure gmm doesn't fallback. self.assertNotIn("aten::", met.short_metrics_report()) + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_gmm_backward_2(self): + self._init_test_cases() + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = torch.bfloat16 + + torch.manual_seed(42) + lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True) + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + lhs.retain_grad() + rhs.retain_grad() + + ref_out = self._reference_gmm(lhs, rhs, group_sizes) + ref_out.sum().backward() + + torch.manual_seed(42) + lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla") + rhs_xla = torch.rand( + num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") + lhs_xla.retain_grad() + rhs_xla.retain_grad() + + out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla")) + out.sum().backward() + + self.assertTrue(torch.allclose(ref_out, out.cpu())) + self.assertTrue(torch.allclose(lhs.grad, lhs_xla.grad.cpu())) + self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu())) + + # Make sure gmm doesn't fallback. + self.assertNotIn("aten::", met.short_metrics_report()) + + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_gmm_backward_3(self): + self._init_test_cases() + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = torch.bfloat16 + + torch.manual_seed(42) + lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True) + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + lhs.retain_grad() + rhs.retain_grad() + + ref_out = self._reference_gmm(lhs, rhs, group_sizes) + ref_out.sum().backward() + + torch.manual_seed(42) + lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla") + rhs_xla = torch.rand( + num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") + lhs_xla.retain_grad() + rhs_xla.retain_grad() + + out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla")) + grad_out = torch.ones_like(out) + torch.autograd.backward([out], [grad_out, lhs_xla, rhs_xla]) + + self.assertTrue(torch.allclose(ref_out, out.cpu())) + self.assertTrue(torch.allclose(lhs.grad, lhs_xla.grad.cpu())) + self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu())) + + # Make sure gmm doesn't fallback. + self.assertNotIn("aten::", met.short_metrics_report()) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 9bc32a3fc3f..1a8a8cd3852 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -825,6 +825,22 @@ def gmm_backward(grad, lhs, rhs, group_sizes, tiling=(512, 512, 512)): return grad_lhs, grad_rhs +class GMM(torch.autograd.Function): + + @staticmethod + def forward(ctx, lhs, rhs, group_sizes, tiling=(512, 512, 512)): + ctx.save_for_backward(lhs, rhs, group_sizes) + ctx.tiling = tiling + return gmm(lhs, rhs, group_sizes, tiling) + + @staticmethod + def backward(ctx, grad_output): + lhs, rhs, group_sizes = ctx.saved_tensors + grad_lhs, grad_rhs = gmm_backward(grad_output, lhs, rhs, group_sizes, + ctx.tiling) + return grad_lhs, grad_rhs, None, None + + def non_xla_attetion(q, k, v, attention_type): # This will be called when dynamo use fake tensor to construct the fake output. # We need to make sure output tensor's shape is correct.