mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: fix bf16 gemm oob (#16603)
This commit is contained in:
parent
8862c7549c
commit
67a4f129c2
2 changed files with 3 additions and 2 deletions
|
|
@ -2752,7 +2752,8 @@ def custom_hk_bf16_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str) -> UOp:
|
|||
assert M % block_m == 0 and N % block_n == 0 and K % block_k == 0, f"invalid bf16 tile {(block_m, block_n, block_k)} for {(M, N, K)}"
|
||||
threads = UOp.special(64 * num_warps, "lidx0")
|
||||
workgroups = UOp.special((M // block_m) * (N // block_n), "gidx0")
|
||||
sink = UOp.sink(C.base, A.base, B.base, threads, workgroups,
|
||||
b_extra = args[0].base if len(args) >= 1 else B.base
|
||||
sink = UOp.sink(C.base, A.base, B.base, b_extra, threads, workgroups,
|
||||
arg=KernelInfo(f"hk_bf16_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K+M*N)*A.dtype.itemsize)))
|
||||
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
|
||||
src = (kittens_path/"gemm_bf16.cpp").read_text()
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ constexpr int NUM_THREADS = WARP_THREADS * NUM_WARPS;
|
|||
|
||||
using G = kittens::group<NUM_WARPS>;
|
||||
|
||||
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr) {
|
||||
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr, bf16 *b_unused) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
static_assert(M % BLOCK_SIZE == 0 && N % BLOCK_SIZE == 0 && K % K_STEP == 0);
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue