gemm: fix bf16 atb for mp sharding (#16637)

This commit is contained in:
wozeparrot 2026-06-16 18:58:47 -04:00 committed by GitHub
commit 36f6d1b064
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 34 additions and 3 deletions

View file

@ -2788,15 +2788,21 @@ def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
assert M % TILE_M == 0 and N % TILE_N == 0 and (batch * rows) % TILE_K == 0, \
f"atb shape {a.shape} {b.shape} must produce (M,N,K) multiples of ({TILE_M},{TILE_N},{TILE_K})"
is_multi = isinstance(a.device, tuple)
reduce_out = False
if is_multi:
out = Tensor(Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
ndev = len(a.device)
if a.uop.axis in (0, 1) or b.uop.axis in (0, 1): inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
elif b.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M, N // ndev, dtype=a.dtype, device=a.device), 2
elif a.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M // ndev, N, dtype=a.dtype, device=a.device), 1
else: inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
out = Tensor(inv.uop.multi(out_axis), device=a.device)
dname = a.device[0]
else:
out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device)
dname = a.device
dname = dname.split(":")[0]
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_hk_bf16_atb_gemm, dname=dname))[0]
if is_multi: out = out.sum(0)
if reduce_out: out = out.sum(0)
return out.squeeze(0) if out.ndim == 3 else out