gemm: fix bf16 atb for mp sharding (#16637)

This commit is contained in:
wozeparrot 2026-06-16 18:58:47 -04:00 committed by GitHub
commit 36f6d1b064
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 34 additions and 3 deletions

View file

@ -2788,15 +2788,21 @@ def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
assert M % TILE_M == 0 and N % TILE_N == 0 and (batch * rows) % TILE_K == 0, \
f"atb shape {a.shape} {b.shape} must produce (M,N,K) multiples of ({TILE_M},{TILE_N},{TILE_K})"
is_multi = isinstance(a.device, tuple)
reduce_out = False
if is_multi:
out = Tensor(Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
ndev = len(a.device)
if a.uop.axis in (0, 1) or b.uop.axis in (0, 1): inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
elif b.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M, N // ndev, dtype=a.dtype, device=a.device), 2
elif a.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M // ndev, N, dtype=a.dtype, device=a.device), 1
else: inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
out = Tensor(inv.uop.multi(out_axis), device=a.device)
dname = a.device[0]
else:
out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device)
dname = a.device
dname = dname.split(":")[0]
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_hk_bf16_atb_gemm, dname=dname))[0]
if is_multi: out = out.sum(0)
if reduce_out: out = out.sum(0)
return out.squeeze(0) if out.ndim == 3 else out

View file

@ -1,7 +1,7 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.helpers import getenv, system, DEV
from extra.gemm.cdna_asm_gemm import asm_gemm
from extra.gemm.cdna_asm_gemm import asm_gemm, hk_bf16_atb_gemm
from test.helpers import needs_second_gpu
from examples.mlperf.models.flat_llama import FP8_DTYPE, quantize_fp8, FP8_MAX
@ -351,6 +351,31 @@ class TestGemmMXFP8(unittest.TestCase):
@needs_second_gpu
def test_multi_data_parallel(self): run_mx_gemm_multi(512, 512, 512, x_shard=0, w_shard=None, g_shard=0)
def run_atb_gemm(rows, M, N, a_shard=None, b_shard=None, gpus=1, atol=1.0, rtol=3e-2) -> None:
import numpy as np
Tensor.manual_seed(0)
a = Tensor.randn(1, rows, M, dtype=dtypes.float).cast(dtypes.bfloat16)
b = Tensor.randn(1, rows, N, dtype=dtypes.float).cast(dtypes.bfloat16)
with Context(DEBUG=0): Tensor.realize(a, b)
ref = (a[0].float().transpose(0, 1) @ b[0].float()).realize() # [M, N]
if gpus > 1:
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(gpus))
a, b = a.shard(devs, axis=a_shard), b.shard(devs, axis=b_shard)
out = hk_bf16_atb_gemm(a, b)
np.testing.assert_allclose(out.float().numpy(), ref.numpy(), atol=atol, rtol=rtol)
@unittest.skipUnless(has_hipcc(), "MXFP8 gemm requires hipcc to compile")
class TestHkBf16AtbGemm(unittest.TestCase):
def setUp(self):
if not is_cdna4(): self.skipTest("hk bf16 atb gemm is cdna4 only")
def test_single(self): run_atb_gemm(256, 256, 256)
@needs_second_gpu
def test_k_sharded(self): run_atb_gemm(512, 256, 256, a_shard=1, b_shard=1, gpus=2)
@needs_second_gpu
def test_n_sharded(self): run_atb_gemm(256, 256, 512, a_shard=None, b_shard=2, gpus=2)
@needs_second_gpu
def test_m_sharded(self): run_atb_gemm(256, 512, 256, a_shard=2, b_shard=None, gpus=2)
class TestMagicGu(unittest.TestCase):
def test_magicgu_matches_old(self):
from extra.gemm.cdna_asm_gemm import _magicgu_mulhi, TILE_M, TILE_N, TILE_K