mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fp8 gemm tests with scale args (#16231)
* update atol * update fp8 path * more work * update profile.sh
This commit is contained in:
parent
e575f778f9
commit
ebcb7b7cc0
2 changed files with 36 additions and 16 deletions
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
export BENCHMARK=5
|
||||
export EVAL_BS=0
|
||||
VIZ=${VIZ:--1} FULL_LAYERS=1 DEBUG=0 examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh
|
||||
VIZ=${VIZ:--1} FULL_LAYERS=1 DEBUG=0 examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama31_8b/implementations/tinybox_8xMI350X/dev_beam.sh
|
||||
SRC="AMD"; [[ $DEV == NULL* ]] && SRC="NULL"
|
||||
python -m tinygrad.viz.cli -s "$SRC" -t
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from tinygrad.device import is_dtype_supported
|
|||
from tinygrad.helpers import getenv, system, DEV
|
||||
from extra.gemm.cdna_asm_gemm import asm_gemm
|
||||
from test.helpers import needs_second_gpu
|
||||
from examples.mlperf.models.flat_llama import FP8_DTYPE
|
||||
from examples.mlperf.models.flat_llama import FP8_DTYPE, quantize_fp8, FP8_MAX
|
||||
|
||||
# On non CDNA4 it will only validate the Tensor.custom_kernel integration
|
||||
# Use DEV=NULL:HIP:gfx950 to also test the assembly
|
||||
|
|
@ -12,33 +12,47 @@ def is_cdna4(): return Device[Device.DEFAULT].renderer.target.arch.startswith("g
|
|||
|
||||
def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=None, gpus:int=1) -> None:
|
||||
Tensor.manual_seed(0)
|
||||
a_rand = Tensor.randn(a_shape, dtype=dtypes.float, requires_grad=False).sub(0.5).cast(dtype)
|
||||
b_rand = Tensor.randn(b_shape, dtype=dtypes.float, requires_grad=False).sub(0.5).cast(dtype)
|
||||
input_dtype = dtypes.bfloat16 if dtype == FP8_DTYPE else dtype
|
||||
a_rand = Tensor.randn(a_shape, dtype=dtypes.float, requires_grad=False).sub(0.5).cast(input_dtype)
|
||||
b_rand = Tensor.randn(b_shape, dtype=dtypes.float, requires_grad=False).sub(0.5).cast(input_dtype)
|
||||
with Context(DEBUG=0):
|
||||
Tensor.realize(a_rand, b_rand)
|
||||
|
||||
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(gpus)) if (multi:=gpus>1) else None
|
||||
|
||||
a, b = a_rand.clone().requires_grad_(), b_rand.clone().requires_grad_()
|
||||
if dtype == FP8_DTYPE:
|
||||
a_rand, x_scale, _ = quantize_fp8(a_rand)
|
||||
b_rand, w_scale, _ = quantize_fp8(b_rand)
|
||||
grad_amax_state = Tensor.full((), FP8_MAX, dtype=dtypes.float32, device=devs).contiguous()
|
||||
with Context(DEBUG=0):
|
||||
Tensor.realize(a_rand, x_scale, b_rand, w_scale, grad_amax_state)
|
||||
|
||||
a, b = a_rand.clone(), b_rand.clone()
|
||||
if multi: a, b = a.shard(devs, axis=a_shard), b.shard(devs, axis=b_shard)
|
||||
tst = asm_gemm(a, b)
|
||||
if dtype == FP8_DTYPE:
|
||||
tst = asm_gemm(a, b, x_scale=x_scale, w_scale=w_scale, grad_amax_state=grad_amax_state)
|
||||
else:
|
||||
tst = asm_gemm(a, b)
|
||||
tst.sum().backward()
|
||||
Tensor.realize(tst, a.grad, b.grad)
|
||||
|
||||
a_ref, b_ref = a_rand.clone().requires_grad_(), b_rand.clone().requires_grad_()
|
||||
# do reference gemm in bf16 for fp8, adjusting atol for quantization effects
|
||||
if a_ref.dtype == FP8_DTYPE:
|
||||
a_ref = a_ref.cast(dtypes.bfloat16)
|
||||
b_ref = b_ref.cast(dtypes.bfloat16)
|
||||
if dtype == FP8_DTYPE:
|
||||
a_ref, b_ref = a_rand.detach().cast(dtypes.bfloat16).requires_grad_(), b_rand.detach().cast(dtypes.bfloat16).requires_grad_()
|
||||
else:
|
||||
a_ref, b_ref = a_rand.clone(), b_rand.clone()
|
||||
if multi: a_ref, b_ref = a_ref.shard(devs, axis=a_shard), b_ref.shard(devs, axis=b_shard)
|
||||
ref = a_ref @ b_ref
|
||||
if dtype == FP8_DTYPE:
|
||||
ref = ((a_ref @ b_ref) * x_scale * w_scale).cast(dtypes.bfloat16)
|
||||
else:
|
||||
ref = a_ref @ b_ref
|
||||
ref.sum().backward()
|
||||
Tensor.realize(ref, a_ref.grad, b_ref.grad)
|
||||
|
||||
# no validation on the NULL device
|
||||
if a_rand.device.startswith("NULL"): return None
|
||||
atol, rtol = (2e-1, 1e-2) if dtype == dtypes.bfloat16 else (256, 1e-2) if dtype == FP8_DTYPE else (1e-2, 1e-3)
|
||||
grad_atol, grad_rtol = (16895, 0.125) if dtype == FP8_DTYPE else (atol, rtol)
|
||||
# allow more rtol for multi because of ALLREDUCE_CAST
|
||||
grad_atol, grad_rtol = (16895, 0.125) if dtype == FP8_DTYPE else (atol, 2e-2 if multi else rtol)
|
||||
with Context(DEBUG=0):
|
||||
# enable for debugging, slow for larger gemms
|
||||
if getenv("USE_NPY"):
|
||||
|
|
@ -139,11 +153,17 @@ class TestGemmLlama(unittest.TestCase):
|
|||
def test_empty_bw(self):
|
||||
x = Tensor.empty(1, N:=getenv("N", 4096), N, dtype=self.dtype, requires_grad=True)
|
||||
y = Tensor.empty((N, N), dtype=self.dtype, requires_grad=True)
|
||||
z = asm_gemm(x, y)
|
||||
if self.dtype == FP8_DTYPE:
|
||||
x_scale = Tensor.empty((), dtype=dtypes.float32)
|
||||
w_scale = Tensor.empty((), dtype=dtypes.float32)
|
||||
grad_amax_state = Tensor.empty((), dtype=dtypes.float32).contiguous()
|
||||
z = asm_gemm(x, y, x_scale=x_scale, w_scale=w_scale, grad_amax_state=grad_amax_state)
|
||||
else:
|
||||
z = asm_gemm(x, y)
|
||||
z.sum().backward()
|
||||
Tensor.realize(z, x.grad, y.grad)
|
||||
# FP8 forward output is bf16, gradients use fp8e5m2 (aka bf8)
|
||||
grad_dtype = dtypes.fp8e5m2 if self.dtype == FP8_DTYPE else self.dtype
|
||||
# FP8 GEMM stores bf16 output and its backward produces bf16 gradients.
|
||||
grad_dtype = dtypes.bfloat16 if self.dtype == FP8_DTYPE else self.dtype
|
||||
assert z.dtype == dtypes.bfloat16
|
||||
assert x.grad.dtype == y.grad.dtype == grad_dtype
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue