gemm: fix mxfp8 on more shapes (#16677)

This commit is contained in:
wozeparrot 2026-06-19 16:28:53 -04:00 committed by GitHub
commit bba611bb59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View file

@ -134,11 +134,14 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
int sa_idx = block_row, sb_idx = block_col;
#pragma unroll 2
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
if (k + 1 < k_iters) {
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0);
@ -194,8 +197,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
{ // Epilogue k = k_iters - 2
int k = k_iters - 2;
if (k + 1 < k_iters) {
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");