gemm: fix mxfp8 on more shapes (#16677)

This commit is contained in:
wozeparrot 2026-06-19 16:28:53 -04:00 committed by GitHub
commit bba611bb59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View file

@ -134,11 +134,14 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
int sa_idx = block_row, sb_idx = block_col;
#pragma unroll 2
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
if (k + 1 < k_iters) {
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0);
@ -194,8 +197,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
{ // Epilogue k = k_iters - 2
int k = k_iters - 2;
if (k + 1 < k_iters) {
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");

View file

@ -338,11 +338,15 @@ class TestGemmMXFP8(unittest.TestCase):
def test_llama_ffn(self): run_mxfp8_gemm(8192, 14336, 4096)
def test_llama_ffn2(self): run_mxfp8_gemm(8192, 4096, 14336)
def test_llama_qkv(self): run_mxfp8_gemm(8192, 4096, 4096)
def test_general_n_fw(self):
for N in (256, 1792, 2048, 8192): run_mxfp8_gemm(8192, N, 4096)
# backward needs all dims tile-aligned (dgrad reduces N, wgrad reduces M)
def test_bw_simple(self): run_mx_gemm_bw(256, 256, 256)
def test_bw_rect(self): run_mx_gemm_bw(512, 256, 512)
def test_bw_w_post(self): run_mx_gemm_bw(256, 256, 256, w_post=True)
def test_bw_llama_qkv(self): run_mx_gemm_bw(8192, 4096, 4096)
def test_general_n_bw(self):
for N in (2048, 8192, 14336): run_mx_gemm_bw(8192, N, 4096)
# MP sharding: col-parallel (w on out axis), row-parallel (x,w on in axis)
@needs_second_gpu
def test_multi_col_parallel(self): run_mx_gemm_multi(512, 512, 512, x_shard=None, w_shard=0, g_shard=1)