llama: less E kernels (#16517)

This commit is contained in:
qazal 2026-06-12 18:49:25 +08:00 committed by GitHub
commit 4d34590b7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 14 additions and 2 deletions

View file

@ -51,7 +51,8 @@ def _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, sha
return None, None, dq.uop, dk.uop, dv.uop
return grad
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
# TODO: remove write_flat once scheduler can remove reshapes between custom_kernel. TestCustomKernel.test_simple_reshape
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False, write_flat:bool=False):
assert attn_mask is None, "attn_mask not supported"
assert is_causal, "only causal attention supported"
@ -73,6 +74,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
arch = Device[single_device].renderer.target.arch
attn = _sharded_empty_like(xq, axis=shard_axis)
attn = _sharded_empty((B, N, H * D), xq, axis=shard_axis) if write_flat else _sharded_empty_like(xq, axis=shard_axis)
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
grad = _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch)