mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: less E kernels (#16517)
This commit is contained in:
parent
12f4cf0e49
commit
4d34590b7d
3 changed files with 14 additions and 2 deletions
|
|
@ -51,7 +51,8 @@ def _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, sha
|
|||
return None, None, dq.uop, dk.uop, dv.uop
|
||||
return grad
|
||||
|
||||
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
|
||||
# TODO: remove write_flat once scheduler can remove reshapes between custom_kernel. TestCustomKernel.test_simple_reshape
|
||||
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False, write_flat:bool=False):
|
||||
assert attn_mask is None, "attn_mask not supported"
|
||||
assert is_causal, "only causal attention supported"
|
||||
|
||||
|
|
@ -73,6 +74,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
arch = Device[single_device].renderer.target.arch
|
||||
|
||||
attn = _sharded_empty_like(xq, axis=shard_axis)
|
||||
attn = _sharded_empty((B, N, H * D), xq, axis=shard_axis) if write_flat else _sharded_empty_like(xq, axis=shard_axis)
|
||||
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
|
||||
|
||||
grad = _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue