Compare commits

...

2 commits

Author SHA1 Message Date
Woze Parrot
5241ee5a6c
fix: override rocm path 2026-03-04 10:19:47 +00:00
George Hotz
138a6b6c40 fix fa forward building with clang 22 2026-03-04 17:42:44 +08:00
2 changed files with 20 additions and 19 deletions

View file

@ -66,7 +66,7 @@ template<int D, typename T=bf16, typename L=row_l, typename S=rt_32x16_s> using
template<int D, typename T=bf16, typename L=col_l, typename S=rt_16x32_s> using qo_tile_transposed = rt<T, D, Q_BLOCK_SIZE, L, S>;
template<int D, typename T=bf16, typename L=row_l, typename S=rt_32x16_s> using kv_tile = rt<T, KV_BLOCK_SIZE, D, L, S>;
template<int D, typename T=bf16, typename L=col_l, typename S=rt_16x32_s> using kv_tile_transposed = rt<T, D, KV_BLOCK_SIZE, L, S>;
template<int D, typename T=float, typename L=col_l, typename S=rt_16x32_4_s> using attn_tile = rt<T, KV_BLOCK_SIZE, Q_BLOCK_SIZE, L, S>;
template<typename T=float, typename L=col_l, typename S=rt_16x32_4_s> using attn_tile = rt<T, KV_BLOCK_SIZE, Q_BLOCK_SIZE, L, S>;
/**********************************************************/
template<int THR_X, int THR_Y>
@ -103,7 +103,7 @@ __device__ inline void mask_kv_tile(RT &dst, int q_abs, int k_abs, uint32_t neg_
#pragma unroll
for (int i = 0; i < dst.height; ++i) {
// Row base of the 32x* chunk produced by MFMA
// Row base of the 32x* chunk produced by MFMA
const int row_base = (i * 32) + ((lane >> 5) << 2); // multiplesof 4
// Relative index of the FIRST element in this row-chunk w.r.t. q_pos
@ -148,7 +148,7 @@ __device__ inline void mask_kv_tile(RT &dst, int q_abs, int k_abs, uint32_t neg_
/**********************************************************/
template<int D> struct attn_globals {
_gl_QKVO Qg, Kg, Vg, Og;
_gl_QKVO Qg, Kg, Vg, Og;
gl<float, -1, -1, -1, -1> L_vec;
dim3 grid() { return dim3(ATTN_H, ((ATTN_N / Q_BLOCK_SIZE + NUM_WARPS - 1) / NUM_WARPS), ATTN_B); }
dim3 block() { return dim3(NUM_THREADS); }
@ -196,10 +196,10 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
kv_tile<D, bf16, col_l, rt_16x32_4_s> v_reg;
qo_tile_transposed<D, float, col_l, rt_32x32_s> o_reg; // Output tile.
attn_tile<D, float, col_l, rt_32x32_s> att_block[2]; // attention tile, in float.
attn_tile<D, bf16, col_l, rt_32x32_s> att_block_bf16;
attn_tile<D, bf16, col_l, rt_16x32_4_s> att_block_bf16_in;
typename attn_tile<D, float, col_l, rt_32x32_s>::row_vec max_vec, norm_vec, max_vec_prev, scale_vec;
attn_tile<float, col_l, rt_32x32_s> att_block[2]; // attention tile, in float.
attn_tile<bf16, col_l, rt_32x32_s> att_block_bf16;
attn_tile<bf16, col_l, rt_16x32_4_s> att_block_bf16_in;
typename attn_tile<float, col_l, rt_32x32_s>::row_vec max_vec, norm_vec, max_vec_prev, scale_vec;
zero(o_reg);
zero(norm_vec);
@ -241,8 +241,8 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
zero(att_block[0]);
transpose(k_reg_transposed, k_reg);
mma_AtB(att_block[0], k_reg_transposed, q_reg_transposed, att_block[0]);
__builtin_amdgcn_sched_barrier(0);
if constexpr (causal) {
__builtin_amdgcn_sched_barrier(0);
if constexpr (causal) {
const int kv_end_pos = (1) * KV_BLOCK_SIZE;
if (__builtin_expect(q_start_pos < kv_end_pos, 0)) { // Only mask if needed
mask_kv_tile(att_block[0], tile_idx, 0, neg_inf_v, lane);
@ -269,7 +269,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
load(k_reg, k_smem[1]);
// All warps then collaboratively load in the third slice of K (K2) into shared memory
G::load<1, false>(k_smem[0], g.Kg, {batch_idx, 2, head_idx_kv, 0}, swizzled_offsets_K);
// All warps then collaboratively load in the second slice of V (V1) into shared memory
// All warps then collaboratively load in the second slice of V (V1) into shared memory
G::load<1, false>(v_smem[1], g.Vg, {batch_idx, 1, head_idx_kv, 0}, swizzled_offsets_V);
asm volatile("s_waitcnt lgkmcnt(0)");
asm volatile("s_waitcnt vmcnt(4)");
@ -288,7 +288,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
mul(norm_vec, norm_vec, scale_vec);
col_sum(norm_vec, att_block[0], norm_vec);
copy(att_block_bf16, att_block[0]);
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
att_block_bf16_in = *reinterpret_cast<attn_tile< bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
sched_barrier_exp_pairs<6, 3, 1>();
sched_barrier_pairs<10, 5, 1>();
__builtin_amdgcn_sched_barrier(0);
@ -296,7 +296,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
__builtin_amdgcn_sched_barrier(0);
// Cluster 1:
// Load K3 into shared
// Load K3 into shared
G::load<1, false>(k_smem[1], g.Kg, {batch_idx, j, head_idx_kv, 0}, swizzled_offsets_K);
// Load V0 into registers
load(v_reg, v_smem[0]);
@ -348,7 +348,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
mul(norm_vec, norm_vec, scale_vec);
col_sum(norm_vec, att_block[1], norm_vec);
copy(att_block_bf16, att_block[1]);
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
sched_barrier_exp_pairs<6, 3, 3>();
sched_barrier_pairs<10, 5, 3>();
__builtin_amdgcn_s_setprio(0);
@ -417,7 +417,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
col_sum(norm_vec, att_block[0], norm_vec);
copy(att_block_bf16, att_block[0]);
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
sched_barrier_exp_pairs<6, 3, 5>();
sched_barrier_pairs<10, 5, 5>();
__builtin_amdgcn_sched_barrier(0);
@ -482,7 +482,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
mul(norm_vec, norm_vec, scale_vec);
col_sum(norm_vec, att_block[1], norm_vec);
copy(att_block_bf16, att_block[1]);
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
sched_barrier_exp_pairs<6, 3, 7>();
sched_barrier_pairs<10, 5, 7>();
__builtin_amdgcn_sched_barrier(0);
@ -544,7 +544,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
mul(norm_vec, norm_vec, scale_vec);
col_sum(norm_vec, att_block[0], norm_vec);
copy(att_block_bf16, att_block[0]);
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
sched_barrier_exp_pairs<6, 3, 9>();
sched_barrier_pairs<10, 5, 9>();
__builtin_amdgcn_sched_barrier(0);
@ -586,7 +586,7 @@ __global__ void attend_ker(bf16 *O_ptr, float *L_vec_ptr, bf16 *Q_ptr, bf16 *K_p
col_sum(norm_vec, att_block[1], norm_vec);
copy(att_block_bf16, att_block[1]);
att_block_bf16_in = *reinterpret_cast<attn_tile<D, bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
att_block_bf16_in = *reinterpret_cast<attn_tile<bf16, col_l, rt_16x32_4_s>*>(&att_block_bf16);
__builtin_amdgcn_sched_barrier(0);
mul_col(o_reg, o_reg, scale_vec);

View file

@ -1,5 +1,5 @@
import ctypes, hashlib, tempfile, subprocess, pathlib, shutil
from tinygrad.helpers import system
from tinygrad.helpers import system, getenv
from tinygrad.runtime.autogen import comgr
try:
comgr.amd_comgr_get_version(ctypes.byref(major:=ctypes.c_uint64()), ctypes.byref(minor:=ctypes.c_uint64()))
@ -110,8 +110,9 @@ class HIPCCCompiler(Compiler):
srcf.write(src.encode())
srcf.flush()
rocm_path = getenv("ROCM_PATH", "/opt/rocm")
subprocess.run(["hipcc", "-c", "-emit-llvm", "--cuda-device-only", "-O3", "-mcumode",
f"--offload-arch={self.arch}", "-I/opt/rocm/include/hip", "-o", bcf.name, srcf.name] + self.extra_options, check=True)
f"--offload-arch={self.arch}", f"-I{rocm_path}/include/hip", "-o", bcf.name, srcf.name] + self.extra_options, check=True)
subprocess.run(["hipcc", "-target", "amdgcn-amd-amdhsa", f"-mcpu={self.arch}",
"-O3", "-mllvm", "-amdgpu-internalize-symbols", "-c", "-o", libf.name, bcf.name] + self.extra_options, check=True)