mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
7 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d4e74a52a | ||
|
|
e887ad6317 | ||
|
|
98b76aa3aa | ||
|
|
2f80994ce8 |
||
|
|
2ef52db298 | ||
|
|
80f9347e53 | ||
|
|
79f98a6624 |
9 changed files with 95 additions and 34 deletions
|
|
@ -2,13 +2,13 @@ from typing import Optional, Any
|
|||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Timing
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Timing, Context
|
||||
from tinygrad.dtype import dtypes, DType, AddrSpace
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
|
||||
from tinygrad.uop.spec import shared_spec
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner, ExecItem
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
|
@ -39,7 +39,7 @@ def _test_single_value(vals, op, dts):
|
|||
output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True)]) for i, dtype in enumerate(dts))
|
||||
loads = (buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0)) for i, dtype in enumerate(dts))
|
||||
alu = uop(uops, op, output_dtype, loads)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
|
|
@ -56,7 +56,7 @@ def _test_single_value_const(vals, op, dts):
|
|||
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, op, output_dtype, loads)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu))
|
||||
out = buf_store[UOp.const(dtypes.int32, 0)].store(alu)
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
prg = _uops_to_prg([out])
|
||||
prg.exec([buf])
|
||||
|
|
@ -565,5 +565,45 @@ class TestZeroRange(unittest.TestCase):
|
|||
out = Tensor.ones(10, dtype=dtypes.int).contiguous().shrink(((0,v),)).sum()
|
||||
self.assertEqual(out.item(), i)
|
||||
|
||||
class TestUOpPrograms(unittest.TestCase):
|
||||
def _run(self, prog:UOp, *tensors:Tensor):
|
||||
ExecItem(get_runner(Device.DEFAULT, prog), [t.uop.buffer for t in tensors]).run(wait=True)
|
||||
|
||||
def test_matmul(self):
|
||||
a = Tensor.rand(10,10)
|
||||
b = Tensor.rand(10,10)
|
||||
c = Tensor.empty(10,10)
|
||||
ref = a@b
|
||||
with Context(DEBUG=0): Tensor.realize(a, b, c, ref)
|
||||
|
||||
# C[i,j] = sum_k A[i,k] * B[k,j]
|
||||
# Shapes: A[M,K], B[K,N], C[M,N]
|
||||
M = N = K = 10
|
||||
DT = dtypes.float32
|
||||
|
||||
# Axes: i,j are spatial; k is a reduction axis over the shared dim K
|
||||
i = UOp.range(M, axis_id=0) # rows of A/C
|
||||
j = UOp.range(N, axis_id=1) # cols of B/C
|
||||
k = UOp.range(K, axis_id=2, axis_type=AxisType.REDUCE) # reduction over K
|
||||
|
||||
# Placeholders (bind slots explicitly)
|
||||
A = UOp.placeholder(DT, (M, K), slot=0)
|
||||
B = UOp.placeholder(DT, (K, N), slot=1)
|
||||
C = UOp.placeholder(DT, (M, N), slot=2)
|
||||
|
||||
# Zero-init: write a scalar 0 to each (i,j).
|
||||
C = C[i, j].set(0.0)
|
||||
|
||||
# Accumulate: C_after(k) enforces the dependency along the reduction axis
|
||||
C = C[i, j].set(C.after(k)[i, j] + A[i, k] * B[k, j])
|
||||
|
||||
# Finalize the loop nest / schedule in (i, j, k) order
|
||||
prog = C.end(i, j, k)
|
||||
|
||||
# run program
|
||||
self._run(prog.sink(arg=KernelInfo(opts_to_apply=())), a, b, c)
|
||||
|
||||
with Context(DEBUG=0): self.assertLessEqual((c-ref).square().mean().item(), 1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
|||
ReduceContext, correct_load_store, pm_render, pm_add_loads
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
|
|
@ -23,6 +23,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
|||
|
||||
if SPEC: type_verify(sink, kernel_spec)
|
||||
|
||||
# preprocess
|
||||
sink = graph_rewrite(sink, pm_mops, name="early movement ops")
|
||||
|
||||
# first we optimize
|
||||
if optimize:
|
||||
# collapse loads reduce (indexing by a tensor)
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ expander = PatternMatcher([
|
|||
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||
Ops.VECTORIZE, Ops.REDUCE, Ops.END), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
Ops.VECTORIZE, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# BARRIERs aren't actually expanded
|
||||
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class Scheduler:
|
|||
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
|
||||
|
||||
def _output_rngs(self) -> list[UOp]:
|
||||
return flatten([list(UOp.sink(*s.src[1:]).ranges) for s in self.ast.src if s.op is Ops.END])
|
||||
return flatten([[r for r in UOp.sink(*s.src[1:]).ranges if r.arg[-1] != AxisType.REDUCE] for s in self.ast.src if s.op is Ops.END])
|
||||
def _globalizable_rngs(self) -> list[UOp]:
|
||||
ret = self._output_rngs()
|
||||
# exclude any output ranges from global that don't appear in all BUFFERIZE
|
||||
|
|
|
|||
|
|
@ -90,7 +90,8 @@ class CompiledRunner(Runner):
|
|||
|
||||
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
||||
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
|
||||
if var_vals is None: var_vals = {}
|
||||
has_local = Device[self.p.device].renderer.has_local
|
||||
global_size, local_size = self.p.launch_dims(var_vals)
|
||||
if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ sys.setrecursionlimit(10000)
|
|||
pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
|
||||
(UPat(Ops.RESHAPE, name="r").after(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:]).reshape(r.shape)),
|
||||
(UPat(Ops.RESHAPE, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
|
|
|||
|
|
@ -341,7 +341,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, *idx): return self.index(*idx)
|
||||
def __getitem__(self, idx): return self.index(*argfix(idx))
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
|
||||
|
|
@ -390,10 +390,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||
return ret
|
||||
@staticmethod
|
||||
def range(end:sint, *arg, dtype=dtypes.index, src=(), **kwargs):
|
||||
if len(arg) == 0: raise RuntimeError("range needs an arg")
|
||||
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=arg, **kwargs)
|
||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
||||
@staticmethod
|
||||
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
|
||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||
|
|
@ -745,6 +743,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
|
||||
def pyrender(self): return pyrender(self)
|
||||
|
||||
# *** uop high level syntactic sugar ***
|
||||
|
||||
@staticmethod
|
||||
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int):
|
||||
ret = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(prod(shape)), arg=slot)
|
||||
if len(shape) > 1: ret = ret.reshape(shape)
|
||||
return ret
|
||||
|
||||
# set is store+after
|
||||
def set(self:UOp, val:UOp|ConstType):
|
||||
return self.src[0].after(self.store(UOp.const(self.dtype, val) if not isinstance(val, UOp) else val))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelInfo:
|
||||
name: str = "test" # name of the kernel
|
||||
|
|
@ -873,6 +883,7 @@ class UPat(MathTrait):
|
|||
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)
|
||||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
|
||||
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.dtype, (self,)+src, **kwargs)
|
||||
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, op:Ops, *src:UPat):
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ shared_codegen_spec = PatternMatcher([
|
|||
(UPat(Ops.DEFINE_REG, src=()), lambda: True),
|
||||
|
||||
# allow AFTER on buffers, GROUP anywhere
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines|{Ops.AFTER}),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.GROUP, dtypes.void), lambda: True),
|
||||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
|
|
@ -141,7 +141,7 @@ shared_codegen_spec = PatternMatcher([
|
|||
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||
|
||||
# INDEX
|
||||
(UPat(GroupOp.Defines, name="buf").or_after().index(UPat.var("idx")), validate_index),
|
||||
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
|
||||
|
||||
# SPECIAL
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||
|
|
@ -150,11 +150,31 @@ shared_codegen_spec = PatternMatcher([
|
|||
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
|
||||
])
|
||||
|
||||
# ***** UOp spec in kernel graph *****
|
||||
|
||||
kernel_spec = PatternMatcher([
|
||||
# RESHAPE (but only RESHAPE) is allowed here
|
||||
(UPat(Ops.RESHAPE, name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.RESHAPE),), allow_any_len=True), lambda: True),
|
||||
|
||||
# index is allowed here
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# END can end multiple axes here
|
||||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# bufferize can be on anything
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
|
||||
|
||||
# reduce must be on ranges
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
])+shared_codegen_spec+shared_spec
|
||||
|
||||
# ***** UOp spec in linearized programs *****
|
||||
|
||||
program_spec = PatternMatcher([
|
||||
# INDEX with a gate as third src
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines, name="buf").or_after(), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
|
||||
|
||||
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
|
||||
|
|
@ -173,22 +193,6 @@ program_spec = PatternMatcher([
|
|||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||
])+shared_codegen_spec+shared_spec
|
||||
|
||||
# ***** UOp spec in kernel graph *****
|
||||
|
||||
kernel_spec = PatternMatcher([
|
||||
# index is allowed here
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# END can end multiple axes here
|
||||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# bufferize can be on anything
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
|
||||
|
||||
# reduce must be on ranges
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
])+shared_codegen_spec+shared_spec
|
||||
|
||||
# *** this spec should match all UOps ever created ***
|
||||
|
||||
full_spec = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
|
|||
name = ctxs[ref]["name"]
|
||||
if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
|
||||
info = f"{sym_infer(p.estimates.ops, ei.arg['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei.arg['var_vals'])/t:4.1f}"+ \
|
||||
f"|{sym_infer(p.estimates.lds,ei.arg['var_vals'])/t:.1f} GB/s\n{[str(m) for m in ei.arg['metadata']]}"
|
||||
f"|{sym_infer(p.estimates.lds,ei.arg['var_vals'])/t:.1f} GB/s\n{[str(m) for m in (ei.arg['metadata'] or ())]}"
|
||||
key = ei.key
|
||||
elif isinstance(e.name, TracingKey):
|
||||
name = e.name.display_name
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue