mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
ranged_cal
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b384c27314 | ||
|
|
5ed68aeed5 | ||
|
|
f682af2a31 | ||
|
|
62dbf12655 | ||
|
|
b17e15d1aa |
5 changed files with 299 additions and 59 deletions
139
extra/callrange/amd_call_matmul.py
Normal file
139
extra/callrange/amd_call_matmul.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
from typing import Callable
|
||||
from tinygrad import UOp, dtypes, Device, Tensor, getenv, function
|
||||
from tinygrad.uop.ops import AxisType, AddrSpace
|
||||
|
||||
def simple_function(fxn:Callable[..., UOp]) -> Callable[..., UOp]:
|
||||
def wrapper(*args:UOp) -> UOp:
|
||||
params:list[UOp] = [x.param_like(i) for i,x in enumerate(args)]
|
||||
return fxn(*params).call(*args)
|
||||
return wrapper
|
||||
|
||||
THREADS_PER_BLOCK = 128
|
||||
WARP_SIZE = 32
|
||||
|
||||
# Register tile sizes (per-thread accumulator tile of C)
|
||||
TN = 4 # columns per thread
|
||||
TM = 4 # rows per thread
|
||||
|
||||
WAVE_TILE_N = 128
|
||||
WAVE_TILE_M = 32
|
||||
|
||||
LANES_PER_WAVE_X = 8
|
||||
LANES_PER_WAVE_Y = 4
|
||||
ITERS_PER_WAVE_N = 4 #WAVE_TILE_N // (LANES_PER_WAVE_X * TN)
|
||||
ITERS_PER_WAVE_M = 2 #WAVE_TILE_M // (LANES_PER_WAVE_Y * TM)
|
||||
|
||||
WAVES_IN_BLOCK_Y = 4
|
||||
WAVES_IN_BLOCK_X = 1
|
||||
|
||||
|
||||
N = getenv("N", 4096)
|
||||
M = K = N
|
||||
|
||||
# Threadblock tile sizes (block-level tile of C that a block computes)
|
||||
BLOCK_N = 128 # columns of C (N-dim) per block
|
||||
BLOCK_M = 128 # rows of C (M-dim) per block
|
||||
BLOCK_K = 8 # K-slice per block iteration
|
||||
|
||||
@simple_function
|
||||
def slice_matmul(c_regs, a_local, b_local):
|
||||
# 2x
|
||||
A_col = UOp.placeholder((ITERS_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder((ITERS_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
|
||||
|
||||
pass
|
||||
|
||||
@simple_function
|
||||
def compute_local(c:UOp, a_local:UOp, b_local:UOp) -> UOp:
|
||||
# this is the LID level on the GPU, here we can define regs
|
||||
tid = UOp.special(THREADS_PER_BLOCK, "lidx0")
|
||||
waveIdx = (tid // WARP_SIZE) % WAVES_IN_BLOCK_X
|
||||
waveIdy = (tid // WARP_SIZE) // WAVES_IN_BLOCK_X
|
||||
assert waveIdy.vmax+1 == WAVES_IN_BLOCK_Y
|
||||
|
||||
laneIdx = (tid % WARP_SIZE) % LANES_PER_WAVE_X
|
||||
laneIdy = (tid % WARP_SIZE) // LANES_PER_WAVE_X
|
||||
assert laneIdy.vmax+1 == LANES_PER_WAVE_Y
|
||||
|
||||
A_col = UOp.placeholder((ITERS_PER_WAVE_M*TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder((ITERS_PER_WAVE_N*TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
|
||||
# do the math
|
||||
A_col = A_col.assign(a_local[k_tile].reshape(WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM)[waveIdy, :, laneIdy, :].flatten())
|
||||
B_row = B_row.assign(b_local[k_tile].reshape(WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)[waveIdx, :, laneIdx, :].flatten())
|
||||
c_regs += A_col.reshape(-1, 1) * B_row.reshape(1, -1) #
|
||||
c_regs
|
||||
|
||||
|
||||
|
||||
@simple_function
|
||||
def load_local(a_local, b_local, a_global, b_global):
|
||||
# NOTE: it ends this range, so there's a BARRIER
|
||||
tid = UOp.special(THREADS_PER_BLOCK, "lidx0")
|
||||
return UOp.group(
|
||||
a_local[:, tid].store(a_global[tid, :]),
|
||||
b_local[:, tid].store(b_global[:, tid]))
|
||||
|
||||
@simple_function
|
||||
def reg_matmul(c_regs, a_local, b_local):
|
||||
A_col = UOp.placeholder((ITERS_PER_WAVE_M*TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder((ITERS_PER_WAVE_N*TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
|
||||
|
||||
|
||||
@simple_function
|
||||
def local_matmul(c:UOp, a:UOp, b:UOp, a_local:UOp, b_local:UOp):
|
||||
tid = UOp.special(THREADS_PER_BLOCK, "lidx0")
|
||||
waveIdx = (tid // WARP_SIZE) % WAVES_IN_BLOCK_X
|
||||
waveIdy = (tid // WARP_SIZE) // WAVES_IN_BLOCK_X
|
||||
laneIdx = (tid % WARP_SIZE) % LANES_PER_WAVE_X
|
||||
laneIdy = (tid % WARP_SIZE) // LANES_PER_WAVE_X
|
||||
|
||||
# this is the LID level on the GPU, this (and below) is where we define REGs
|
||||
c_regs = UOp.placeholder((ITERS_PER_WAVE_M*TM, ITERS_PER_WAVE_N*TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
|
||||
# 128x128, Kx128, Kx128
|
||||
k_tile = UOp.range(N // BLOCK_K, 0, AxisType.REDUCE)*BLOCK_K
|
||||
fxn = reg_matmul(c_regs.assign(0),
|
||||
a_local[:, tid].assign(a[k_tile:k_tile+BLOCK_K, tid]),
|
||||
b_local[:, tid].assign(b[k_tile:k_tile+BLOCK_K, tid]))
|
||||
|
||||
# do math
|
||||
c = c.reshape(WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
|
||||
WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
|
||||
return c[waveIdy, :, laneIdy, :, waveIdx, :, laneIdx, :].store(c_regs.after(fxn))
|
||||
|
||||
@simple_function
|
||||
def global_matmul(c:UOp, a:UOp, b:UOp):
|
||||
# this is the GID level on the GPU, this is where we define LOCAL buffers shared across lids
|
||||
gx = UOp.range(N//BLOCK_N, 0, AxisType.GLOBAL) * BLOCK_N
|
||||
gy = UOp.range(M//BLOCK_M, 1, AxisType.GLOBAL) * BLOCK_M
|
||||
a_local = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
b_local = UOp.placeholder((BLOCK_K, BLOCK_M), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
|
||||
return local_matmul(c[gx:gx+BLOCK_N, gy:gy+BLOCK_M], a.permute(1,0)[:, gx:gx+BLOCK_N], b[:, gy:gy+BLOCK_M], a_local, b_local)
|
||||
|
||||
#ll = load_local(a_local, b_local, a.permute(1,0)[:, gx:gx+BLOCK_N], b[:, gy:gy+BLOCK_M])
|
||||
#return compute_local(c[gx:gx+BLOCK_N, gy:gy+BLOCK_M], a_local.after(ll), b_local.after(ll))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# this is the outer lvel on the GPU, this is where we define GLOBAL buffers
|
||||
C = Tensor.empty(N, M)
|
||||
A = Tensor.randn(N, K)
|
||||
B = Tensor.randn(K, M)
|
||||
c_out = C.call(A, B, fxn=global_matmul).numpy()
|
||||
|
||||
#C = UOp.new_buffer(Device.DEFAULT, N*M, dtypes.float).reshape(N,M)
|
||||
#A = UOp.new_buffer(Device.DEFAULT, N*K, dtypes.float).reshape(N,K)
|
||||
#B = UOp.new_buffer(Device.DEFAULT, K*M, dtypes.float).reshape(K,M)
|
||||
#global_matmul(C, A, B).realize()
|
||||
|
||||
# input matmuls
|
||||
#c = UOp.param(0, dtypes.float, (N, M))
|
||||
#a = UOp.param(1, dtypes.float, (N, K))
|
||||
#b = UOp.param(2, dtypes.float, (K, M))
|
||||
|
||||
|
||||
#ba = a.rearrange("(n bn) (k bk) -> n k bn bk", bn=BLOCK_N, bk=BLOCK_K)[gx, k_tile_range]
|
||||
#bb = b.rearrange("(k bk) (m bm) -> k m bk bm", bk=BLOCK_K, bm=BLOCK_M)[k_tile_range, gy]
|
||||
#bc = c.rearrange("(n bn) (m bm) -> n m bn bm", bn=BLOCK_N, bm=BLOCK_M)[gx, gy]
|
||||
85
extra/callrange/test.py
Normal file
85
extra/callrange/test.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
from tinygrad import UOp, dtypes, Device, Tensor
|
||||
|
||||
if __name__ == "__main__":
|
||||
B0 = UOp.new_buffer(Device.DEFAULT, 100, dtypes.float).reshape(10,10)
|
||||
B1 = UOp.new_buffer(Device.DEFAULT, 100, dtypes.float).reshape(10,10)
|
||||
|
||||
|
||||
b0 = UOp.param(0, dtypes.float, (10,10))
|
||||
b1 = UOp.param(1, dtypes.float, (10,10))
|
||||
r0 = UOp.range(10, axis_id=0)
|
||||
r1 = UOp.range(10, axis_id=1)
|
||||
|
||||
fxn = (b0[r0, r1] + b1[r0, r1]).call(B0, B1)
|
||||
t = Tensor(fxn)
|
||||
t.realize()
|
||||
|
||||
# gemm (N,N)
|
||||
|
||||
# (N//k, k, N//k, k)
|
||||
|
||||
|
||||
# what if call just implicitly ends all ranges and you don't need to connect them?
|
||||
# you do have to connect them, and it does end the ranges
|
||||
|
||||
# if assign (store+after) is on call, we move the store into the call (indexed with the ranges) and replace the assign with an after
|
||||
|
||||
|
||||
def gemm(A, B):
|
||||
N = 4096
|
||||
k = 128
|
||||
|
||||
ia = UOp.param(0, dtypes.float, (k, k)).reshape(k, 1, k)
|
||||
ib = UOp.param(1, dtypes.float, (k, k)).reshape(1, k, k)
|
||||
gemm_fxn = (ia * ib).sum(2) # <-- rangeify this
|
||||
|
||||
a = UOp.param(0, dtypes.float, (N, N))
|
||||
b = UOp.param(1, dtypes.float, (N, N))
|
||||
r0 = UOp.range(N//k, 0)
|
||||
r1 = UOp.range(N//k, 1)
|
||||
local_fxn = gemm_fxn.call(a.reshape(N//k, k, N//k, k)[r0, :, r1, :], b.reshape(N//k, k, N//k, k)[r0, :, r1, :], r0, r1).permute(0,2,1,3).reshape(N,N)
|
||||
|
||||
fxn = local_fxn.call(A,B)
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
a = UOp.param(0, dtypes.float, (N//k, k, N//k, k))
|
||||
b = UOp.param(1, dtypes.float, (N//k, k, N//k, k))
|
||||
|
||||
|
||||
|
||||
# inner kxk GEMM (are WMMAs calls?)
|
||||
ia = UOp.param(0, dtypes.float, (k,k)).reshape(k, 1, k)
|
||||
ib = UOp.param(1, dtypes.float, (k,k)).reshape(1, k, k)
|
||||
r0 = UOp.range(N//k, 0)
|
||||
r1 = UOp.range(N//k, 1)
|
||||
fxn = (ia * ib).sum(2).call(a[:, r0, :, r1], b[:, r0, :, r1]) # this call ends these ranges implicitly
|
||||
assert fxn.shape == (N//k, N//k, k, k)
|
||||
|
||||
|
||||
#.call(A, B, UOp.range(N//k), UOp.range(N//k))
|
||||
|
||||
#r0 = UOp.param(2, dtypes.index, (), vmin_vmax=(0, N//k-1))
|
||||
#r1 = UOp.param(3, dtypes.index, (), vmin_vmax=(0, N//k-1))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Q = [batch, seq_len, heads, dim]
|
||||
# K = [batch, seq_len, head_kv, dim]
|
||||
# V = [batch, seq_len, head_kv, dim]
|
||||
|
||||
|
||||
|
||||
|
|
@ -6,8 +6,9 @@ from tinygrad.dtype import AddrSpace
|
|||
from tinygrad.helpers import getenv
|
||||
|
||||
N = getenv("N", 4096)
|
||||
M = K = N
|
||||
run_count = getenv("CNT", 5)
|
||||
M = getenv("M", N)
|
||||
K = getenv("K", N)
|
||||
NUM_RUNS = getenv("CNT", 5)
|
||||
|
||||
# ---------------------------
|
||||
# launch/config constants
|
||||
|
|
@ -19,6 +20,9 @@ WARP_SIZE = 32
|
|||
BLOCK_N = 128 # columns of C (N-dim) per block
|
||||
BLOCK_M = 128 # rows of C (M-dim) per block
|
||||
BLOCK_K = 8 # K-slice per block iteration
|
||||
assert N % BLOCK_N == 0, f"N ({N}) must be a multiple of BLOCK_N ({BLOCK_N})"
|
||||
assert M % BLOCK_M == 0, f"M ({M}) must be a multiple of BLOCK_M ({BLOCK_M})"
|
||||
assert K % BLOCK_K == 0, f"K ({K}) must be a multiple of BLOCK_K ({BLOCK_K})"
|
||||
|
||||
# Register tile sizes (per-thread accumulator tile of C)
|
||||
TN = 4 # columns per thread
|
||||
|
|
@ -36,16 +40,16 @@ WAVE_TILE_N = 128 if is_kernel5 else 64
|
|||
WAVE_TILE_M = BLOCK_N * BLOCK_M // WARPS_PER_BLOCK // WAVE_TILE_N
|
||||
assert BLOCK_N % WAVE_TILE_N == 0, "BN must be a multiple of WN"
|
||||
assert BLOCK_M % WAVE_TILE_M == 0, "BM must be a multiple of WM"
|
||||
WAVES_IN_BLOCK_X = BLOCK_N // WAVE_TILE_N
|
||||
WAVES_IN_BLOCK_Y = BLOCK_M // WAVE_TILE_M
|
||||
assert WAVES_IN_BLOCK_X * WAVES_IN_BLOCK_Y == WARPS_PER_BLOCK, "wave grid must match warps/block"
|
||||
WAVES_PER_BLOCK_N = BLOCK_N // WAVE_TILE_N
|
||||
WAVES_PER_BLOCK_M = BLOCK_M // WAVE_TILE_M
|
||||
assert WAVES_PER_BLOCK_N * WAVES_PER_BLOCK_M == WARPS_PER_BLOCK, "wave grid must match warps/block"
|
||||
|
||||
LANES_PER_WAVE_X = 8
|
||||
LANES_PER_WAVE_Y = 4
|
||||
ITERS_PER_WAVE_N = WAVE_TILE_N // (LANES_PER_WAVE_X * TN)
|
||||
ITERS_PER_WAVE_M = WAVE_TILE_M // (LANES_PER_WAVE_Y * TM)
|
||||
assert WAVE_TILE_N % (LANES_PER_WAVE_X * TN) == 0, "WAVE_TILE_N must be divisible by LANES_PER_WAVE_X*TN"
|
||||
assert WAVE_TILE_M % (LANES_PER_WAVE_Y * TM) == 0, "WAVE_TILE_M must be divisible by LANES_PER_WAVE_Y*TM"
|
||||
LANES_PER_WAVE_N = 8
|
||||
LANES_PER_WAVE_M = 4
|
||||
REG_TILES_PER_WAVE_N = WAVE_TILE_N // (LANES_PER_WAVE_N * TN)
|
||||
REG_TILES_PER_WAVE_M = WAVE_TILE_M // (LANES_PER_WAVE_M * TM)
|
||||
assert WAVE_TILE_N % (LANES_PER_WAVE_N * TN) == 0, "WAVE_TILE_N must be divisible by LANES_PER_WAVE_N*TN"
|
||||
assert WAVE_TILE_M % (LANES_PER_WAVE_M * TM) == 0, "WAVE_TILE_M must be divisible by LANES_PER_WAVE_M*TM"
|
||||
|
||||
def rngs_for_shape(shape:tuple[sint, ...], rng:int, axis_type=AxisType.LOOP): return [UOp.range(s, rng+i, axis_type) for i,s in enumerate(shape)]
|
||||
def copy(dest:UOp, src:UOp, rng:int, set=False, upcast=False):
|
||||
|
|
@ -58,41 +62,41 @@ def hand_spec_kernel3():
|
|||
# ---------------------------
|
||||
# block indices & placeholders
|
||||
# ---------------------------
|
||||
blockIdx_x = UOp.special(N // BLOCK_N, "gidx0")
|
||||
blockIdx_y = UOp.special(N // BLOCK_M, "gidx1")
|
||||
block_id_n = UOp.special(N // BLOCK_N, "gidx0")
|
||||
block_id_m = UOp.special(M // BLOCK_M, "gidx1")
|
||||
|
||||
a = UOp.placeholder((N, N), dtypes.float, slot=1)
|
||||
b = UOp.placeholder((N, N), dtypes.float, slot=2)
|
||||
c = UOp.placeholder((N, N), dtypes.float, slot=0)
|
||||
a = UOp.placeholder((M, K), dtypes.float, slot=1)
|
||||
b = UOp.placeholder((K, N), dtypes.float, slot=2)
|
||||
c = UOp.placeholder((M, N), dtypes.float, slot=0)
|
||||
|
||||
# index the output with the globals
|
||||
c = c.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_N, BLOCK_N)[blockIdx_y, :, blockIdx_x, :]
|
||||
c = c.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_N, BLOCK_N)[block_id_m, :, block_id_n, :]
|
||||
|
||||
# open the main reduction range
|
||||
k_tile_range = UOp.range(N // BLOCK_K, 0, AxisType.REDUCE)
|
||||
a = a.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_K, BLOCK_K)[blockIdx_y, :, k_tile_range, :]
|
||||
b = b.reshape(N // BLOCK_K, BLOCK_K, N // BLOCK_N, BLOCK_N)[k_tile_range, :, blockIdx_x, :]
|
||||
k_tile_range = UOp.range(K // BLOCK_K, 0, AxisType.REDUCE)
|
||||
a = a.reshape(M // BLOCK_M, BLOCK_M, K // BLOCK_K, BLOCK_K)[block_id_m, :, k_tile_range, :]
|
||||
b = b.reshape(K // BLOCK_K, BLOCK_K, N // BLOCK_N, BLOCK_N)[k_tile_range, :, block_id_n, :]
|
||||
|
||||
# globals are no longer used, they are already in the indexes
|
||||
del blockIdx_y, blockIdx_x
|
||||
del block_id_m, block_id_n
|
||||
|
||||
# ---------------------------
|
||||
# GLOBAL -> LOCAL (As, Bs)
|
||||
# GLOBAL -> LOCAL (A_local, B_local)
|
||||
# ---------------------------
|
||||
tid = UOp.special(THREADS_PER_BLOCK, "lidx0")
|
||||
|
||||
# A: read BM x BK tiles (permute on store into locals)
|
||||
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
|
||||
As = UOp.placeholder((BLOCK_K, BM_As_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
|
||||
As_store = copy(As.permute((1,0)).reshape(-1, THREADS_PER_BLOCK)[:, tid], a.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=100)
|
||||
BM_A_local_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
|
||||
A_local = UOp.placeholder((BLOCK_K, BM_A_local_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
|
||||
A_local_store = copy(A_local.permute((1,0)).reshape(-1, THREADS_PER_BLOCK)[:, tid], a.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=100)
|
||||
|
||||
# B: read BK x BN tiles
|
||||
Bs = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
|
||||
Bs_store = copy(Bs.reshape(-1, THREADS_PER_BLOCK)[:, tid], b.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=200)
|
||||
B_local = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
|
||||
B_local_store = copy(B_local.reshape(-1, THREADS_PER_BLOCK)[:, tid], b.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=200)
|
||||
|
||||
# TODO: can we automate barrier?
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
As, Bs = As.after(barrier), Bs.after(barrier)
|
||||
barrier = UOp.barrier(A_local_store, B_local_store)
|
||||
A_local, B_local = A_local.after(barrier), B_local.after(barrier)
|
||||
|
||||
# open inner k range
|
||||
k = UOp.range(BLOCK_K, 3, AxisType.REDUCE)
|
||||
|
|
@ -100,31 +104,33 @@ def hand_spec_kernel3():
|
|||
# ---------------------------
|
||||
# LOCAL -> REG (per-wave tiles)
|
||||
# ---------------------------
|
||||
waveIdx = (tid // WARP_SIZE) % WAVES_IN_BLOCK_X
|
||||
waveIdy = (tid // WARP_SIZE) // WAVES_IN_BLOCK_X
|
||||
assert waveIdy.vmax+1 == WAVES_IN_BLOCK_Y
|
||||
waveIdx = (tid // WARP_SIZE) % WAVES_PER_BLOCK_N
|
||||
waveIdy = (tid // WARP_SIZE) // WAVES_PER_BLOCK_N
|
||||
assert waveIdy.vmax+1 == WAVES_PER_BLOCK_M
|
||||
|
||||
laneIdx = (tid % WARP_SIZE) % LANES_PER_WAVE_X
|
||||
laneIdy = (tid % WARP_SIZE) // LANES_PER_WAVE_X
|
||||
assert laneIdy.vmax+1 == LANES_PER_WAVE_Y
|
||||
laneIdx = (tid % WARP_SIZE) % LANES_PER_WAVE_N
|
||||
laneIdy = (tid % WARP_SIZE) // LANES_PER_WAVE_N
|
||||
assert laneIdy.vmax+1 == LANES_PER_WAVE_M
|
||||
|
||||
A_col = UOp.placeholder((ITERS_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
A_col = copy(A_col, As[k, :].reshape(WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM)[waveIdy, :, laneIdy, :], 300, set=True, upcast=True)
|
||||
A_col = UOp.placeholder((REG_TILES_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
A_local_slice = A_local[k, :].reshape(WAVES_PER_BLOCK_M, REG_TILES_PER_WAVE_M, LANES_PER_WAVE_M, TM)[waveIdy, :, laneIdy, :]
|
||||
A_col = copy(A_col, A_local_slice , 300, set=True, upcast=True)
|
||||
|
||||
B_row = UOp.placeholder((ITERS_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
B_row = copy(B_row, Bs[k, :].reshape(WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)[waveIdx, :, laneIdx, :], 400, set=True, upcast=True)
|
||||
B_row = UOp.placeholder((REG_TILES_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
B_local_slice = B_local[k, :].reshape(WAVES_PER_BLOCK_N, REG_TILES_PER_WAVE_N, LANES_PER_WAVE_N, TN)[waveIdx, :, laneIdx, :]
|
||||
B_row = copy(B_row, B_local_slice, 400, set=True, upcast=True)
|
||||
|
||||
# ---------------------------
|
||||
# FMA: c_regs += A_col * B_row
|
||||
# ---------------------------
|
||||
c_regs = UOp.placeholder((ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
c_regs = UOp.placeholder((REG_TILES_PER_WAVE_M, TM, REG_TILES_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
i = UOp.range(c_regs.size, 16)
|
||||
c_regs = c_regs.after(c_regs.flatten()[i].store(0.0).end(i))
|
||||
|
||||
# TODO: why don't these work as upcast?
|
||||
# why if the ranges merge is it slow?!? (if you change the order on end, they will merge. big slowdown on METAL)
|
||||
iterWaveM, yt, iterWaveN, xt = rngs = rngs_for_shape(c_regs.shape, 500)
|
||||
sink = c_regs[*rngs].store(c_regs.after(k)[*rngs] + A_col[iterWaveM, yt] * B_row[iterWaveN, xt]).end(iterWaveM, iterWaveN, yt, xt)
|
||||
iter_m, t_m, iter_n, t_n = rngs = rngs_for_shape(c_regs.shape, 500)
|
||||
sink = c_regs[*rngs].store(c_regs.after(k)[*rngs] + A_col[iter_m, t_m] * B_row[iter_n, t_n]).end(iter_m, iter_n, t_m, t_n)
|
||||
|
||||
# Close k, sync, and close K tiles
|
||||
sink = sink.end(k).barrier().end(k_tile_range)
|
||||
|
|
@ -132,28 +138,28 @@ def hand_spec_kernel3():
|
|||
# ---------------------------
|
||||
# REG -> GLOBAL (epilogue)
|
||||
# ---------------------------
|
||||
c = c.reshape(WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
|
||||
WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
|
||||
c = c.reshape(WAVES_PER_BLOCK_M, REG_TILES_PER_WAVE_M, LANES_PER_WAVE_M, TM,
|
||||
WAVES_PER_BLOCK_N, REG_TILES_PER_WAVE_N, LANES_PER_WAVE_N, TN)
|
||||
c = c[waveIdy, :, laneIdy, :,
|
||||
waveIdx, :, laneIdx, :]
|
||||
sink = copy(c, c_regs.after(sink), rng=600)
|
||||
|
||||
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
|
||||
|
||||
def test_matmul(sink:UOp, dtype=dtypes.float32, N=N):
|
||||
def test_matmul(sink:UOp, dtype=dtypes.float32, M=M, N=N, K=K):
|
||||
rng = np.random.default_rng()
|
||||
a = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
|
||||
b = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
|
||||
hc = Tensor.empty(N, N, dtype=dtype)
|
||||
a = Tensor(rng.random((M, K), dtype=np.float32)-0.5, dtype=dtype)
|
||||
b = Tensor(rng.random((K, N), dtype=np.float32)-0.5, dtype=dtype)
|
||||
hc = Tensor.empty(M, N, dtype=dtype)
|
||||
Tensor.realize(a, b, hc)
|
||||
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in [hc, a, b]], prg=get_runner(Device.DEFAULT, sink))
|
||||
|
||||
ets = []
|
||||
with Context(DEBUG=2):
|
||||
for _ in range(run_count):
|
||||
for _ in range(NUM_RUNS):
|
||||
ets.append(ei.run(wait=True))
|
||||
print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}")
|
||||
print(f"REAL TFLOPS {M * N * K * 2 / min(ets) * 1e-12:.2f}")
|
||||
|
||||
if getenv("VERIFY", 1):
|
||||
GlobalCounters.reset()
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class _function(Generic[ReturnType]):
|
|||
call_uops: list[UOp] = dedup(input_uops)
|
||||
|
||||
# disable realize/schedule while this is running
|
||||
# run it and do surgery later
|
||||
# run it and do surgery later. TODO: why am i not calling it with the params?
|
||||
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
assert isinstance(ret, Tensor), "only supports one tensor return for now"
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
def _shape(self) -> tuple[sint, ...]|None:
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS:
|
||||
return None
|
||||
|
|
@ -218,15 +218,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
# non pointer index doesn't have a shape
|
||||
if not isinstance(self.dtype, PtrDType): return None
|
||||
# non pointer index
|
||||
if not isinstance(self.dtype, PtrDType):
|
||||
idxs = flatten([d.shape for d in self.src[1:]])
|
||||
return tuple(idxs) + self.src[0].shape[len(self.src)-1:]
|
||||
# fully indexed doesn't have a shape. TODO: remove this
|
||||
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
|
||||
# pointer index
|
||||
return self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return ()
|
||||
case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE: return ()
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||
case Ops.CUSTOM_FUNCTION: return None
|
||||
|
|
@ -246,7 +248,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
inner_shape = self.src[0]._shape
|
||||
if inner_shape is None: return None
|
||||
# substitute internal PARAMs in the shape with corresponding args
|
||||
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
|
||||
ret = tuple(graph_rewrite(s, _pm_resolve_params, self.src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
|
||||
prepend = tuple([x.vmax+1 for x in self.src[0].ranges])
|
||||
return prepend+ret
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
case Ops.BITCAST:
|
||||
|
|
@ -348,7 +352,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
@recursive_property
|
||||
def _ranges(self) -> dict[UOp, None]:
|
||||
ret: dict[UOp, None] = {}
|
||||
for s in self.src: ret.update(s.ranges)
|
||||
if self.op is Ops.CALL:
|
||||
# ranges do not flow through calls
|
||||
for s in self.src[1:]: ret.update(s.ranges)
|
||||
else:
|
||||
for s in self.src: ret.update(s.ranges)
|
||||
for er in self.ended_ranges:
|
||||
if er.op is Ops.RANGE:
|
||||
# if it's a single RANGE, we don't flow through it.
|
||||
|
|
@ -414,7 +422,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx):
|
||||
idx = argfix(idx)
|
||||
assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
|
||||
#assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
|
||||
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
|
||||
perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx))
|
||||
return perm.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
|
||||
|
|
@ -902,7 +910,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return p
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None, precompile:bool=False) -> UOp:
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
# ranges don't leak through calls, they end!
|
||||
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
|
|
@ -1462,6 +1471,7 @@ renderer = PatternMatcher([
|
|||
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat(), UPat(), UPat(Ops.NOOP, name="x"))), lambda x: x.arg),
|
||||
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: f"p{x.arg}"),
|
||||
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: str(x.arg)),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda ctx,x,u: f"UNROLL({ctx[x.src[0]]}, {u.arg})"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({str(x.dtype)[7:]})({ctx[x.src[0]]})"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue