mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
gemm: fix bf16 atb for mp sharding (#16637)
This commit is contained in:
parent
1cb6b88d37
commit
36f6d1b064
2 changed files with 34 additions and 3 deletions
|
|
@ -2788,15 +2788,21 @@ def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
|
|||
assert M % TILE_M == 0 and N % TILE_N == 0 and (batch * rows) % TILE_K == 0, \
|
||||
f"atb shape {a.shape} {b.shape} must produce (M,N,K) multiples of ({TILE_M},{TILE_N},{TILE_K})"
|
||||
is_multi = isinstance(a.device, tuple)
|
||||
reduce_out = False
|
||||
if is_multi:
|
||||
out = Tensor(Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
|
||||
ndev = len(a.device)
|
||||
if a.uop.axis in (0, 1) or b.uop.axis in (0, 1): inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
|
||||
elif b.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M, N // ndev, dtype=a.dtype, device=a.device), 2
|
||||
elif a.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M // ndev, N, dtype=a.dtype, device=a.device), 1
|
||||
else: inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
|
||||
out = Tensor(inv.uop.multi(out_axis), device=a.device)
|
||||
dname = a.device[0]
|
||||
else:
|
||||
out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device)
|
||||
dname = a.device
|
||||
dname = dname.split(":")[0]
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_hk_bf16_atb_gemm, dname=dname))[0]
|
||||
if is_multi: out = out.sum(0)
|
||||
if reduce_out: out = out.sum(0)
|
||||
return out.squeeze(0) if out.ndim == 3 else out
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.helpers import getenv, system, DEV
|
||||
from extra.gemm.cdna_asm_gemm import asm_gemm
|
||||
from extra.gemm.cdna_asm_gemm import asm_gemm, hk_bf16_atb_gemm
|
||||
from test.helpers import needs_second_gpu
|
||||
from examples.mlperf.models.flat_llama import FP8_DTYPE, quantize_fp8, FP8_MAX
|
||||
|
||||
|
|
@ -351,6 +351,31 @@ class TestGemmMXFP8(unittest.TestCase):
|
|||
@needs_second_gpu
|
||||
def test_multi_data_parallel(self): run_mx_gemm_multi(512, 512, 512, x_shard=0, w_shard=None, g_shard=0)
|
||||
|
||||
def run_atb_gemm(rows, M, N, a_shard=None, b_shard=None, gpus=1, atol=1.0, rtol=3e-2) -> None:
|
||||
import numpy as np
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(1, rows, M, dtype=dtypes.float).cast(dtypes.bfloat16)
|
||||
b = Tensor.randn(1, rows, N, dtype=dtypes.float).cast(dtypes.bfloat16)
|
||||
with Context(DEBUG=0): Tensor.realize(a, b)
|
||||
ref = (a[0].float().transpose(0, 1) @ b[0].float()).realize() # [M, N]
|
||||
if gpus > 1:
|
||||
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(gpus))
|
||||
a, b = a.shard(devs, axis=a_shard), b.shard(devs, axis=b_shard)
|
||||
out = hk_bf16_atb_gemm(a, b)
|
||||
np.testing.assert_allclose(out.float().numpy(), ref.numpy(), atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipUnless(has_hipcc(), "MXFP8 gemm requires hipcc to compile")
|
||||
class TestHkBf16AtbGemm(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if not is_cdna4(): self.skipTest("hk bf16 atb gemm is cdna4 only")
|
||||
def test_single(self): run_atb_gemm(256, 256, 256)
|
||||
@needs_second_gpu
|
||||
def test_k_sharded(self): run_atb_gemm(512, 256, 256, a_shard=1, b_shard=1, gpus=2)
|
||||
@needs_second_gpu
|
||||
def test_n_sharded(self): run_atb_gemm(256, 256, 512, a_shard=None, b_shard=2, gpus=2)
|
||||
@needs_second_gpu
|
||||
def test_m_sharded(self): run_atb_gemm(256, 512, 256, a_shard=2, b_shard=None, gpus=2)
|
||||
|
||||
class TestMagicGu(unittest.TestCase):
|
||||
def test_magicgu_matches_old(self):
|
||||
from extra.gemm.cdna_asm_gemm import _magicgu_mulhi, TILE_M, TILE_N, TILE_K
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue