diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index d315c94a0b3..ed015f73374 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -129,8 +129,10 @@ def trace_pallas(kernel: Callable, global trace_pallas_arg_to_payload # implcit assumption here that everything in kwargs is hashable and not a tensor, # which is true for the gmm and tgmm. - hash_key = (kernel, static_argnums, tuple(static_argnames), tuple(jax_args), - repr(sorted(kwargs.items())).encode()) + hash_key = (jax.config.jax_default_matmul_precision, kernel, static_argnums, + tuple(static_argnames) + if static_argnames is not None else static_argnames, + tuple(jax_args), repr(sorted(kwargs.items())).encode()) if hash_key in trace_pallas_arg_to_payload: torch_xla._XLAC._xla_increment_counter('trace_pallas_cache_hit', 1) return trace_pallas_arg_to_payload[hash_key], tensor_args @@ -287,7 +289,9 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]), min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]), False, - static_argnums=range(5, 13)) + static_argnums=range(5, 13), + use_cache=True, + ) args = [q, k, v] if ab is not None: @@ -386,7 +390,9 @@ def backward(ctx, grad_output): static_argnames=[ "block_q_major", "block_k_major", "block_k", "sm_scale", "causal", "mask_value", "debug" - ]) + ], + use_cache=True, + ) args = [q, k, v] if ab is not None: @@ -435,7 +441,8 @@ def backward(ctx, grad_output): static_argnames=[ "block_q_major", "block_k_major", "block_k", "block_q", "sm_scale", "causal", "mask_value", "debug" - ]) + ], + use_cache=True) grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload, [k.shape, v.shape],