mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
gemm: fix mxfp8 on more shapes (#16677)
This commit is contained in:
parent
67c3e589a1
commit
bba611bb59
2 changed files with 12 additions and 4 deletions
|
|
@ -134,11 +134,14 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
int sa_idx = block_row, sb_idx = block_col;
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
|
||||
if (k + 1 < k_iters) {
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
|
|
@ -194,8 +197,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
{ // Epilogue k = k_iters - 2
|
||||
int k = k_iters - 2;
|
||||
if (k + 1 < k_iters) {
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
|
|
|
|||
|
|
@ -338,11 +338,15 @@ class TestGemmMXFP8(unittest.TestCase):
|
|||
def test_llama_ffn(self): run_mxfp8_gemm(8192, 14336, 4096)
|
||||
def test_llama_ffn2(self): run_mxfp8_gemm(8192, 4096, 14336)
|
||||
def test_llama_qkv(self): run_mxfp8_gemm(8192, 4096, 4096)
|
||||
def test_general_n_fw(self):
|
||||
for N in (256, 1792, 2048, 8192): run_mxfp8_gemm(8192, N, 4096)
|
||||
# backward needs all dims tile-aligned (dgrad reduces N, wgrad reduces M)
|
||||
def test_bw_simple(self): run_mx_gemm_bw(256, 256, 256)
|
||||
def test_bw_rect(self): run_mx_gemm_bw(512, 256, 512)
|
||||
def test_bw_w_post(self): run_mx_gemm_bw(256, 256, 256, w_post=True)
|
||||
def test_bw_llama_qkv(self): run_mx_gemm_bw(8192, 4096, 4096)
|
||||
def test_general_n_bw(self):
|
||||
for N in (2048, 8192, 14336): run_mx_gemm_bw(8192, N, 4096)
|
||||
# MP sharding: col-parallel (w on out axis), row-parallel (x,w on in axis)
|
||||
@needs_second_gpu
|
||||
def test_multi_col_parallel(self): run_mx_gemm_multi(512, 512, 512, x_shard=None, w_shard=0, g_shard=1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue