mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: less E kernels (#16517)
This commit is contained in:
parent
12f4cf0e49
commit
4d34590b7d
3 changed files with 14 additions and 2 deletions
|
|
@ -186,7 +186,7 @@ class FlatTransformer:
|
|||
xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
|
||||
if getenv("HK_FLASH_ATTENTION"):
|
||||
from extra.thunder.amd.fa import flash_attention
|
||||
attn, *save = flash_attention(xq, xk, xv, is_causal=True)
|
||||
attn, *save = flash_attention(xq, xk, xv, is_causal=True, write_flat=True)
|
||||
saves.extend(save)
|
||||
else:
|
||||
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ def _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, sha
|
|||
return None, None, dq.uop, dk.uop, dv.uop
|
||||
return grad
|
||||
|
||||
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
|
||||
# TODO: remove write_flat once scheduler can remove reshapes between custom_kernel. TestCustomKernel.test_simple_reshape
|
||||
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False, write_flat:bool=False):
|
||||
assert attn_mask is None, "attn_mask not supported"
|
||||
assert is_causal, "only causal attention supported"
|
||||
|
||||
|
|
@ -73,6 +74,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
arch = Device[single_device].renderer.target.arch
|
||||
|
||||
attn = _sharded_empty_like(xq, axis=shard_axis)
|
||||
attn = _sharded_empty((B, N, H * D), xq, axis=shard_axis) if write_flat else _sharded_empty_like(xq, axis=shard_axis)
|
||||
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
|
||||
|
||||
grad = _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch)
|
||||
|
|
|
|||
|
|
@ -277,6 +277,16 @@ class TestCustomKernel(unittest.TestCase):
|
|||
out.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 5)
|
||||
|
||||
def test_simple_reshape(self):
|
||||
a = Tensor.ones(2,3,4).realize()
|
||||
b = Tensor.custom_kernel(Tensor.empty_like(a), a, fxn=custom_add_one_kernel)[0]
|
||||
b2 = b.reshape(2,12)
|
||||
c = Tensor.custom_kernel(Tensor.empty_like(b2), b2, fxn=custom_add_one_kernel)[0]
|
||||
GlobalCounters.reset()
|
||||
c.realize()
|
||||
assert all(i == 3. for i in c.flatten().tolist()), f"all 3 {c.tolist()}"
|
||||
self.assertEqual(GlobalCounters.kernel_count, 3)
|
||||
|
||||
def test_multi_after_schedule_order(self):
|
||||
"""Test correct scheduling order when custom_kernel has multiple outputs.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue