mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
mxfp8 part 2 (#16561)
This commit is contained in:
parent
83971860d8
commit
c38d6a7e3a
2 changed files with 129 additions and 3 deletions
|
|
@ -2652,7 +2652,7 @@ def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int
|
|||
# ** MXFP8 GEMM custom kernel
|
||||
|
||||
@functools.cache
|
||||
def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *, dname:str) -> UOp:
|
||||
def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:UOp, dname:str) -> UOp:
|
||||
# mxfp8 block-scaled gemm: A(M,K) @ B(N,K).T, e8m0 1x32 microscales packed (k_iters,dim) uint32
|
||||
M, K = A.shape[0]*A.shape[1], A.shape[2]
|
||||
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
|
||||
|
|
@ -2670,6 +2670,26 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *, dname
|
|||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src),
|
||||
UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# 1x32 block scaling along the last axis
|
||||
rows, K = x.shape
|
||||
scale_K, k_iters = K // 32, K // 128
|
||||
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
|
||||
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K)
|
||||
x_scaled = x.float() * qscale
|
||||
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
|
||||
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8)
|
||||
|
||||
def mx_pack(e8:Tensor) -> Tensor:
|
||||
rows, scale_K = e8.shape
|
||||
return e8.reshape(rows, scale_K // 4, 4).bitcast(dtypes.uint32).reshape(rows, scale_K // 4).permute(1, 0).contiguous()
|
||||
|
||||
def _mx_block_scale(e8:Tensor) -> Tensor:
|
||||
# dequant scale 2^(e8-127) broadcast back to element shape
|
||||
rows, scale_K = e8.shape
|
||||
return (e8.cast(dtypes.float32) - 127.0).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, scale_K*32)
|
||||
|
||||
counters = {"used":0, "todos":[]}
|
||||
def todo(msg:str) -> bool: counters["todos"].append(msg); return False
|
||||
def _asm_gemm_report():
|
||||
|
|
@ -2854,10 +2874,30 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
# hk_bf16 uses b.T, writes gradients only for a and b
|
||||
return (None, grad_a, None, grad_b) if hk_bf16 else (None, grad_a, grad_b)
|
||||
|
||||
# ** mxfp8 gemm backward
|
||||
|
||||
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
|
||||
inputs = kernel.src[1:] # (out, a_q, b_q, a_si, b_si, a_e8, b_e8, [w_post])
|
||||
aq, bq = Tensor(inputs[1], device=inputs[1].device), Tensor(inputs[2], device=inputs[2].device)
|
||||
ae8, be8 = Tensor(inputs[5], device=inputs[5].device), Tensor(inputs[6], device=inputs[6].device)
|
||||
wp = Tensor(inputs[7], device=inputs[7].device) if has_w_post else None
|
||||
|
||||
a_phys = (aq.reshape(-1, aq.shape[-1]).cast(dtypes.bfloat16) * _mx_block_scale(ae8)).cast(dtypes.bfloat16)
|
||||
b_phys = (bq.cast(dtypes.bfloat16) * _mx_block_scale(be8)).cast(dtypes.bfloat16)
|
||||
|
||||
g = Tensor(gradient, device=aq.device)[:aq.shape[0]].reshape(aq.shape[0]*aq.shape[1], bq.shape[0]).cast(dtypes.bfloat16)
|
||||
grad_a = asm_gemm(g, b_phys, mx=True)
|
||||
grad_b = asm_gemm(g.T, a_phys, mx=True)
|
||||
|
||||
grad_a = (grad_a * _mx_block_scale(ae8)).reshape(aq.shape)
|
||||
grad_b = grad_b * _mx_block_scale(be8)
|
||||
if wp is not None: grad_b = grad_b / wp.reshape(-1, 1)
|
||||
return (None, grad_a.uop, grad_b.uop) + tuple(None for _ in inputs[3:])
|
||||
|
||||
# ** main gemm function
|
||||
|
||||
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
|
||||
w_post_scale:Tensor|None=None) -> Tensor:
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None) -> Tensor:
|
||||
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
|
||||
counters["used"] += 1
|
||||
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
|
||||
|
|
@ -2889,8 +2929,21 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
|
|||
renderer = Device[dname:=(a.device[0] if is_multi else a.device)].renderer
|
||||
dname, arch = dname.split(":")[0], renderer.target.arch
|
||||
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
|
||||
if mx:
|
||||
# mxfp8 1x32 block scaling
|
||||
if mx_scales is not None:
|
||||
a_si, a_e8, b_si, b_e8 = mx_scales
|
||||
a_q, b_q = a.reshape(-1, a.shape[-1]), b.T
|
||||
else:
|
||||
a_q, a_e8, a_si = quantize_mxfp8(a.reshape(-1, a.shape[-1]))
|
||||
b_q, b_e8, b_si = quantize_mxfp8(b.T)
|
||||
has_w_post = w_post_scale is not None
|
||||
fxn = functools.partial(custom_hk_mxfp8_gemm, dname=dname)
|
||||
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post)
|
||||
extra = [w_post_scale] if w_post_scale is not None else []
|
||||
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
|
||||
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
|
||||
if a.dtype == FP8_DTYPE:
|
||||
elif a.dtype == FP8_DTYPE:
|
||||
scales = tuple(s for s in (x_scale, w_scale) if s is not None)
|
||||
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0)
|
||||
extra = ([grad_amax_state] if grad_amax_state is not None else []) + ([w_post_scale] if w_post_scale is not None else [])
|
||||
|
|
|
|||
|
|
@ -268,15 +268,88 @@ def run_mxfp8_gemm(M:int, N:int, K:int) -> None:
|
|||
assert err_mx < 1e-2, f"kernel vs mxfp8 reference rel err {err_mx}"
|
||||
assert err < 6e-2, f"kernel vs fp32 rel err {err}"
|
||||
|
||||
def run_mx_gemm_bw(M:int, N:int, K:int, w_post:bool=False) -> None:
|
||||
Tensor.manual_seed(0)
|
||||
a_rand = (Tensor.randn(M, K, dtype=dtypes.float) * 0.5).cast(dtypes.bfloat16).realize()
|
||||
b_rand = (Tensor.randn(N, K, dtype=dtypes.float) * 0.5).cast(dtypes.bfloat16).realize()
|
||||
w_post_scale = (Tensor.rand(N, dtype=dtypes.float) + 0.5).realize() if w_post else None
|
||||
a, b, a_ref, b_ref = a_rand.clone(), b_rand.clone(), a_rand.clone(), b_rand.clone()
|
||||
tst = asm_gemm(a, b.T, mx=True, w_post_scale=w_post_scale)
|
||||
tst.sum().backward()
|
||||
Tensor.realize(tst, a.grad, b.grad)
|
||||
a_grad, b_grad = a.grad.float().contiguous().realize(), b.grad.float().contiguous().realize()
|
||||
ref = a_ref.float() @ b_ref.float().T
|
||||
if w_post is not None and w_post_scale is not None: ref = ref * w_post_scale.reshape(1, -1)
|
||||
ref.sum().backward()
|
||||
ref_b_grad = b_ref.grad / w_post_scale.reshape(-1, 1) if w_post_scale is not None else b_ref.grad
|
||||
Tensor.realize(ref, a_ref.grad, b_ref.grad)
|
||||
if a.device.startswith("NULL"): return
|
||||
for name, t, r in [("fw", tst, ref), ("grad_a", a_grad, a_ref.grad), ("grad_b", b_grad, ref_b_grad)]:
|
||||
err = ((t.float() - r.float()).abs().mean() / (r.float().abs().mean() + 1e-8)).item()
|
||||
assert err < 6e-2, f"{name} rel err {err}"
|
||||
|
||||
def run_mx_gemm_multi(M:int, N:int, K:int, x_shard, w_shard, g_shard, gpus:int=2) -> None:
|
||||
Tensor.manual_seed(0)
|
||||
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(gpus))
|
||||
x_r = (Tensor.randn(M, K, dtype=dtypes.float) * 0.5).cast(dtypes.bfloat16).realize()
|
||||
w_r = (Tensor.randn(N, K, dtype=dtypes.float) * 0.5).cast(dtypes.bfloat16).realize()
|
||||
def run(shard):
|
||||
x = (x_r.shard(devs, axis=x_shard) if shard else x_r.clone())
|
||||
w = (w_r.shard(devs, axis=w_shard) if shard else w_r.clone())
|
||||
out = asm_gemm(x, w.T, mx=True)
|
||||
gmul = Tensor.ones(M, N).cast(dtypes.bfloat16)
|
||||
(out.float() * (gmul.shard(devs, axis=g_shard) if shard else gmul).float()).sum().backward()
|
||||
Tensor.realize(out, x.grad, w.grad)
|
||||
to = (lambda t: t.to(Device.DEFAULT)) if shard else (lambda t: t)
|
||||
return to(out).float().numpy(), to(x.grad).float().numpy(), to(w.grad).float().numpy()
|
||||
ref = run(False)
|
||||
if Device.DEFAULT.startswith("NULL"): return
|
||||
got = run(True)
|
||||
for name, g, r in zip(("fw", "grad_x", "grad_w"), got, ref):
|
||||
err = ((abs(g - r)).mean() / (abs(r).mean() + 1e-8))
|
||||
assert err < 2e-2, f"{name} sharded vs single rel err {err}"
|
||||
|
||||
def run_mx_prequant(M:int, N:int, K:int) -> None:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
Tensor.manual_seed(0)
|
||||
x_rand = (Tensor.randn(M, K, dtype=dtypes.float) * 0.5).cast(dtypes.bfloat16).realize()
|
||||
w_rand = (Tensor.randn(N, K, dtype=dtypes.float) * 0.5).cast(dtypes.bfloat16).realize()
|
||||
x, w = x_rand.clone(), w_rand.clone()
|
||||
x_q, x_e8, x_si = quantize_mxfp8(x)
|
||||
w_q, w_e8, w_si = quantize_mxfp8(w)
|
||||
out = asm_gemm(x_q, w_q.T, mx=True, mx_scales=(x_si, x_e8, w_si, w_e8))
|
||||
out.sum().backward()
|
||||
Tensor.realize(out, x.grad, w.grad)
|
||||
if Device.DEFAULT.startswith("NULL"): return
|
||||
ref_out, gx = x_rand.float() @ w_rand.float().T, w_rand.float().sum(0)
|
||||
gw = x_rand.float().sum(0).reshape(1, K).expand(N, K)
|
||||
for name, t, r in [("fw", out, ref_out), ("grad_x", x.grad, gx), ("grad_w", w.grad, gw)]:
|
||||
err = ((t.float() - r.float()).abs().mean() / (r.float().abs().mean() + 1e-8)).item()
|
||||
assert err < 6e-2, f"{name} prequant vs analytic rel err {err}"
|
||||
|
||||
@unittest.skipUnless(has_hipcc(), "MXFP8 gemm requires hipcc to compile")
|
||||
class TestGemmMXFP8(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if not is_cdna4() or DEV.interface.startswith("MOCK"): self.skipTest("mxfp8 gemm is only for cdna4")
|
||||
def test_prequant_simple(self): run_mx_prequant(256, 256, 256)
|
||||
def test_prequant_rect(self): run_mx_prequant(512, 256, 512)
|
||||
def test_simple(self): run_mxfp8_gemm(N:=getenv("N", 256), N, 2*128)
|
||||
def test_rect(self): run_mxfp8_gemm(512, 256, 512)
|
||||
def test_llama_ffn(self): run_mxfp8_gemm(8192, 14336, 4096)
|
||||
def test_llama_ffn2(self): run_mxfp8_gemm(8192, 4096, 14336)
|
||||
def test_llama_qkv(self): run_mxfp8_gemm(8192, 4096, 4096)
|
||||
# backward needs all dims tile-aligned (dgrad reduces N, wgrad reduces M)
|
||||
def test_bw_simple(self): run_mx_gemm_bw(256, 256, 256)
|
||||
def test_bw_rect(self): run_mx_gemm_bw(512, 256, 512)
|
||||
def test_bw_w_post(self): run_mx_gemm_bw(256, 256, 256, w_post=True)
|
||||
def test_bw_llama_qkv(self): run_mx_gemm_bw(8192, 4096, 4096)
|
||||
# MP sharding: col-parallel (w on out axis), row-parallel (x,w on in axis)
|
||||
@needs_second_gpu
|
||||
def test_multi_col_parallel(self): run_mx_gemm_multi(512, 512, 512, x_shard=None, w_shard=0, g_shard=1)
|
||||
@needs_second_gpu
|
||||
def test_multi_row_parallel(self): run_mx_gemm_multi(512, 512, 512, x_shard=1, w_shard=1, g_shard=None)
|
||||
@needs_second_gpu
|
||||
def test_multi_data_parallel(self): run_mx_gemm_multi(512, 512, 512, x_shard=0, w_shard=None, g_shard=0)
|
||||
|
||||
class TestMagicGu(unittest.TestCase):
|
||||
def test_magicgu_matches_old(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue