mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
gemm: fix mxfp8 on more shapes (#16677)
This commit is contained in:
parent
67c3e589a1
commit
bba611bb59
2 changed files with 12 additions and 4 deletions
|
|
@ -134,11 +134,14 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
int sa_idx = block_row, sb_idx = block_col;
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
|
||||
if (k + 1 < k_iters) {
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
|
|
@ -194,8 +197,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
{ // Epilogue k = k_iters - 2
|
||||
int k = k_iters - 2;
|
||||
if (k + 1 < k_iters) {
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue