Compare commits

...

39 commits

Author SHA1 Message Date
George Hotz
d319b5f614
Merge branch 'master' into codegen2 2026-06-21 19:18:24 -07:00
George Hotz
0decf136fe
Merge branch 'master' into codegen2 2026-06-19 18:29:13 -07:00
George Hotz
a96db15419 fixes 2026-06-19 17:57:59 -07:00
George Hotz
b0f7e2a5e1 fix imports 2026-06-19 17:44:39 -07:00
George Hotz
566fe4c7dc
Merge branch 'master' into codegen2 2026-06-19 16:57:09 -07:00
George Hotz
c2f73b102e
Merge branch 'master' into codegen2 2026-06-19 12:39:00 -07:00
George Hotz
c6f195f410 fixes 2026-06-18 18:08:14 -07:00
George Hotz
f26378264b
Merge branch 'master' into codegen2 2026-06-18 18:03:15 -07:00
George Hotz
1168ed9730
Merge branch 'master' into codegen2 2026-06-17 00:37:09 -07:00
George Hotz
017edbbbb5 param -1 2026-06-16 21:52:07 -07:00
George Hotz
daa72812b0 add gpu dims 2026-06-16 21:37:59 -07:00
George Hotz
fd325d662c
Merge branch 'master' into codegen2 2026-06-16 21:29:09 -07:00
George Hotz
8d36539656 test tiny passes 2026-06-16 14:57:12 -07:00
George Hotz
db2c71536b almost passing 2026-06-16 13:29:38 -07:00
George Hotz
4112b34a32 closer 2026-06-16 13:23:30 -07:00
George Hotz
1ad72dff08 more passing 2026-06-16 12:54:54 -07:00
George Hotz
6f1eaa8d46 fixes 2026-06-16 12:38:17 -07:00
George Hotz
35d2882991 no vec 2026-06-16 10:47:19 -07:00
George Hotz
a31732d819
Merge branch 'master' into codegen2 2026-06-16 10:33:34 -07:00
George Hotz
43d62c4211 hreduce 2026-06-16 09:36:47 -07:00
George Hotz
4d0429090c split reduce types 2026-06-16 09:27:09 -07:00
George Hotz
2c7a1450e7 fix reduce 2026-06-16 08:40:00 -07:00
George Hotz
6ffb55cc74
Merge branch 'master' into codegen2 2026-06-15 17:19:25 -07:00
George Hotz
1a280829ca
Merge branch 'master' into codegen2 2026-06-15 12:48:46 -07:00
George Hotz
3b426b1072 devec 2026-06-15 08:57:52 -07:00
George Hotz
ce2cdc3708
Merge branch 'master' into codegen2 2026-06-14 16:43:48 -07:00
George Hotz
333f062eee new expander 2026-06-14 13:54:13 -07:00
George Hotz
0d5bf3ca6d revert that 2026-06-14 13:28:28 -07:00
George Hotz
56bad940df disable that 2026-06-14 13:28:02 -07:00
George Hotz
f98deb9250 preprocess 2026-06-14 13:24:19 -07:00
George Hotz
bdfcb1cb98 test ops passes 2026-06-14 12:58:18 -07:00
George Hotz
a6fdb53a1e
Merge branch 'master' into codegen2 2026-06-14 10:09:00 -07:00
George Hotz
49deb9714b test_tiny passes 2026-06-14 09:36:51 -07:00
George Hotz
afab220947
Merge branch 'master' into codegen2 2026-06-14 08:52:36 -07:00
George Hotz
a7523b2596 simpler 2026-06-13 10:40:52 -07:00
George Hotz
21806848df improve new codegen 2026-06-12 20:08:20 -07:00
George Hotz
6fda6c704d
Merge branch 'master' into codegen2 2026-06-12 20:01:43 -07:00
George Hotz
3f7ec187df work 2026-06-12 19:24:56 -07:00
George Hotz
af9284e9b1 try for a full rewrite of codegen 2026-06-12 19:11:54 -07:00
5 changed files with 200 additions and 12 deletions

View file

@ -1,14 +1,16 @@
from typing import cast
from dataclasses import replace
import itertools
import functools
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, all_same, flatten
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
from tinygrad.uop.ops import AxisType, _align_left, _broadcast_shape, identity_element
from tinygrad.uop.render import pyrender
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
from tinygrad.renderer import Renderer, Estimates
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
from tinygrad.dtype import dtypes, PtrDType, ImageDType
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
@ -43,6 +45,98 @@ pm_remove_vec_dtypes = PatternMatcher([
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
])+pm_clean_up_group_sink
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_move_regs = PatternMatcher([
# BITCAST?
(UPat(GroupOp.Elementwise, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
])
pm_lower_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
])
def build_range_map(ctx, sink:UOp):
for x in sink.toposort():
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
ctx[x.arg[0]] = len(ctx)
def fix_reduce(ctx, r:UOp):
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
expander2 = PatternMatcher([
(UPat(Ops.SINK, name="sink"), build_range_map),
(UPat(Ops.REDUCE, name="r"), fix_reduce),
(UPat(Ops.RANGE, name="r"),
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
])+pm_flatten_range
def broadcast_binary(x:UOp):
shapes = [u.shape for u in x.src]
if all_same(shapes): return None
shaped_aligned = _align_left(*shapes)
broadcasted = _broadcast_shape(*shapes)
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
return x.replace(src=tuple(src_reshaped))
unbroadcast = PatternMatcher([
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
])
def do_devectorize(b:UOp):
if b.shape == (): return None
# broadcasting needs to be already unpacked
if not all_same([x.shape for x in b.src]): return None
src = []
for idx in itertools.product(*[range(x) for x in b.shape]):
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
return UOp.vectorize(*src).reshape(b.shape)
devectorizer2 = pm_mops+PatternMatcher([
# unpack broadcasting
(UPat(GroupOp.Elementwise|{Ops.LOAD, Ops.STORE}, name="b"), do_devectorize),
# INDEX into STACK is src
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
# stacked INDEX is many INDEX
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
# INDEX into RESHAPE moves the RESHAPE
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
# RESHAPE a void is removed (hack for AFTER)
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
# reshape of a single element shaped value to scalar is an index
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
# INDEX without src is nothing
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
])
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
ctx.acc_num += 1
topo = r.src[0].toposort()
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
acc_initted = acc.after(acc_init, *r.src[1:])
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
return acc.after(acc_out)
def expand_horizontal_reduce(r:UOp):
axes = r.arg[1]
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
pm_reduce_local = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
])+pm_clean_up_group_sink
def do_number_param(ctx:list[int], x:UOp):
if x.arg.slot != -1: return None
ctx[0] += 1
@ -56,6 +150,90 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
if SPEC: type_verify(ast, spec_tensor)
sink = ast
# preprocess. we need to simplify these
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
# this is new style
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
# first we optimize
if optimize:
# do postrange optimization, BEAM or hand_coded_optimizations
sink = apply_opts(sink, ren, beam=ast.arg.beam)
# do expander
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
# add locals (STAGE -> BUFFER)
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
# rewrite reduce after optimizations
sink = graph_rewrite(sink, pm_reduce_local, ctx=ReduceContext(), name="remove_reduce")
# add gpu dims
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
# add loads
sink = graph_rewrite(sink, pm_move_regs, name="move to registers", walk=True)
# symbolic (note: this does POW decomp)
sink = graph_rewrite(sink, sym, name="post index symbolic")
# ***** make it rendererable (within spec, tighten) *****
# decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="*** decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions more")
# split ends
sink = graph_rewrite(sink, pm_split_ends, name="split ends")
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# ***** this is where it gets large *****
# unbroadcast
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
# devectorizer
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, name="devectorizer")
# ***** make it rendererable (outside spec, transform) *****
# final symbolic
sink = graph_rewrite(sink, sym, name="post devectorizer sym")
# move gates from unrenderable INVALID where
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
# put registers in slots
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot != -1])
name_to_slot = {x:x.replace(arg=replace(x.arg, slot=num_params+i))
for i,x in enumerate(sorted([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot == -1]))}
sink = sink.substitute(name_to_slot, name="put variables in slots")
# remove all weakints
sink = graph_rewrite(sink, pm_lower_weakints, name="lower weakints", bottom_up=True)
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Output AST")
if SPEC: type_verify(sink, spec_program)
# return the rewritten sink
return sink
def old_full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
if SPEC: type_verify(ast, spec_tensor)
# preprocess
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)

View file

@ -274,7 +274,7 @@ SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# don't allow broadcast
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1)
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
@dataclass(frozen=True)
class Metadata:

View file

@ -75,7 +75,9 @@ def fold_divmod_general(d: UOp) -> UOp|None:
# divide_by_gcd: x//y -> (x//gcd)//(y//gcd)
gcd = UOp.gcd(*all_uops, y).simplify()
if not (gcd.op is Ops.CONST and gcd.arg==1):
ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd)))
x_div, y_div = x.divide_exact(gcd), y.divide_exact(gcd)
if x_div is None or y_div is None: return None
ret = x_div.alu(d.op, y_div)
return ret*gcd if d.op is Ops.FLOORMOD else ret
# factor_remainder: (d*x+y)//d -> x+y//d

View file

@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
if len(arg) == 0: return UOp(Ops.STACK)
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
elif all_int(arg): return UOp.const(dtypes.weakint, arg)
else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {}
@ -306,9 +306,10 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
return self.src[0]._shape
# REDUCE with empty axis is passthrough (lowered form)
case Ops.REDUCE if len(self.arg[1]) == 0:
# no longer true
#case Ops.REDUCE if len(self.arg[1]) == 0:
# these can mismatch if there's a horizonal reduce
return (self.dtype.count,) if self.dtype.count > 1 else ()
#return (self.dtype.count,) if self.dtype.count > 1 else ()
# TODO: disallow shape changing bitcast
case Ops.BITCAST:
@ -473,12 +474,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs):
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
return UOp(Ops.STACK, self.dtype, (self,)+srcs)
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx):
# pointers index into INDEX UOps (scalar lookup); everything else uses the shared mixin view path
if not isinstance(self.dtype, PtrDType): return super(UOp, self).__getitem__(idx)
#if not isinstance(self.dtype, PtrDType): return super(UOp, self).__getitem__(idx)
idx = self._normalize_indices(list(argfix(idx)))
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
# apply SHRINK for slices that aren't the full range
@ -1682,7 +1683,7 @@ pm_lower_index_dtype = PatternMatcher([
UPat.var("gate").where(UPat.var("idx_x", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx_x,idx_y,gate: buf.index(gate.where(idx_y, idx_y.const_like(Invalid)),
gate.where(idx_x, idx_x.const_like(Invalid)), ptr=True)),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
(UPat((Ops.SINK, Ops.NOOP, Ops.END, Ops.AFTER, Ops.BUFFER), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
])
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]

View file

@ -29,7 +29,14 @@ def const_arg(u:UOp) -> ConstType|tuple[ConstType, ...]|None:
def fold_const_alu(a:UOp) -> UOp|None:
vals = [const_arg(s) for s in a.src]
return None if any(v is None for v in vals) else a.const_like(exec_alu(a.op, a.dtype, vals, False))
if any(v is None for v in vals): return None
if any(isinstance(v, tuple) for v in vals):
out_len = prod(a.shape)
if not all(not isinstance(v, tuple) or len(v) in {1, out_len} for v in vals): return None
return a.const_like(tuple(exec_alu(a.op, a.dtype.scalar(),
[v[0] if isinstance(v, tuple) and len(v) == 1 else v[i] if isinstance(v, tuple) else v for v in vals], False)
for i in range(out_len)))
return a.const_like(exec_alu(a.op, a.dtype, vals, False))
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)