mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: don't allocate grad_xw13 in bf16 (#16359)
This commit is contained in:
parent
0c385e31c6
commit
452c7d4230
2 changed files with 11 additions and 18 deletions
|
|
@ -11,13 +11,13 @@ from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, compile_cpp, al
|
|||
_grad_fp8_mailbox:dict[UOp, tuple[UOp, UOp]] = {}
|
||||
|
||||
@functools.cache
|
||||
def _custom_fused_bwd_w13(grad_xw13:UOp, grad_xw13_fp8:UOp, grad_amax_buf:UOp,
|
||||
def _custom_fused_bwd_w13(grad_xw13_fp8:UOp, grad_amax_buf:UOp,
|
||||
xw13:UOp, grad_x2:UOp, amax_state:UOp, grad_amax_state:UOp, dname:str) -> UOp:
|
||||
hidden = xw13.shape[2] // 2
|
||||
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
|
||||
mem = n_elems * 2 * 5 + n_elems * 2 + NUM_WG * 4 + 4
|
||||
sink = UOp.sink(grad_xw13.base, grad_xw13_fp8.base, grad_amax_buf.base,
|
||||
mem = n_elems * 2 * 3 + n_elems * 2 + NUM_WG * 4 + 4
|
||||
sink = UOp.sink(grad_xw13_fp8.base, grad_amax_buf.base,
|
||||
xw13.base, grad_x2.base, amax_state.base, grad_amax_state.base, threads, workgroups,
|
||||
arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=10*n_elems, mem=mem)))
|
||||
src, lib = compile_cpp(pathlib.Path(__file__).parent, "cast_amax_bwd_w13.cpp", n_elems, hidden)
|
||||
|
|
@ -41,23 +41,23 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
|
|||
_, _, xw13, amax_state, grad_amax_state = kernel.src[1:]
|
||||
device = xw13.device
|
||||
axis = xw13.axis if isinstance(device, tuple) else None
|
||||
grad_xw13 = alloc_like(xw13.shape, dtypes.bfloat16, device, axis)
|
||||
grad_xw13_fp8 = alloc_like(xw13.shape, dtypes.fp8e4m3, device, axis)
|
||||
grad_amax_buf = alloc_local((NUM_WG,), dtypes.float32, device, axis)
|
||||
grad_amax_state_t = Tensor(grad_amax_state, device=device)
|
||||
fxn = functools.partial(_custom_fused_bwd_w13, dname=dname_of(device))
|
||||
grad_xw13, grad_xw13_fp8, grad_amax_buf, *_ = Tensor.custom_kernel(
|
||||
grad_xw13, grad_xw13_fp8, grad_amax_buf,
|
||||
grad_xw13_fp8, grad_amax_buf, *_ = Tensor.custom_kernel(
|
||||
grad_xw13_fp8, grad_amax_buf,
|
||||
Tensor(xw13, device=device), Tensor(gradient, device=device).cast(dtypes.bfloat16),
|
||||
Tensor(amax_state, device=device), grad_amax_state_t, fxn=fxn)
|
||||
grad_xw13_uop = grad_xw13_fp8.uop.cast(dtypes.bfloat16)
|
||||
inv_scale = (grad_amax_state_t.float() + 1e-8) / FP8_MAX
|
||||
new_grad_amax = scalar_amax(grad_amax_buf)
|
||||
store_effect = grad_amax_state_t.uop.store(new_grad_amax.uop)
|
||||
assert grad_xw13_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {grad_xw13_fp8.uop.op}"
|
||||
grad_xw13_fp8_uop = grad_xw13_fp8.uop.replace(src=grad_xw13_fp8.uop.src + (store_effect,))
|
||||
# Stash fp8 companion for cdna_asm_gemm's bwd to attach to grad_a.
|
||||
_grad_fp8_mailbox[grad_xw13.uop] = (grad_xw13_fp8_uop, inv_scale.uop)
|
||||
return (None, None, grad_xw13.uop, None, None)
|
||||
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, inv_scale.uop)
|
||||
return (None, None, grad_xw13_uop, None, None)
|
||||
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
|
||||
|
|
|
|||
|
|
@ -21,15 +21,13 @@ constexpr float FP8_MAX = 448.0f;
|
|||
static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
|
||||
static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC");
|
||||
|
||||
// fused silu*mul backward, three outputs in a single HBM pass:
|
||||
// 1) bf16 grad_xw13 — consumed by downstream bf16 autograd chain
|
||||
// 2) fp8 grad_xw13_fp8 — delayed-scale quantize using grad_amax_state (mailbox to matmul bwd)
|
||||
// 3) fp32 grad_amax_buf — per-WG partial |grad_xw13|, reduced into next step's grad_amax_state
|
||||
// fused silu*mul backward, two outputs in a single HBM pass:
|
||||
// 1) fp8 grad_xw13_fp8 — delayed-scale quantize using grad_amax_state (mailbox to matmul bwd)
|
||||
// 2) fp32 grad_amax_buf — per-WG partial |grad_xw13|, reduced into next step's grad_amax_state
|
||||
// grad_amax_state is read for the fp8 scale. The store of new_grad_amax into grad_amax_state's
|
||||
// buffer is built in Python as a separate effect and threaded into grad_a via .after(store).
|
||||
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
|
||||
fused_silu_mul_bwd_w13(
|
||||
__hip_bfloat16* __restrict__ grad_xw13_out, // bf16, 2*N_ELEMS
|
||||
__hip_fp8_storage_t* __restrict__ grad_xw13_fp8_out, // fp8, 2*N_ELEMS
|
||||
float* __restrict__ grad_amax_buf, // fp32, NUM_WG per-WG partials
|
||||
const __hip_bfloat16* __restrict__ xw13, // bf16, 2*N_ELEMS
|
||||
|
|
@ -62,7 +60,6 @@ fused_silu_mul_bwd_w13(
|
|||
const __hip_bfloat16 *x3 = reinterpret_cast<const __hip_bfloat16*>(&x3_raw);
|
||||
const __hip_bfloat16 *gv = reinterpret_cast<const __hip_bfloat16*>(&g_raw);
|
||||
|
||||
__hip_bfloat16 out1[VEC], out3[VEC];
|
||||
__hip_fp8_storage_t fp8_1[VEC], fp8_3[VEC];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC; i++) {
|
||||
|
|
@ -75,15 +72,11 @@ fused_silu_mul_bwd_w13(
|
|||
const float gs = fg * scale;
|
||||
const float g1 = gs * silu_prime * f3;
|
||||
const float g3 = gs * silu;
|
||||
out1[i] = static_cast<__hip_bfloat16>(g1);
|
||||
out3[i] = static_cast<__hip_bfloat16>(g3);
|
||||
local_max = fmaxf(local_max, fmaxf(fabsf(g1), fabsf(g3)));
|
||||
fp8_1[i] = __hip_cvt_float_to_fp8(fmaxf(-FP8_MAX, fminf(FP8_MAX, g1 * g_scale)), __HIP_SATFINITE, __HIP_E4M3);
|
||||
fp8_3[i] = __hip_cvt_float_to_fp8(fmaxf(-FP8_MAX, fminf(FP8_MAX, g3 * g_scale)), __HIP_SATFINITE, __HIP_E4M3);
|
||||
}
|
||||
|
||||
*reinterpret_cast<float4*>(&grad_xw13_out[xw1_off]) = *reinterpret_cast<float4*>(out1);
|
||||
*reinterpret_cast<float4*>(&grad_xw13_out[xw3_off]) = *reinterpret_cast<float4*>(out3);
|
||||
*reinterpret_cast<uint64_t*>(&grad_xw13_fp8_out[xw1_off]) = *reinterpret_cast<uint64_t*>(fp8_1);
|
||||
*reinterpret_cast<uint64_t*>(&grad_xw13_fp8_out[xw3_off]) = *reinterpret_cast<uint64_t*>(fp8_3);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue