llama: don't allocate grad_xw13 in bf16 (#16359)

This commit is contained in:
qazal 2026-05-27 22:33:07 +03:00 committed by GitHub
commit 452c7d4230
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 18 deletions

View file

@ -11,13 +11,13 @@ from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, compile_cpp, al
_grad_fp8_mailbox:dict[UOp, tuple[UOp, UOp]] = {}
@functools.cache
def _custom_fused_bwd_w13(grad_xw13:UOp, grad_xw13_fp8:UOp, grad_amax_buf:UOp,
def _custom_fused_bwd_w13(grad_xw13_fp8:UOp, grad_amax_buf:UOp,
xw13:UOp, grad_x2:UOp, amax_state:UOp, grad_amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 * 5 + n_elems * 2 + NUM_WG * 4 + 4
sink = UOp.sink(grad_xw13.base, grad_xw13_fp8.base, grad_amax_buf.base,
mem = n_elems * 2 * 3 + n_elems * 2 + NUM_WG * 4 + 4
sink = UOp.sink(grad_xw13_fp8.base, grad_amax_buf.base,
xw13.base, grad_x2.base, amax_state.base, grad_amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=10*n_elems, mem=mem)))
src, lib = compile_cpp(pathlib.Path(__file__).parent, "cast_amax_bwd_w13.cpp", n_elems, hidden)
@ -41,23 +41,23 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
_, _, xw13, amax_state, grad_amax_state = kernel.src[1:]
device = xw13.device
axis = xw13.axis if isinstance(device, tuple) else None
grad_xw13 = alloc_like(xw13.shape, dtypes.bfloat16, device, axis)
grad_xw13_fp8 = alloc_like(xw13.shape, dtypes.fp8e4m3, device, axis)
grad_amax_buf = alloc_local((NUM_WG,), dtypes.float32, device, axis)
grad_amax_state_t = Tensor(grad_amax_state, device=device)
fxn = functools.partial(_custom_fused_bwd_w13, dname=dname_of(device))
grad_xw13, grad_xw13_fp8, grad_amax_buf, *_ = Tensor.custom_kernel(
grad_xw13, grad_xw13_fp8, grad_amax_buf,
grad_xw13_fp8, grad_amax_buf, *_ = Tensor.custom_kernel(
grad_xw13_fp8, grad_amax_buf,
Tensor(xw13, device=device), Tensor(gradient, device=device).cast(dtypes.bfloat16),
Tensor(amax_state, device=device), grad_amax_state_t, fxn=fxn)
grad_xw13_uop = grad_xw13_fp8.uop.cast(dtypes.bfloat16)
inv_scale = (grad_amax_state_t.float() + 1e-8) / FP8_MAX
new_grad_amax = scalar_amax(grad_amax_buf)
store_effect = grad_amax_state_t.uop.store(new_grad_amax.uop)
assert grad_xw13_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {grad_xw13_fp8.uop.op}"
grad_xw13_fp8_uop = grad_xw13_fp8.uop.replace(src=grad_xw13_fp8.uop.src + (store_effect,))
# Stash fp8 companion for cdna_asm_gemm's bwd to attach to grad_a.
_grad_fp8_mailbox[grad_xw13.uop] = (grad_xw13_fp8_uop, inv_scale.uop)
return (None, None, grad_xw13.uop, None, None)
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, inv_scale.uop)
return (None, None, grad_xw13_uop, None, None)
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)

View file

@ -21,15 +21,13 @@ constexpr float FP8_MAX = 448.0f;
static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC");
// fused silu*mul backward, three outputs in a single HBM pass:
// 1) bf16 grad_xw13 — consumed by downstream bf16 autograd chain
// 2) fp8 grad_xw13_fp8 — delayed-scale quantize using grad_amax_state (mailbox to matmul bwd)
// 3) fp32 grad_amax_buf — per-WG partial |grad_xw13|, reduced into next step's grad_amax_state
// fused silu*mul backward, two outputs in a single HBM pass:
// 1) fp8 grad_xw13_fp8 — delayed-scale quantize using grad_amax_state (mailbox to matmul bwd)
// 2) fp32 grad_amax_buf — per-WG partial |grad_xw13|, reduced into next step's grad_amax_state
// grad_amax_state is read for the fp8 scale. The store of new_grad_amax into grad_amax_state's
// buffer is built in Python as a separate effect and threaded into grad_a via .after(store).
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
fused_silu_mul_bwd_w13(
__hip_bfloat16* __restrict__ grad_xw13_out, // bf16, 2*N_ELEMS
__hip_fp8_storage_t* __restrict__ grad_xw13_fp8_out, // fp8, 2*N_ELEMS
float* __restrict__ grad_amax_buf, // fp32, NUM_WG per-WG partials
const __hip_bfloat16* __restrict__ xw13, // bf16, 2*N_ELEMS
@ -62,7 +60,6 @@ fused_silu_mul_bwd_w13(
const __hip_bfloat16 *x3 = reinterpret_cast<const __hip_bfloat16*>(&x3_raw);
const __hip_bfloat16 *gv = reinterpret_cast<const __hip_bfloat16*>(&g_raw);
__hip_bfloat16 out1[VEC], out3[VEC];
__hip_fp8_storage_t fp8_1[VEC], fp8_3[VEC];
#pragma unroll
for (int i = 0; i < VEC; i++) {
@ -75,15 +72,11 @@ fused_silu_mul_bwd_w13(
const float gs = fg * scale;
const float g1 = gs * silu_prime * f3;
const float g3 = gs * silu;
out1[i] = static_cast<__hip_bfloat16>(g1);
out3[i] = static_cast<__hip_bfloat16>(g3);
local_max = fmaxf(local_max, fmaxf(fabsf(g1), fabsf(g3)));
fp8_1[i] = __hip_cvt_float_to_fp8(fmaxf(-FP8_MAX, fminf(FP8_MAX, g1 * g_scale)), __HIP_SATFINITE, __HIP_E4M3);
fp8_3[i] = __hip_cvt_float_to_fp8(fmaxf(-FP8_MAX, fminf(FP8_MAX, g3 * g_scale)), __HIP_SATFINITE, __HIP_E4M3);
}
*reinterpret_cast<float4*>(&grad_xw13_out[xw1_off]) = *reinterpret_cast<float4*>(out1);
*reinterpret_cast<float4*>(&grad_xw13_out[xw3_off]) = *reinterpret_cast<float4*>(out3);
*reinterpret_cast<uint64_t*>(&grad_xw13_fp8_out[xw1_off]) = *reinterpret_cast<uint64_t*>(fp8_1);
*reinterpret_cast<uint64_t*>(&grad_xw13_fp8_out[xw3_off]) = *reinterpret_cast<uint64_t*>(fp8_3);
}