new codegen, try 2

This commit is contained in:
George Hotz 2026-06-20 08:30:38 -07:00
commit da402953d9
4 changed files with 111 additions and 6 deletions

View file

@ -24,6 +24,8 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
from tinygrad.codegen.codegen2 import expander2, pm_move_regs
pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).cast(name="x"), lambda buf,idx,x:
@ -81,7 +83,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
# expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
#sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
# add locals
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
@ -96,7 +99,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# **** optimizations are done, now we lower to actual code ****
# add loads and remove invalids
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
#sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
sink = graph_rewrite(sink, pm_move_regs, name="** add loads")
# create image buffers
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:

View file

@ -0,0 +1,101 @@
import itertools, functools
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.helpers import all_same, flatten
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
from tinygrad.codegen.late.devectorizer import ReduceContext
from tinygrad.uop.symbolic import pm_clean_up_group_sink
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_move_regs = PatternMatcher([
# BITCAST?
(UPat(GroupOp.Elementwise, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
])
pm_lower_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
])
def build_range_map(ctx, sink:UOp):
for x in sink.toposort():
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
ctx[x.arg[0]] = len(ctx)
def fix_reduce(ctx, r:UOp):
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
expander2 = PatternMatcher([
(UPat(Ops.SINK, name="sink"), build_range_map),
(UPat(Ops.REDUCE, name="r"), fix_reduce),
(UPat(Ops.RANGE, name="r"),
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
])+pm_flatten_range
def broadcast_binary(x:UOp):
shapes = [u.shape for u in x.src]
if all_same(shapes): return None
shaped_aligned = _align_left(*shapes)
broadcasted = _broadcast_shape(*shapes)
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
return x.replace(src=tuple(src_reshaped))
unbroadcast = PatternMatcher([
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
])
def do_devectorize(b:UOp):
if b.shape == (): return None
# broadcasting needs to be already unpacked
if not all_same([x.shape for x in b.src]): return None
src = []
for idx in itertools.product(*[range(x) for x in b.shape]):
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
return UOp.vectorize(*src).reshape(b.shape)
devectorizer2 = pm_mops+PatternMatcher([
# unpack broadcasting
(UPat(GroupOp.Elementwise|{Ops.LOAD, Ops.STORE}, name="b"), do_devectorize),
# INDEX into STACK is src
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
# stacked INDEX is many INDEX
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
# INDEX into RESHAPE moves the RESHAPE
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
# RESHAPE a void is removed (hack for AFTER)
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
# reshape of a single element shaped value to scalar is an index
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
# INDEX without src is nothing
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
])
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
ctx.acc_num += 1
topo = r.src[0].toposort()
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
acc_initted = acc.after(acc_init, *r.src[1:])
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
return acc.after(acc_out)
def expand_horizontal_reduce(r:UOp):
axes = r.arg[1]
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
pm_reduce_local = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
])+pm_clean_up_group_sink

View file

@ -274,7 +274,7 @@ SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# don't allow broadcast
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1)
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
@dataclass(frozen=True)
class Metadata:

View file

@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
if len(arg) == 0: return UOp(Ops.STACK)
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
elif all_int(arg): return UOp.const(dtypes.weakint, arg)
else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {}
@ -473,7 +473,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs):
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
return UOp(Ops.STACK, self.dtype, (self,)+srcs)
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx):