Compare commits

...

7 commits

Author SHA1 Message Date
George Hotz
9d4e74a52a comments 2025-10-30 12:27:55 +08:00
George Hotz
e887ad6317 more syntax 2025-10-30 12:11:18 +08:00
George Hotz
98b76aa3aa more syntax 2025-10-30 12:04:26 +08:00
George Hotz
2f80994ce8
Merge branch 'master' into uop_prg 2025-10-30 11:22:36 +08:00
George Hotz
2ef52db298 work 2025-10-30 11:12:37 +08:00
George Hotz
80f9347e53 work 2025-10-30 11:01:53 +08:00
George Hotz
79f98a6624 uops programs 2025-10-30 10:52:44 +08:00
9 changed files with 95 additions and 34 deletions

View file

@ -2,13 +2,13 @@ from typing import Optional, Any
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Timing
from tinygrad.helpers import CI, DEBUG, getenv, Timing, Context
from tinygrad.dtype import dtypes, DType, AddrSpace
from tinygrad.device import Buffer, Device
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
from tinygrad.uop.spec import shared_spec
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner, ExecItem
from tinygrad.codegen import full_rewrite
from tinygrad.uop.symbolic import sym
from tinygrad.device import is_dtype_supported
@ -39,7 +39,7 @@ def _test_single_value(vals, op, dts):
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True)]) for i, dtype in enumerate(dts))
loads = (buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0)) for i, dtype in enumerate(dts))
alu = uop(uops, op, output_dtype, loads)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
@ -56,7 +56,7 @@ def _test_single_value_const(vals, op, dts):
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, op, output_dtype, loads)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu))
out = buf_store[UOp.const(dtypes.int32, 0)].store(alu)
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg([out])
prg.exec([buf])
@ -565,5 +565,45 @@ class TestZeroRange(unittest.TestCase):
out = Tensor.ones(10, dtype=dtypes.int).contiguous().shrink(((0,v),)).sum()
self.assertEqual(out.item(), i)
class TestUOpPrograms(unittest.TestCase):
def _run(self, prog:UOp, *tensors:Tensor):
ExecItem(get_runner(Device.DEFAULT, prog), [t.uop.buffer for t in tensors]).run(wait=True)
def test_matmul(self):
a = Tensor.rand(10,10)
b = Tensor.rand(10,10)
c = Tensor.empty(10,10)
ref = a@b
with Context(DEBUG=0): Tensor.realize(a, b, c, ref)
# C[i,j] = sum_k A[i,k] * B[k,j]
# Shapes: A[M,K], B[K,N], C[M,N]
M = N = K = 10
DT = dtypes.float32
# Axes: i,j are spatial; k is a reduction axis over the shared dim K
i = UOp.range(M, axis_id=0) # rows of A/C
j = UOp.range(N, axis_id=1) # cols of B/C
k = UOp.range(K, axis_id=2, axis_type=AxisType.REDUCE) # reduction over K
# Placeholders (bind slots explicitly)
A = UOp.placeholder(DT, (M, K), slot=0)
B = UOp.placeholder(DT, (K, N), slot=1)
C = UOp.placeholder(DT, (M, N), slot=2)
# Zero-init: write a scalar 0 to each (i,j).
C = C[i, j].set(0.0)
# Accumulate: C_after(k) enforces the dependency along the reduction axis
C = C[i, j].set(C.after(k)[i, j] + A[i, k] * B[k, j])
# Finalize the loop nest / schedule in (i, j, k) order
prog = C.end(i, j, k)
# run program
self._run(prog.sink(arg=KernelInfo(opts_to_apply=())), a, b, c)
with Context(DEBUG=0): self.assertLessEqual((c-ref).square().mean().item(), 1e-6)
if __name__ == '__main__':
unittest.main(verbosity=2)

View file

@ -15,7 +15,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
ReduceContext, correct_load_store, pm_render, pm_add_loads
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
@ -23,6 +23,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
if SPEC: type_verify(sink, kernel_spec)
# preprocess
sink = graph_rewrite(sink, pm_mops, name="early movement ops")
# first we optimize
if optimize:
# collapse loads reduce (indexing by a tensor)

View file

@ -84,7 +84,7 @@ expander = PatternMatcher([
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
Ops.VECTORIZE, Ops.REDUCE, Ops.END), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
Ops.VECTORIZE, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# BARRIERs aren't actually expanded
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),

View file

@ -64,7 +64,7 @@ class Scheduler:
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
def _output_rngs(self) -> list[UOp]:
return flatten([list(UOp.sink(*s.src[1:]).ranges) for s in self.ast.src if s.op is Ops.END])
return flatten([[r for r in UOp.sink(*s.src[1:]).ranges if r.arg[-1] != AxisType.REDUCE] for s in self.ast.src if s.op is Ops.END])
def _globalizable_rngs(self) -> list[UOp]:
ret = self._output_rngs()
# exclude any output ranges from global that don't appear in all BUFFERIZE

View file

@ -90,7 +90,8 @@ class CompiledRunner(Runner):
def __reduce__(self): return self.__class__, (self.p, self.lib)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
if var_vals is None: var_vals = {}
has_local = Device[self.p.device].renderer.has_local
global_size, local_size = self.p.launch_dims(var_vals)
if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]

View file

@ -18,6 +18,8 @@ sys.setrecursionlimit(10000)
pm_mops = PatternMatcher([
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
(UPat(Ops.RESHAPE, name="r").after(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:]).reshape(r.shape)),
(UPat(Ops.RESHAPE, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
])
# *****************

View file

@ -341,7 +341,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, *idx): return self.index(*idx)
def __getitem__(self, idx): return self.index(*argfix(idx))
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
@ -390,10 +390,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret
@staticmethod
def range(end:sint, *arg, dtype=dtypes.index, src=(), **kwargs):
if len(arg) == 0: raise RuntimeError("range needs an arg")
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=arg, **kwargs)
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@staticmethod
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
def r(self, op:Ops, axis:tuple[int, ...]):
@ -745,6 +743,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def pyrender(self): return pyrender(self)
# *** uop high level syntactic sugar ***
@staticmethod
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int):
ret = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(prod(shape)), arg=slot)
if len(shape) > 1: ret = ret.reshape(shape)
return ret
# set is store+after
def set(self:UOp, val:UOp|ConstType):
return self.src[0].after(self.store(UOp.const(self.dtype, val) if not isinstance(val, UOp) else val))
@dataclass(frozen=True)
class KernelInfo:
name: str = "test" # name of the kernel
@ -873,6 +883,7 @@ class UPat(MathTrait):
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.dtype, (self,)+src, **kwargs)
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, op:Ops, *src:UPat):

View file

@ -115,7 +115,7 @@ shared_codegen_spec = PatternMatcher([
(UPat(Ops.DEFINE_REG, src=()), lambda: True),
# allow AFTER on buffers, GROUP anywhere
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines|{Ops.AFTER}),), allow_any_len=True), lambda: True),
(UPat(Ops.GROUP, dtypes.void), lambda: True),
# RANGE/SPECIAL define loops, END closes them
@ -141,7 +141,7 @@ shared_codegen_spec = PatternMatcher([
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
# INDEX
(UPat(GroupOp.Defines, name="buf").or_after().index(UPat.var("idx")), validate_index),
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
# SPECIAL
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
@ -150,11 +150,31 @@ shared_codegen_spec = PatternMatcher([
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
])
# ***** UOp spec in kernel graph *****
kernel_spec = PatternMatcher([
# RESHAPE (but only RESHAPE) is allowed here
(UPat(Ops.RESHAPE, name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat(Ops.AFTER, src=(UPat(Ops.RESHAPE),), allow_any_len=True), lambda: True),
# index is allowed here
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
# END can end multiple axes here
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
# bufferize can be on anything
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
# reduce must be on ranges
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
])+shared_codegen_spec+shared_spec
# ***** UOp spec in linearized programs *****
program_spec = PatternMatcher([
# INDEX with a gate as third src
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines, name="buf").or_after(), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
@ -173,22 +193,6 @@ program_spec = PatternMatcher([
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
])+shared_codegen_spec+shared_spec
# ***** UOp spec in kernel graph *****
kernel_spec = PatternMatcher([
# index is allowed here
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
# END can end multiple axes here
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
# bufferize can be on anything
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
# reduce must be on ranges
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
])+shared_codegen_spec+shared_spec
# *** this spec should match all UOps ever created ***
full_spec = PatternMatcher([

View file

@ -155,7 +155,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
name = ctxs[ref]["name"]
if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
info = f"{sym_infer(p.estimates.ops, ei.arg['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei.arg['var_vals'])/t:4.1f}"+ \
f"|{sym_infer(p.estimates.lds,ei.arg['var_vals'])/t:.1f} GB/s\n{[str(m) for m in ei.arg['metadata']]}"
f"|{sym_infer(p.estimates.lds,ei.arg['var_vals'])/t:.1f} GB/s\n{[str(m) for m in (ei.arg['metadata'] or ())]}"
key = ei.key
elif isinstance(e.name, TracingKey):
name = e.name.display_name