mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
341 lines
16 KiB
C++
341 lines
16 KiB
C++
#include "kittens.cuh"
|
|
|
|
using namespace kittens;
|
|
|
|
#ifndef GEMM_M
|
|
constexpr int GEMM_M = 8192;
|
|
#endif
|
|
#ifndef GEMM_N
|
|
constexpr int GEMM_N = 8192;
|
|
#endif
|
|
#ifndef GEMM_K
|
|
constexpr int GEMM_K = 8192;
|
|
#endif
|
|
|
|
// Kernel
|
|
constexpr int NUM_WARPS = 8;
|
|
constexpr int WARPS_ROW = 2;
|
|
constexpr int WARPS_COL = 4;
|
|
constexpr int BLOCK_ROW = 256;
|
|
constexpr int BLOCK_COL = 256;
|
|
constexpr int BLOCK_K = 128;
|
|
constexpr int HALF_ROW = BLOCK_ROW / 2;
|
|
constexpr int HALF_COL = BLOCK_COL / 2;
|
|
constexpr int REG_M = BLOCK_ROW / WARPS_ROW / 2;
|
|
constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
|
|
|
|
using G = kittens::group<NUM_WARPS>;
|
|
|
|
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr, fp8e8m0 *scale_A_ptr, fp8e8m0 *scale_B_ptr,
|
|
const uint8_t *__restrict__ a_e8_unused,
|
|
const uint8_t *__restrict__ b_e8_unused) {
|
|
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
|
|
|
kittens::gl<fp8e4m3, 1, 1, M, K> A{A_ptr, nullptr, nullptr, nullptr, nullptr};
|
|
kittens::gl<fp8e4m3, 1, 1, N, K> B{B_ptr, nullptr, nullptr, nullptr, nullptr};
|
|
kittens::gl<bf16, 1, 1, M, N> C{C_ptr, nullptr, nullptr, nullptr, nullptr};
|
|
|
|
constexpr int k_iters = K / BLOCK_K;
|
|
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
|
|
|
|
kittens::gl<fp8e8m0, k_iters * (M / BLOCK_ROW), 1, 16, 64> scale_A_gl{scale_A_ptr, nullptr, nullptr, nullptr, nullptr};
|
|
kittens::gl<fp8e8m0, k_iters * (N / BLOCK_COL), 1, 16, 64> scale_B_gl{scale_B_ptr, nullptr, nullptr, nullptr, nullptr};
|
|
|
|
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
|
|
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
|
|
using ST_Scale = st<fp8e8m0, 16, 64, st_16x64_s>;
|
|
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
|
|
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
|
|
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
|
|
|
|
__shared__ ST_A As[2][2];
|
|
__shared__ ST_B Bs[2][2];
|
|
__shared__ ST_Scale scale_A_smem[2], scale_B_smem[2];
|
|
|
|
RT_A a;
|
|
RT_B b0, b1;
|
|
RT_C cA, cB, cC, cD;
|
|
zero(cA); zero(cB); zero(cC); zero(cD);
|
|
|
|
constexpr int tiles_M = M / BLOCK_ROW;
|
|
constexpr int tiles_N = N / BLOCK_COL;
|
|
const int NUM_XCDS = 8;
|
|
const int WGM = 8;
|
|
int wgid = chiplet_transform_chunked(blockIdx.x, gridDim.x, NUM_XCDS, WGM * WGM);
|
|
int num_wgid_in_group = WGM * tiles_N;
|
|
int group_id = wgid / num_wgid_in_group;
|
|
int first_pid_m = group_id * WGM;
|
|
int group_size_m = min(tiles_M - first_pid_m, WGM);
|
|
int block_row = first_pid_m + ((wgid % num_wgid_in_group) % group_size_m);
|
|
int block_col = (wgid % num_wgid_in_group) / group_size_m;
|
|
int block_m = block_row * BLOCK_ROW;
|
|
int block_n = block_col * BLOCK_COL;
|
|
|
|
int warp_m = warpid() / WARPS_COL;
|
|
int warp_n = warpid() % WARPS_COL;
|
|
|
|
using T = fp8e4m3;
|
|
constexpr int bpt = ST_A::underlying_subtile_bytes_per_thread;
|
|
constexpr int bpm = bpt * NUM_THREADS;
|
|
constexpr int copies_A = HALF_ROW * BLOCK_K * sizeof(T) / bpm;
|
|
constexpr int copies_B = HALF_COL * BLOCK_K * sizeof(T) / bpm;
|
|
uint32_t sw_A[copies_A], sw_B[copies_B];
|
|
G::prefill_swizzled_offsets(As[0][0], A, sw_A);
|
|
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
|
|
|
|
const T *a_base = (const T *)&A[{0, 0, 0, 0}];
|
|
const T *b_base = (const T *)&B[{0, 0, 0, 0}];
|
|
const int a_row_stride = A.template stride<2>() * sizeof(T);
|
|
const int b_row_stride = B.template stride<2>() * sizeof(T);
|
|
i32x4 a_srd = make_srsrc(a_base, M * a_row_stride, a_row_stride);
|
|
i32x4 b_srd = make_srsrc(b_base, N * b_row_stride, b_row_stride);
|
|
|
|
const int wid = warpid() % NUM_WARPS;
|
|
constexpr int elem_per_warp = (16 / sizeof(T)) * kittens::WARP_THREADS;
|
|
uint32_t a_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
|
uint32_t a_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
|
uint32_t a_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
|
uint32_t a_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
|
uint32_t b_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
|
uint32_t b_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
|
uint32_t b_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
|
uint32_t b_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
|
|
|
int a_row_h0 = warp_m * REG_M;
|
|
int a_row_h1 = HALF_ROW + warp_m * REG_M;
|
|
int b_row_h0 = warp_n * REG_N;
|
|
int b_row_h1 = HALF_COL + warp_n * REG_N;
|
|
|
|
int tic = 0, toc = 1;
|
|
int tic_scales = 0, toc_scales = 1;
|
|
|
|
uint32_t a_lds[2][2] = {{a_lds_00, a_lds_01}, {a_lds_10, a_lds_11}};
|
|
uint32_t b_lds[2][2] = {{b_lds_00, b_lds_01}, {b_lds_10, b_lds_11}};
|
|
|
|
if constexpr (k_iters >= 6 && (k_iters % 2 == 0)) {
|
|
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
|
|
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
|
|
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
|
|
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][1]));
|
|
|
|
if (warp_m == 1) __builtin_amdgcn_s_barrier();
|
|
asm volatile("s_waitcnt vmcnt(4)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][0]));
|
|
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][0]));
|
|
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][1]));
|
|
asm volatile("s_waitcnt vmcnt(6)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
G::load(scale_A_smem[0], scale_A_gl, {0 * tiles_M + block_row, 0, 0, 0});
|
|
G::load(scale_B_smem[0], scale_B_gl, {0 * tiles_N + block_col, 0, 0, 0});
|
|
asm volatile("s_waitcnt vmcnt(0)");
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
int sa_idx = block_row, sb_idx = block_col;
|
|
|
|
#pragma unroll 2
|
|
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
|
|
if (k + 1 < k_iters) {
|
|
sa_idx += tiles_M; sb_idx += tiles_N;
|
|
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
|
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
|
}
|
|
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
|
load(b0, bs0);
|
|
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
|
load(a, as0);
|
|
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
|
|
asm volatile("s_waitcnt lgkmcnt(8)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
|
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
|
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
|
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
|
load(b1, bs1);
|
|
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
|
load(a, as1);
|
|
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
|
|
asm volatile("s_waitcnt vmcnt(6)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
}
|
|
|
|
{ // Epilogue k = k_iters - 2
|
|
int k = k_iters - 2;
|
|
if (k + 1 < k_iters) {
|
|
sa_idx += tiles_M; sb_idx += tiles_N;
|
|
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
|
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
|
}
|
|
asm volatile("s_waitcnt vmcnt(0)");
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
|
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
|
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
|
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
|
|
|
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
|
load(b0, bs0);
|
|
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
|
load(a, as0);
|
|
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
|
load(b1, bs1);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
|
load(a, as1);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
auto bs0_next = subtile_inplace<REG_N, BLOCK_K>(Bs[toc][0], {warp_n, 0});
|
|
load(b0, bs0_next);
|
|
asm volatile("s_waitcnt vmcnt(4)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
tic ^= 1; toc ^= 1;
|
|
tic_scales ^= 1; toc_scales ^= 1;
|
|
}
|
|
|
|
{ // Epilogue k = k_iters - 1 (final)
|
|
asm volatile("s_waitcnt vmcnt(0)");
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
|
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
|
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
|
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
|
|
|
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
|
load(a, as0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
|
load(b1, bs1);
|
|
asm volatile("s_waitcnt vmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
|
load(a, as1);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
|
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
}
|
|
|
|
} else {
|
|
#pragma unroll 1
|
|
for (int kk = 0; kk < k_iters; kk++) {
|
|
G::load(As[0][0], A, {0, 0, block_row * 2, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][0]));
|
|
G::load(As[0][1], A, {0, 0, block_row * 2 + 1, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][1]));
|
|
G::load(Bs[0][0], B, {0, 0, block_col * 2, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][0]));
|
|
G::load(Bs[0][1], B, {0, 0, block_col * 2 + 1, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][1]));
|
|
G::load(scale_A_smem[0], scale_A_gl, {kk * tiles_M + block_row, 0, 0, 0});
|
|
G::load(scale_B_smem[0], scale_B_gl, {kk * tiles_N + block_col, 0, 0, 0});
|
|
asm volatile("s_waitcnt vmcnt(0)");
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[0].data, a_row_h0);
|
|
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[0].data, a_row_h1);
|
|
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[0].data, b_row_h0);
|
|
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[0].data, b_row_h1);
|
|
|
|
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][0], {warp_n, 0}); load(b0, bs0);
|
|
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][1], {warp_n, 0}); load(b1, bs1);
|
|
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[0][0], {warp_m, 0}); load(a, as0);
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
|
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
|
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[0][1], {warp_m, 0}); load(a, as1);
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
|
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
|
__builtin_amdgcn_s_barrier();
|
|
}
|
|
}
|
|
|
|
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
|
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
|
|
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
|
store(C, cD, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
|
|
}
|