Skip to content

Commit

Permalink
Support logits_soft_cap parameter in paged_attention (#7704)
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jul 17, 2024
1 parent 44f88a9 commit 2a20d1e
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 6 deletions.
70 changes: 70 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,76 @@ def paged_attention_wrapper(q, k, v, seq_lens, page_indices,
atol=1e-5,
rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper_with_attn_logits_soft_cap(self):
# TODO: enable checking TPU accelerator types.
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention

max_kv_len = 2048
block_size = 512
page_size = 64
num_kv_heads = 8
q_kv_head_ratio = 8
head_dim = 256
dtype = torch.float32
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
)

q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")

outputs = []
for attn_logits_soft_cap in [1.0, None]:
outputs.append(
paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
attn_logits_soft_cap=attn_logits_soft_cap))

q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_outputs = []
for attn_logits_soft_cap in [1.0, None]:
expected_outputs.append(
torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
seq_lens_jax,
page_indices_jax,
pages_per_compute_block=block_size // page_size,
attn_logits_soft_cap=attn_logits_soft_cap))))

for output, expected_output in zip(outputs, expected_outputs):
self.assertTrue(
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-5,
rtol=1e-5))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_wrapper_segment_ids_1(self):
Expand Down
19 changes: 13 additions & 6 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def paged_attention(q,
lengths,
page_indices,
pages_per_compute_block,
megacore_mode: str = None):
megacore_mode: str = None,
attn_logits_soft_cap: float = None):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
Expand All @@ -458,7 +459,10 @@ def paged_attention(q,
page_indices,
pages_per_compute_block=pages_per_compute_block,
megacore_mode=megacore_mode,
static_argnames=["pages_per_compute_block", "megacore_mode"],
attn_logits_soft_cap=attn_logits_soft_cap,
static_argnames=[
"pages_per_compute_block", "megacore_mode", "attn_logits_soft_cap"
],
)

batch_size, num_heads, head_dim = q.shape
Expand Down Expand Up @@ -874,7 +878,7 @@ def flash_attention_non_xla(q: torch.Tensor,


XLA_LIB.define(
"paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block, str megacore_mode=None) -> Tensor",
"paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block, str megacore_mode=None, float attn_logits_soft_cap=None) -> Tensor",
)


Expand All @@ -885,9 +889,11 @@ def paged_attention_xla(q: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: str = None):
megacore_mode: str = None,
attn_logits_soft_cap: float = None):
return paged_attention(q, k_pages, v_pages, lengths, page_indices,
pages_per_compute_block, megacore_mode)
pages_per_compute_block, megacore_mode,
attn_logits_soft_cap)


@impl(XLA_LIB, "paged_attention", "CompositeExplicitAutograd")
Expand All @@ -897,7 +903,8 @@ def paged_attention_non_xla(q: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: str = None):
megacore_mode: str = None,
attn_logits_soft_cap: float = None):
return non_xla_attetion(q, k_pages, v_pages, "paged")


Expand Down

0 comments on commit 2a20d1e

Please sign in to comment.