llama: less E kernels (#16517)

This commit is contained in:
qazal 2026-06-12 18:49:25 +08:00 committed by GitHub
commit 4d34590b7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 14 additions and 2 deletions

View file

@ -186,7 +186,7 @@ class FlatTransformer:
xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
if getenv("HK_FLASH_ATTENTION"):
from extra.thunder.amd.fa import flash_attention
attn, *save = flash_attention(xq, xk, xv, is_causal=True)
attn, *save = flash_attention(xq, xk, xv, is_causal=True, write_flat=True)
saves.extend(save)
else:
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)

View file

@ -51,7 +51,8 @@ def _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, sha
return None, None, dq.uop, dk.uop, dv.uop
return grad
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
# TODO: remove write_flat once scheduler can remove reshapes between custom_kernel. TestCustomKernel.test_simple_reshape
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False, write_flat:bool=False):
assert attn_mask is None, "attn_mask not supported"
assert is_causal, "only causal attention supported"
@ -73,6 +74,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
arch = Device[single_device].renderer.target.arch
attn = _sharded_empty_like(xq, axis=shard_axis)
attn = _sharded_empty((B, N, H * D), xq, axis=shard_axis) if write_flat else _sharded_empty_like(xq, axis=shard_axis)
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
grad = _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch)

View file

@ -277,6 +277,16 @@ class TestCustomKernel(unittest.TestCase):
out.realize()
self.assertEqual(GlobalCounters.kernel_count, 5)
def test_simple_reshape(self):
a = Tensor.ones(2,3,4).realize()
b = Tensor.custom_kernel(Tensor.empty_like(a), a, fxn=custom_add_one_kernel)[0]
b2 = b.reshape(2,12)
c = Tensor.custom_kernel(Tensor.empty_like(b2), b2, fxn=custom_add_one_kernel)[0]
GlobalCounters.reset()
c.realize()
assert all(i == 3. for i in c.flatten().tolist()), f"all 3 {c.tolist()}"
self.assertEqual(GlobalCounters.kernel_count, 3)
def test_multi_after_schedule_order(self):
"""Test correct scheduling order when custom_kernel has multiple outputs.