Skip to content

Commit

Permalink
[Pallas] Make gmm support bf16 (#7133)
Browse files Browse the repository at this point in the history
Summary:
This pull request does:
1. make gmm support bf16,
2. don't visit_empty_groups for gmm,
3. make the reference gmm torchy.

Test Plan:
python test/test_gmm.py
  • Loading branch information
alanwaketan authored May 29, 2024
1 parent 8531d1c commit fb37312
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 71 deletions.
69 changes: 36 additions & 33 deletions test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,15 @@

class MegabloxTest(unittest.TestCase):

def _reference_gmm(
self,
lhs: np.array,
rhs: np.array,
group_sizes: np.array,
preferred_element_type: np.dtype = np.float32,
) -> np.array:

def _reference_gmm(self, lhs: torch.Tensor, rhs: torch.Tensor,
group_sizes: torch.Tensor) -> np.array:
start = 0
out = []
for i, size in enumerate(group_sizes):
result = np.dot(lhs[start:start + size, :], rhs[i, :, :])

result = result.astype(preferred_element_type)
result = lhs[start:start + size, :] @ rhs[i, :, :]
out.append(result)
start += group_sizes[i]
return np.array(np.concatenate(out, axis=0))
return torch.cat(out)

def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor:
# Randomly sample the ends of the groups in the m-dimension. Let the fuzzer
Expand All @@ -57,15 +49,6 @@ def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor:
starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final])
return torch.from_numpy(ends - starts).to(torch.int32)

def _tolerances(self, lhs_dtype: torch.dtype, rhs_dtype: torch.dtype,
out_dtype: torch.dtype) -> tuple[float, float]:
if (lhs_dtype == torch.bfloat16 or rhs_dtype == torch.bfloat16 or
out_dtype == torch.bfloat16):
return 1e-3, 1e-2 # atol, rtol
return 1e-4, 1e-2 # atol, rtol

LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]]

def _init_test_cases(self):
self.tests_cases = []
self.tests_cases.append({
Expand Down Expand Up @@ -100,6 +83,7 @@ def _init_test_cases(self):
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm(self):
met.clear_all()
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)

self._init_test_cases()
for test_case in self.tests_cases:
Expand All @@ -108,21 +92,39 @@ def test_gmm(self):
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']
out_dtype = torch.float32

lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla')
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla')
group_sizes = self._group_sizes_strategy(
m=m, num_groups=num_groups).to('xla')
out = gmm(lhs, rhs, group_sizes)
lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_bf16(self):
met.clear_all()

self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

ref_out = self._reference_gmm(lhs.cpu().float().numpy(),
rhs.cpu().float().numpy(),
group_sizes.cpu().numpy())
out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))

atol, rtol = self._tolerances(lhs_dtype, rhs_dtype, out_dtype)
np.testing.assert_allclose(
ref_out, np.array(out.cpu()), rtol=rtol, atol=atol)
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
Expand Down Expand Up @@ -183,6 +185,7 @@ def test_make_group_metadata(self):
group_sizes=torch.tensor(test_grid['group_sizes']).to("xla"),
m=test_grid['m'],
tm=test_grid['tm'],
visit_empty_groups=True,
)

for i in range(len(jax_meta)):
Expand Down
81 changes: 43 additions & 38 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,38 +56,43 @@ def jax_import_guard():
torch_xla._XLAC._init_computation_client()


def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct":
def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype":
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax
import jax.numpy as jnp

def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
if _XLA_USE_BF16:
return jnp.bfloat16
return jnp.float32
elif dtype == torch.float64:
if _XLA_USE_BF16:
return jnp.bfloat16
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
elif dtype == torch.bfloat16:
if dtype == torch.float32:
if _XLA_USE_BF16:
return jnp.bfloat16
elif dtype == torch.int32:
return jnp.int32
elif dtype == torch.int64:
return jnp.int64
elif dtype == torch.int16:
return jnp.int16
elif dtype == torch.int8:
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype}")
return jnp.float32
elif dtype == torch.float64:
if _XLA_USE_BF16:
return jnp.bfloat16
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
elif dtype == torch.bfloat16:
return jnp.bfloat16
elif dtype == torch.int32:
return jnp.int32
elif dtype == torch.int64:
return jnp.int64
elif dtype == torch.int16:
return jnp.int16
elif dtype == torch.int8:
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype}")


def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct":
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax

return jax.ShapeDtypeStruct(tensor.shape,
convert_torch_dtype_to_jax(tensor.dtype))
Expand Down Expand Up @@ -518,7 +523,7 @@ def _make_group_metadata(
group_sizes: torch.Tensor,
m: int,
tm: int,
visit_empty_groups: bool = True,
visit_empty_groups: bool,
) -> Any:
"""Create the metadata needed for grouped matmul computation.
Expand Down Expand Up @@ -734,31 +739,31 @@ def gmm(

m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[2]
tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n)
preferred_element_type = lhs.dtype

payload, _ = trace_pallas(
gmm,
lhs,
rhs,
group_sizes,
static_argnames=["tiling"],
static_argnames=["tiling", "preferred_element_type"],
preferred_element_type=convert_torch_dtype_to_jax(preferred_element_type),
tiling=(tm, tk, tn))

# Create the metadata we need for computation.
# TODO (alanwaketan): The following assuumes groups_sizes is a cpu tensor.
# That means we need to materialize this input in order to use this gmm
# kernel, and that will introduce graph breaks in the computation.
# Create the metadata we need for computation, and that's why need to separate
# the tracing and execution part.
group_offsets, group_ids, m_tile_ids, num_tiles = _make_group_metadata(
group_sizes=group_sizes,
m=m,
tm=tm,
visit_empty_groups=False,
)
group_offset_torch = torch.tensor([0], dtype=torch.int32).to("xla")
group_offset_torch = torch.tensor([0], dtype=torch.int32).to(lhs.device)

return torch_xla._XLAC._xla_tpu_custom_call([
num_tiles.to("xla"),
group_offsets.to("xla"),
group_ids.to("xla"),
m_tile_ids.to("xla"), group_offset_torch, lhs, rhs
], payload, [torch.Size([m, n])], [lhs.dtype])[0]
num_tiles, group_offsets, group_ids, m_tile_ids, group_offset_torch, lhs,
rhs
], payload, [torch.Size([m, n])], [preferred_element_type])[0]


def non_xla_attetion(q, k, v, attention_type):
Expand Down

0 comments on commit fb37312

Please sign in to comment.