call matmul

This commit is contained in:
George Hotz 2026-03-11 21:50:31 +08:00
commit 62dbf12655
4 changed files with 138 additions and 11 deletions

View file

@ -0,0 +1,56 @@
from typing import Callable
from tinygrad import UOp, dtypes, Device, Tensor, getenv, function
from tinygrad.uop.ops import AxisType, AddrSpace
def simple_function(fxn:Callable[..., UOp]) -> Callable[..., UOp]:
def wrapper(*args:UOp) -> UOp:
params:list[UOp] = [x.param_like(i) for i,x in enumerate(args)]
return fxn(*params).call(*args)
return wrapper
N = getenv("N", 4096)
M = K = N
# Threadblock tile sizes (block-level tile of C that a block computes)
BLOCK_N = 128 # columns of C (N-dim) per block
BLOCK_M = 128 # rows of C (M-dim) per block
BLOCK_K = 8 # K-slice per block iteration
@simple_function
def slice_matmul(c_regs, a_local, b_local):
# 128x128, 8x128, 8x128
pass
@simple_function
def local_matmul(c:UOp, a:UOp, b:UOp):
# accumulate in registers
a_local = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL)
b_local = UOp.placeholder((BLOCK_K, BLOCK_M), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
c_regs = UOp.placeholder(c.shape, dtypes.float, slot=2, addrspace=AddrSpace.REG)
k_tile = UOp.range(N // BLOCK_K, 0, AxisType.REDUCE)*BLOCK_K
# assign = store + after
slice_matmul(c_regs.assign(0), a_local.assign(a[:, k_tile:k_tile+BLOCK_K].permute(1,0)), b_local.assign(b[k_tile:k_tile+BLOCK_K, :]))
return c.store(c_regs)
@simple_function
def global_matmul(c, a, b):
gx = UOp.range(N//BLOCK_N, 0, AxisType.GLOBAL) * BLOCK_N
gy = UOp.range(M//BLOCK_M, 0, AxisType.GLOBAL) * BLOCK_M
return local_matmul(c[gx:gx+BLOCK_N, gy:gy+BLOCK_M], a[gx:gx+BLOCK_N, :], b[:, gy:gy+BLOCK_M])
if __name__ == "__main__":
C = UOp.new_buffer(Device.DEFAULT, N*M, dtypes.float).reshape(N,M)
A = UOp.new_buffer(Device.DEFAULT, N*K, dtypes.float).reshape(N,K)
B = UOp.new_buffer(Device.DEFAULT, K*M, dtypes.float).reshape(K,M)
global_matmul(C, A, B).realize()
# input matmuls
#c = UOp.param(0, dtypes.float, (N, M))
#a = UOp.param(1, dtypes.float, (N, K))
#b = UOp.param(2, dtypes.float, (K, M))
#ba = a.rearrange("(n bn) (k bk) -> n k bn bk", bn=BLOCK_N, bk=BLOCK_K)[gx, k_tile_range]
#bb = b.rearrange("(k bk) (m bm) -> k m bk bm", bk=BLOCK_K, bm=BLOCK_M)[k_tile_range, gy]
#bc = c.rearrange("(n bn) (m bm) -> n m bn bm", bn=BLOCK_N, bm=BLOCK_M)[gx, gy]

View file

@ -4,15 +4,82 @@ if __name__ == "__main__":
B0 = UOp.new_buffer(Device.DEFAULT, 100, dtypes.float).reshape(10,10)
B1 = UOp.new_buffer(Device.DEFAULT, 100, dtypes.float).reshape(10,10)
R0 = UOp.range(10, axis_id=0)
R1 = UOp.range(10, axis_id=1)
b0 = UOp.param(0, dtypes.float, (10,10))
b1 = UOp.param(1, dtypes.float, (10,10))
r0 = UOp.param(2, dtypes.index, ())
r1 = UOp.param(3, dtypes.index, ())
r0 = UOp.range(10, axis_id=0)
r1 = UOp.range(10, axis_id=1)
fxn = (b0[r0, r1] + b1[r0, r1]).call(B0, B1, R0, R1)
fxn = (b0[r0, r1] + b1[r0, r1]).call(B0, B1)
t = Tensor(fxn)
t.realize()
# gemm (N,N)
# (N//k, k, N//k, k)
# what if call just implicitly ends all ranges and you don't need to connect them?
# you do have to connect them, and it does end the ranges
# if assign (store+after) is on call, we move the store into the call (indexed with the ranges) and replace the assign with an after
def gemm(A, B):
N = 4096
k = 128
ia = UOp.param(0, dtypes.float, (k, k)).reshape(k, 1, k)
ib = UOp.param(1, dtypes.float, (k, k)).reshape(1, k, k)
gemm_fxn = (ia * ib).sum(2) # <-- rangeify this
a = UOp.param(0, dtypes.float, (N, N))
b = UOp.param(1, dtypes.float, (N, N))
r0 = UOp.range(N//k, 0)
r1 = UOp.range(N//k, 1)
local_fxn = gemm_fxn.call(a.reshape(N//k, k, N//k, k)[r0, :, r1, :], b.reshape(N//k, k, N//k, k)[r0, :, r1, :], r0, r1).permute(0,2,1,3).reshape(N,N)
fxn = local_fxn.call(A,B)
return
a = UOp.param(0, dtypes.float, (N//k, k, N//k, k))
b = UOp.param(1, dtypes.float, (N//k, k, N//k, k))
# inner kxk GEMM (are WMMAs calls?)
ia = UOp.param(0, dtypes.float, (k,k)).reshape(k, 1, k)
ib = UOp.param(1, dtypes.float, (k,k)).reshape(1, k, k)
r0 = UOp.range(N//k, 0)
r1 = UOp.range(N//k, 1)
fxn = (ia * ib).sum(2).call(a[:, r0, :, r1], b[:, r0, :, r1]) # this call ends these ranges implicitly
assert fxn.shape == (N//k, N//k, k, k)
#.call(A, B, UOp.range(N//k), UOp.range(N//k))
#r0 = UOp.param(2, dtypes.index, (), vmin_vmax=(0, N//k-1))
#r1 = UOp.param(3, dtypes.index, (), vmin_vmax=(0, N//k-1))
# Q = [batch, seq_len, heads, dim]
# K = [batch, seq_len, head_kv, dim]
# V = [batch, seq_len, head_kv, dim]

View file

@ -36,7 +36,7 @@ class _function(Generic[ReturnType]):
call_uops: list[UOp] = dedup(input_uops)
# disable realize/schedule while this is running
# run it and do surgery later
# run it and do surgery later. TODO: why am i not calling it with the params?
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
ret = self.fxn(*args, **kwargs)
assert isinstance(ret, Tensor), "only supports one tensor return for now"

View file

@ -249,8 +249,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if inner_shape is None: return None
# substitute internal PARAMs in the shape with corresponding args
ret = tuple(graph_rewrite(s, _pm_resolve_params, self.src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
# NOTE: this requires the RANGEs directly on the call
prepend = tuple([x.vmax+1 for x in self.src[1:] if x.op is Ops.RANGE])
prepend = tuple([x.vmax+1 for x in self.src[0].ranges])
return prepend+ret
# TODO: disallow shape changing bitcast
@ -353,7 +352,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@recursive_property
def _ranges(self) -> dict[UOp, None]:
ret: dict[UOp, None] = {}
for s in self.src: ret.update(s.ranges)
if self.op is Ops.CALL:
# ranges do not flow through calls
for s in self.src[1:]: ret.update(s.ranges)
else:
for s in self.src: ret.update(s.ranges)
for er in self.ended_ranges:
if er.op is Ops.RANGE:
# if it's a single RANGE, we don't flow through it.
@ -419,7 +422,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx):
idx = argfix(idx)
assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
#assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx))
return perm.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
@ -907,7 +910,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return p
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None, precompile:bool=False) -> UOp:
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
# ranges don't leak through calls, they end!
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)