From 5bbe5c82bdf5b72577039744d5b21d73ac59d197 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 8 May 2024 12:09:22 -0700 Subject: [PATCH] [Pallas] Improve segment_ids API UX (#7037) Summary: Adds comment for the shape of the segment_ids. Test Plan: Skip CI. --- torch_xla/experimental/custom_kernel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 43e1ab1c128..7e4f2b4c3b3 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -421,8 +421,8 @@ def flash_attention( k, # [batch_size, num_heads, kv_seq_len, d_model] v, # [batch_size, num_heads, kv_seq_len, d_model] causal=False, - q_segment_ids=None, - kv_segment_ids=None, + q_segment_ids=None, # [batch_size, q_seq_len] + kv_segment_ids=None, # [batch_size, kv_seq_len] sm_scale=1.0, *, partition_spec=None,