diff --git a/test/test_pallas.py b/test/test_pallas.py index f686816034f..25c487912cf 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -558,6 +558,75 @@ def test_paged_attention_wrapper(self): atol=1e-5, rtol=1e-5)) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, + "This test only works on TPUv4+.") + def test_paged_attention_wrapper_with_megacore_modes(self): + from torch_xla.experimental.custom_kernel import paged_attention + from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention + + max_kv_len = 2048 + block_size = 512 + page_size = 64 + num_kv_heads = 8 + q_kv_head_ratio = 8 + head_dim = 256 + dtype = torch.float32 + seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) + + q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( + seq_lens, + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + ) + + q_xla = q.to("xla") + k_pages_xla = k_pages.to("xla") + v_pages_xla = v_pages.to("xla") + seq_lens_xla = seq_lens.to("xla") + page_indices_xla = page_indices.to("xla") + + outputs = [] + for megacore_mode in ['kv_head', 'batch', None]: + outputs.append( + paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + seq_lens_xla, + page_indices_xla, + pages_per_compute_block=block_size // page_size, + megacore_mode=megacore_mode)) + + q_jax = jnp.array(q.numpy(), dtype=jnp.float32) + k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32) + v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32) + seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32) + page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) + expected_outputs = [] + for megacore_mode in ['kv_head', 'batch', None]: + expected_outputs.append( + torch.from_numpy( + np.array( + jax_paged_attention( + q_jax, + k_pages_jax, + v_pages_jax, + seq_lens_jax, + page_indices_jax, + pages_per_compute_block=block_size // page_size, + megacore_mode=megacore_mode)))) + + for output, expected_output in zip(outputs, expected_outputs): + self.assertTrue( + torch.allclose( + output.cpu()[seq_lens > 0], + expected_output.cpu()[seq_lens > 0], + atol=1e-5, + rtol=1e-5)) + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") def test_paged_attention_wrapper_with_dynamo(self): diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 7e4f2b4c3b3..3436ff18d07 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -432,13 +432,22 @@ def flash_attention( sm_scale, partition_spec, mesh) -def paged_attention(q, k_pages, v_pages, lengths, page_indices, - pages_per_compute_block): +def paged_attention(q, + k_pages, + v_pages, + lengths, + page_indices, + pages_per_compute_block, + megacore_mode: str = None): # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention + assert megacore_mode in [ + "kv_head", "batch", None + ], "megacore_mode must be one of ['kv_head', 'batch', None]." + payload, tensor_args = trace_pallas( paged_attention, q, @@ -447,7 +456,8 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices, lengths, page_indices, pages_per_compute_block=pages_per_compute_block, - static_argnames=["pages_per_compute_block"], + megacore_mode=megacore_mode, + static_argnames=["pages_per_compute_block", "megacore_mode"], ) batch_size, num_heads, head_dim = q.shape @@ -512,22 +522,28 @@ def flash_attention_non_xla(q: torch.Tensor, XLA_LIB.define( - "paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor", + "paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block, str megacore_mode=None) -> Tensor", ) @impl(XLA_LIB, "paged_attention", "XLA") -def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor, - v_pages: torch.Tensor, lengths: torch.Tensor, +def paged_attention_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, page_indices: torch.Tensor, - pages_per_compute_block: int): + pages_per_compute_block: int, + megacore_mode: str = None): return paged_attention(q, k_pages, v_pages, lengths, page_indices, - pages_per_compute_block) + pages_per_compute_block, megacore_mode) @impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd") -def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor, - v_pages: torch.Tensor, lengths: torch.Tensor, +def paged_attention_non_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, page_indices: torch.Tensor, - pages_per_compute_block: int): + pages_per_compute_block: int, + megacore_mode: str = None): return non_xla_attetion(q, k_pages, v_pages, "paged")