Skip to content

Commit

Permalink
Support megacore_mode in paged_attention (#7060)
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed May 14, 2024
1 parent f26c35c commit cbb9e21
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 11 deletions.
69 changes: 69 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,75 @@ def test_paged_attention_wrapper(self):
atol=1e-5,
rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper_with_megacore_modes(self):
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention

max_kv_len = 2048
block_size = 512
page_size = 64
num_kv_heads = 8
q_kv_head_ratio = 8
head_dim = 256
dtype = torch.float32
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

outputs = []
for megacore_mode in ['kv_head', 'batch', None]:
outputs.append(
paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
megacore_mode=megacore_mode))

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_outputs = []
for megacore_mode in ['kv_head', 'batch', None]:
expected_outputs.append(
torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
seq_lens_jax,
page_indices_jax,
pages_per_compute_block=block_size // page_size,
megacore_mode=megacore_mode))))

for output, expected_output in zip(outputs, expected_outputs):
self.assertTrue(
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-5,
rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper_with_dynamo(self):
Expand Down
38 changes: 27 additions & 11 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,13 +432,22 @@ def flash_attention(
sm_scale, partition_spec, mesh)


def paged_attention(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block):
def paged_attention(q,
k_pages,
v_pages,
lengths,
page_indices,
pages_per_compute_block,
megacore_mode: str = None):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention

assert megacore_mode in [
"kv_head", "batch", None
], "megacore_mode must be one of ['kv_head', 'batch', None]."

payload, tensor_args = trace_pallas(
paged_attention,
q,
Expand All @@ -447,7 +456,8 @@ def paged_attention(q, k_pages, v_pages, lengths, page_indices,
lengths,
page_indices,
pages_per_compute_block=pages_per_compute_block,
static_argnames=["pages_per_compute_block"],
megacore_mode=megacore_mode,
static_argnames=["pages_per_compute_block", "megacore_mode"],
)

batch_size, num_heads, head_dim = q.shape
Expand Down Expand Up @@ -512,22 +522,28 @@ def flash_attention_non_xla(q: torch.Tensor,


XLA_LIB.define(
"paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block) -> Tensor",
"paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block, str megacore_mode=None) -> Tensor",
)


@impl(XLA_LIB, "paged_attention", "XLA")
def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
def paged_attention_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
pages_per_compute_block: int,
megacore_mode: str = None):
return paged_attention(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block)
pages_per_compute_block, megacore_mode)


@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd")
def paged_attention_non_xla(q: torch.Tensor, k_pages: torch.Tensor,
v_pages: torch.Tensor, lengths: torch.Tensor,
def paged_attention_non_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int):
pages_per_compute_block: int,
megacore_mode: str = None):
return non_xla_attetion(q, k_pages, v_pages, "paged")

0 comments on commit cbb9e21

Please sign in to comment.