mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
fix_fa_fwd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5241ee5a6c |
||
|
|
138a6b6c40 |
2 changed files with 20 additions and 19 deletions
|
|
@ -66,7 +66,7 @@ template<int D, typename T=bf16, typename L=row_l, typename S=rt_32x16_s> using
|
|||
template<int D, typename T=bf16, typename L=col_l, typename S=rt_16x32_s> using qo_tile_transposed = rt<T, D, Q_BLOCK_SIZE, L, S>;
|
||||
template<int D, typename T=bf16, typename L=row_l, typename S=rt_32x16_s> using kv_tile = rt<T, KV_BLOCK_SIZE, D, L, S>;
|
||||
template<int D, typename T=bf16, typename L=col_l, typename S=rt_16x32_s> using kv_tile_transposed = rt<T, D, KV_BLOCK_SIZE, L, S>;
|
||||
template<int D, typename T=float, typename L=col_l, typename S=rt_16x32_4_s> using attn_tile = rt<T, KV_BLOCK_SIZE, Q_BLOCK_SIZE, L, S>;
|
||||
template<typename T=float, typename L=col_l, typename S=rt_16x32_4_s> using attn_tile = rt<T, KV_BLOCK_SIZE, Q_BLOCK_SIZE, L, S>;
|
||||
|
||||
/**********************************************************/
|
||||
template<int THR_X, int THR_Y>
|
||||
|
|
@ -103,7 +103,7 @@ __device__ inline void mask_kv_tile(RT &dst, int q_abs, int k_abs, uint32_t neg_
|
|||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < dst.height; ++i) {
|
||||
// Row base of the 32x* chunk produced by MFMA
|
||||
// Row base of the 32x* chunk produced by MFMA
|
||||
const int row_base = (i * 32) + ((lane >> 5) << 2); // multiplesof 4
|
||||
|
||||
// Relative index of the FIRST element in this row-chunk w.r.t. q_pos
|
||||
|
|
@ -148,7 +148,7 @@ __device__ inline void mask_kv_tile(RT &dst, int q_abs, int k_abs, uint32_t neg_
|
|||
/**********************************************************/
|
||||
|
||||
template<int D> struct attn_globals {
|
||||
_gl_QKVO Qg, Kg, Vg, Og;
|
||||
_gl_QKVO Qg, Kg, Vg, Og;
|
||||
gl<float, -1, -1, -1, -1> L_vec;
|
||||
dim3 grid() { return dim3(ATTN_H, ((ATTN_N / Q_BLOCK_SIZE + NUM_WARPS - 1) / NUM_WARPS), ATTN_B); }
|
||||
dim3 block() { return dim3(NUM_THREADS); }
|
||||
|
|
@ -196,10 +196,10 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
|
||||
kv_tile<D, bf16, col_l, rt_16x32_4_s> v_reg;
|
||||
qo_tile_transposed<D, float, col_l, rt_32x32_s> o_reg; // Output tile.
|
||||
attn_tile<D, float, col_l, rt_32x32_s> att_block[2]; // attention tile, in float.
|
||||
attn_tile<D, bf16, col_l, rt_32x32_s> att_block_bf16;
|
||||
attn_tile<D, bf16, col_l, rt_16x32_4_s> att_block_bf16_in;
|
||||
typename attn_tile<D, float, col_l, rt_32x32_s>::row_vec max_vec, norm_vec, max_vec_prev, scale_vec;
|
||||
attn_tile<float, col_l, rt_32x32_s> att_block[2]; // attention tile, in float.
|
||||
attn_tile<bf16, col_l, rt_32x32_s> att_block_bf16;
|
||||
attn_tile<bf16, col_l, rt_16x32_4_s> att_block_bf16_in;
|
||||
typename attn_tile<float, col_l, rt_32x32_s>::row_vec max_vec, norm_vec, max_vec_prev, scale_vec;
|
||||
|
||||
zero(o_reg);
|
||||
zero(norm_vec);
|
||||
|
|
@ -241,8 +241,8 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
zero(att_block[0]);
|
||||
transpose(k_reg_transposed, k_reg);
|
||||
mma_AtB(att_block[0], k_reg_transposed, q_reg_transposed, att_block[0]);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr (causal) {
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr (causal) {
|
||||
const int kv_end_pos = (1) * KV_BLOCK_SIZE;
|
||||
if (__builtin_expect(q_start_pos < kv_end_pos, 0)) { // Only mask if needed
|
||||
mask_kv_tile(att_block[0], tile_idx, 0, neg_inf_v, lane);
|
||||
|
|
@ -269,7 +269,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
load(k_reg, k_smem[1]);
|
||||
// All warps then collaboratively load in the third slice of K (K2) into shared memory
|
||||
G::load<1, false>(k_smem[0], g.Kg, {batch_idx, 2, head_idx_kv, 0}, swizzled_offsets_K);
|
||||
// All warps then collaboratively load in the second slice of V (V1) into shared memory
|
||||
// All warps then collaboratively load in the second slice of V (V1) into shared memory
|
||||
G::load<1, false>(v_smem[1], g.Vg, {batch_idx, 1, head_idx_kv, 0}, swizzled_offsets_V);
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
asm volatile("s_waitcnt vmcnt(4)");
|
||||
|
|
@ -288,7 +288,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
mul(norm_vec, norm_vec, scale_vec);
|
||||
col_sum(norm_vec, att_block[0], norm_vec);
|
||||
copy(att_block_bf16, att_block[0]);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile< bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
sched_barrier_exp_pairs<6, 3, 1>();
|
||||
sched_barrier_pairs<10, 5, 1>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
|
@ -296,7 +296,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Cluster 1:
|
||||
// Load K3 into shared
|
||||
// Load K3 into shared
|
||||
G::load<1, false>(k_smem[1], g.Kg, {batch_idx, j, head_idx_kv, 0}, swizzled_offsets_K);
|
||||
// Load V0 into registers
|
||||
load(v_reg, v_smem[0]);
|
||||
|
|
@ -348,7 +348,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
mul(norm_vec, norm_vec, scale_vec);
|
||||
col_sum(norm_vec, att_block[1], norm_vec);
|
||||
copy(att_block_bf16, att_block[1]);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
sched_barrier_exp_pairs<6, 3, 3>();
|
||||
sched_barrier_pairs<10, 5, 3>();
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
|
|
@ -417,7 +417,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
|
||||
col_sum(norm_vec, att_block[0], norm_vec);
|
||||
copy(att_block_bf16, att_block[0]);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
sched_barrier_exp_pairs<6, 3, 5>();
|
||||
sched_barrier_pairs<10, 5, 5>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
|
@ -482,7 +482,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
mul(norm_vec, norm_vec, scale_vec);
|
||||
col_sum(norm_vec, att_block[1], norm_vec);
|
||||
copy(att_block_bf16, att_block[1]);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
sched_barrier_exp_pairs<6, 3, 7>();
|
||||
sched_barrier_pairs<10, 5, 7>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
|
@ -544,7 +544,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
mul(norm_vec, norm_vec, scale_vec);
|
||||
col_sum(norm_vec, att_block[0], norm_vec);
|
||||
copy(att_block_bf16, att_block[0]);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
sched_barrier_exp_pairs<6, 3, 9>();
|
||||
sched_barrier_pairs<10, 5, 9>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
|
@ -586,7 +586,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
|
|||
|
||||
col_sum(norm_vec, att_block[1], norm_vec);
|
||||
copy(att_block_bf16, att_block[1]);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
mul_col(o_reg, o_reg, scale_vec);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import ctypes, hashlib, tempfile, subprocess, pathlib, shutil
|
||||
from tinygrad.helpers import system
|
||||
from tinygrad.helpers import system, getenv
|
||||
from tinygrad.runtime.autogen import comgr
|
||||
try:
|
||||
comgr.amd_comgr_get_version(ctypes.byref(major:=ctypes.c_uint64()), ctypes.byref(minor:=ctypes.c_uint64()))
|
||||
|
|
@ -110,8 +110,9 @@ class HIPCCCompiler(Compiler):
|
|||
srcf.write(src.encode())
|
||||
srcf.flush()
|
||||
|
||||
rocm_path = getenv("ROCM_PATH", "/opt/rocm")
|
||||
subprocess.run(["hipcc", "-c", "-emit-llvm", "--cuda-device-only", "-O3", "-mcumode",
|
||||
f"--offload-arch={self.arch}", "-I/opt/rocm/include/hip", "-o", bcf.name, srcf.name] + self.extra_options, check=True)
|
||||
f"--offload-arch={self.arch}", f"-I{rocm_path}/include/hip", "-o", bcf.name, srcf.name] + self.extra_options, check=True)
|
||||
subprocess.run(["hipcc", "-target", "amdgcn-amd-amdhsa", f"-mcpu={self.arch}",
|
||||
"-O3", "-mllvm", "-amdgpu-internalize-symbols", "-c", "-o", libf.name, bcf.name] + self.extra_options, check=True)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue