mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fa multi fix 2 (#14314)
This commit is contained in:
parent
d9f0ad1d87
commit
d74587f16d
3 changed files with 13 additions and 12 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
|
|
@ -12,10 +13,10 @@ Q_BLOCK_SIZE = 16
|
|||
KV_BLOCK_SIZE = 16
|
||||
|
||||
def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None) -> Tensor:
|
||||
if not isinstance(ref.device, tuple): return Tensor.empty(*shape, device=ref.device)
|
||||
if not isinstance(ref.device, tuple): return Tensor.empty(*shape, dtype=ref.dtype, device=ref.device)
|
||||
shape = tuple(s // len(ref.device) if i == ref.uop.axis else s for i, s in enumerate(shape))
|
||||
axis = ref.uop.axis if axis is None else axis
|
||||
return Tensor(Tensor.empty(*shape, device=ref.device).uop.multi(axis), device=ref.device)
|
||||
return Tensor(Tensor.empty(*shape, dtype=ref.dtype, device=ref.device).uop.multi(axis), dtype=ref.dtype, device=ref.device)
|
||||
|
||||
def _sharded_empty_like(ref:Tensor, axis:int|None=None) -> Tensor:
|
||||
return _sharded_empty(ref.shape, ref, axis)
|
||||
|
|
@ -38,10 +39,12 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
B, N, H, D = xq.shape
|
||||
H_KV = xk.shape[2]
|
||||
GROUP_SIZE = H // H_KV
|
||||
print(f"Flash Attention {B=} {N=} {H=} {D=} {H_KV=} {GROUP_SIZE=}")
|
||||
num_devices = len(xq.device) if isinstance(xq.device, tuple) else 1
|
||||
B_local = B // num_devices
|
||||
if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {D=} {H_KV=} {GROUP_SIZE=}")
|
||||
|
||||
def custom_forward(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp) -> UOp:
|
||||
with Kernel("fa_custom_forward", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
with Kernel("fa_custom_forward", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B_local), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
o, q, k, v, mask, l_vec = GL(ou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker), GL(l_vecu, ker)
|
||||
|
|
@ -139,7 +142,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
return ker.finish()
|
||||
|
||||
def custom_backward_q(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
with Kernel("fa_custom_backward_q", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
with Kernel("fa_custom_backward_q", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B_local), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
dq, do, q, k, v, mask = GL(dqu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker)
|
||||
|
|
@ -229,7 +232,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
return ker.finish()
|
||||
|
||||
def custom_backward_kv(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp):
|
||||
with Kernel("fa_custom_backward_kv", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
with Kernel("fa_custom_backward_kv", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B_local), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
dk, dv, do, q, k, v, mask = GL(dku, ker), GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker)
|
||||
|
|
|
|||
|
|
@ -189,13 +189,11 @@ class Group:
|
|||
self.ker.push_store(c_store, c)
|
||||
return c.after(c_store).reshape(c.shape)
|
||||
|
||||
map_rid = 400
|
||||
def map(self, a:ALL_TILES, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]):
|
||||
a = cast(UOp, a)
|
||||
assert self.warps == 1
|
||||
|
||||
rngs_for_shape = tuple(UOp.range(dim, Group.map_rid + i) for i, dim in enumerate(a.shape))
|
||||
Group.map_rid += len(a.shape)
|
||||
rngs_for_shape = tuple(self.ker.raw_range(dim) for dim in a.shape)
|
||||
|
||||
if op.__code__.co_argcount == 1:
|
||||
to_store = op(a[*rngs_for_shape]) # type: ignore
|
||||
|
|
|
|||
|
|
@ -807,7 +807,7 @@ class TestTK(unittest.TestCase):
|
|||
|
||||
Tensor.manual_seed(42)
|
||||
|
||||
B, N, H, H_KV, D = 1, 1024, 32, 32, 128
|
||||
B, N, H, H_KV, D = 1, 8192, 32, 32, 128
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
|
|
@ -831,7 +831,7 @@ class TestTK(unittest.TestCase):
|
|||
Tensor.realize(q_ref, k_ref, v_ref)
|
||||
|
||||
q_ref_, k_ref_, v_ref_ = q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2)
|
||||
ref = q_ref_.scaled_dot_product_attention(k_ref_, v_ref_, is_causal=True)
|
||||
ref = q_ref_.scaled_dot_product_attention(k_ref_, v_ref_, is_causal=True, enable_gqa=True)
|
||||
ref = ref.float().transpose(1, 2)
|
||||
ref.backward(do)
|
||||
Tensor.realize(q_ref.grad, k_ref.grad, v_ref.grad)
|
||||
|
|
@ -845,7 +845,7 @@ class TestTK(unittest.TestCase):
|
|||
|
||||
Tensor.manual_seed(42)
|
||||
|
||||
B, N, H, H_KV, D = 1, 1024, 32, 32, 128
|
||||
B, N, H, H_KV, D = 1, 8192, 32, 32, 128
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue