diff --git a/test/test_gmm.py b/test/test_gmm.py index 08483b0dd84..141c66ca342 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -23,23 +23,15 @@ class MegabloxTest(unittest.TestCase): - def _reference_gmm( - self, - lhs: np.array, - rhs: np.array, - group_sizes: np.array, - preferred_element_type: np.dtype = np.float32, - ) -> np.array: - + def _reference_gmm(self, lhs: torch.Tensor, rhs: torch.Tensor, + group_sizes: torch.Tensor) -> np.array: start = 0 out = [] for i, size in enumerate(group_sizes): - result = np.dot(lhs[start:start + size, :], rhs[i, :, :]) - - result = result.astype(preferred_element_type) + result = lhs[start:start + size, :] @ rhs[i, :, :] out.append(result) start += group_sizes[i] - return np.array(np.concatenate(out, axis=0)) + return torch.cat(out) def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor: # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer @@ -57,15 +49,6 @@ def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor: starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) return torch.from_numpy(ends - starts).to(torch.int32) - def _tolerances(self, lhs_dtype: torch.dtype, rhs_dtype: torch.dtype, - out_dtype: torch.dtype) -> tuple[float, float]: - if (lhs_dtype == torch.bfloat16 or rhs_dtype == torch.bfloat16 or - out_dtype == torch.bfloat16): - return 1e-3, 1e-2 # atol, rtol - return 1e-4, 1e-2 # atol, rtol - - LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] - def _init_test_cases(self): self.tests_cases = [] self.tests_cases.append({ @@ -100,6 +83,7 @@ def _init_test_cases(self): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_gmm(self): met.clear_all() + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) self._init_test_cases() for test_case in self.tests_cases: @@ -108,21 +92,39 @@ def test_gmm(self): m = test_case['m'] n = test_case['n'] lhs_dtype = rhs_dtype = test_case['dtype'] - out_dtype = torch.float32 - lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla') - rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla') - group_sizes = self._group_sizes_strategy( - m=m, num_groups=num_groups).to('xla') - out = gmm(lhs, rhs, group_sizes) + lhs = torch.rand(m, k, dtype=lhs_dtype) + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + ref_out = self._reference_gmm(lhs, rhs, group_sizes) + + out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + self.assertTrue(torch.allclose(ref_out, out.cpu())) + + # Make sure gmm doesn't fallback. + self.assertNotIn("aten::", met.short_metrics_report()) + jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_gmm_bf16(self): + met.clear_all() + + self._init_test_cases() + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = torch.bfloat16 + + lhs = torch.rand(m, k, dtype=lhs_dtype) + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + ref_out = self._reference_gmm(lhs, rhs, group_sizes) - ref_out = self._reference_gmm(lhs.cpu().float().numpy(), - rhs.cpu().float().numpy(), - group_sizes.cpu().numpy()) + out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) - atol, rtol = self._tolerances(lhs_dtype, rhs_dtype, out_dtype) - np.testing.assert_allclose( - ref_out, np.array(out.cpu()), rtol=rtol, atol=atol) + self.assertTrue(torch.allclose(ref_out, out.cpu())) # Make sure gmm doesn't fallback. self.assertNotIn("aten::", met.short_metrics_report()) @@ -183,6 +185,7 @@ def test_make_group_metadata(self): group_sizes=torch.tensor(test_grid['group_sizes']).to("xla"), m=test_grid['m'], tm=test_grid['tm'], + visit_empty_groups=True, ) for i in range(len(jax_meta)): diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 2f528bafae5..c28870e1812 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -56,38 +56,43 @@ def jax_import_guard(): torch_xla._XLAC._init_computation_client() -def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct": +def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype": # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() - import jax import jax.numpy as jnp - def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: - if dtype == torch.float32: - if _XLA_USE_BF16: - return jnp.bfloat16 - return jnp.float32 - elif dtype == torch.float64: - if _XLA_USE_BF16: - return jnp.bfloat16 - return jnp.float64 - elif dtype == torch.float16: - return jnp.float16 - elif dtype == torch.bfloat16: + if dtype == torch.float32: + if _XLA_USE_BF16: return jnp.bfloat16 - elif dtype == torch.int32: - return jnp.int32 - elif dtype == torch.int64: - return jnp.int64 - elif dtype == torch.int16: - return jnp.int16 - elif dtype == torch.int8: - return jnp.int8 - elif dtype == torch.uint8: - return jnp.uint8 - else: - raise ValueError(f"Unsupported dtype: {dtype}") + return jnp.float32 + elif dtype == torch.float64: + if _XLA_USE_BF16: + return jnp.bfloat16 + return jnp.float64 + elif dtype == torch.float16: + return jnp.float16 + elif dtype == torch.bfloat16: + return jnp.bfloat16 + elif dtype == torch.int32: + return jnp.int32 + elif dtype == torch.int64: + return jnp.int64 + elif dtype == torch.int16: + return jnp.int16 + elif dtype == torch.int8: + return jnp.int8 + elif dtype == torch.uint8: + return jnp.uint8 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct": + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax return jax.ShapeDtypeStruct(tensor.shape, convert_torch_dtype_to_jax(tensor.dtype)) @@ -518,7 +523,7 @@ def _make_group_metadata( group_sizes: torch.Tensor, m: int, tm: int, - visit_empty_groups: bool = True, + visit_empty_groups: bool, ) -> Any: """Create the metadata needed for grouped matmul computation. @@ -734,31 +739,31 @@ def gmm( m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[2] tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n) + preferred_element_type = lhs.dtype + payload, _ = trace_pallas( gmm, lhs, rhs, group_sizes, - static_argnames=["tiling"], + static_argnames=["tiling", "preferred_element_type"], + preferred_element_type=convert_torch_dtype_to_jax(preferred_element_type), tiling=(tm, tk, tn)) - # Create the metadata we need for computation. - # TODO (alanwaketan): The following assuumes groups_sizes is a cpu tensor. - # That means we need to materialize this input in order to use this gmm - # kernel, and that will introduce graph breaks in the computation. + # Create the metadata we need for computation, and that's why need to separate + # the tracing and execution part. group_offsets, group_ids, m_tile_ids, num_tiles = _make_group_metadata( group_sizes=group_sizes, m=m, tm=tm, + visit_empty_groups=False, ) - group_offset_torch = torch.tensor([0], dtype=torch.int32).to("xla") + group_offset_torch = torch.tensor([0], dtype=torch.int32).to(lhs.device) return torch_xla._XLAC._xla_tpu_custom_call([ - num_tiles.to("xla"), - group_offsets.to("xla"), - group_ids.to("xla"), - m_tile_ids.to("xla"), group_offset_torch, lhs, rhs - ], payload, [torch.Size([m, n])], [lhs.dtype])[0] + num_tiles, group_offsets, group_ids, m_tile_ids, group_offset_torch, lhs, + rhs + ], payload, [torch.Size([m, n])], [preferred_element_type])[0] def non_xla_attetion(q, k, v, attention_type):