mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
gemm: fix bf16 atb for mp sharding (#16637)
This commit is contained in:
parent
1cb6b88d37
commit
36f6d1b064
2 changed files with 34 additions and 3 deletions
|
|
@ -2788,15 +2788,21 @@ def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
|
|||
assert M % TILE_M == 0 and N % TILE_N == 0 and (batch * rows) % TILE_K == 0, \
|
||||
f"atb shape {a.shape} {b.shape} must produce (M,N,K) multiples of ({TILE_M},{TILE_N},{TILE_K})"
|
||||
is_multi = isinstance(a.device, tuple)
|
||||
reduce_out = False
|
||||
if is_multi:
|
||||
out = Tensor(Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
|
||||
ndev = len(a.device)
|
||||
if a.uop.axis in (0, 1) or b.uop.axis in (0, 1): inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
|
||||
elif b.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M, N // ndev, dtype=a.dtype, device=a.device), 2
|
||||
elif a.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M // ndev, N, dtype=a.dtype, device=a.device), 1
|
||||
else: inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
|
||||
out = Tensor(inv.uop.multi(out_axis), device=a.device)
|
||||
dname = a.device[0]
|
||||
else:
|
||||
out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device)
|
||||
dname = a.device
|
||||
dname = dname.split(":")[0]
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_hk_bf16_atb_gemm, dname=dname))[0]
|
||||
if is_multi: out = out.sum(0)
|
||||
if reduce_out: out = out.sum(0)
|
||||
return out.squeeze(0) if out.ndim == 3 else out
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue