Compare commits

...

20 commits

Author SHA1 Message Date
George Hotz
30cc426916 order 2026-06-23 12:39:06 -07:00
George Hotz
c7fd74d523 order 2026-06-23 12:30:19 -07:00
George Hotz
4e398e3f1f lengths 2026-06-23 09:18:51 -07:00
George Hotz
5238f304c7 coalese 2026-06-23 09:04:23 -07:00
George Hotz
87e12fad81 cleanup 2026-06-23 08:43:00 -07:00
George Hotz
5eaf02c719 memory coalesing 2026-06-23 08:30:23 -07:00
George Hotz
4424af7bd8 revert that 2026-06-22 18:36:50 -07:00
George Hotz
c01d75a651 fixes 2026-06-22 18:18:19 -07:00
George Hotz
63ec8ad21d
Merge branch 'master' into codegen_try_2 2026-06-22 17:44:37 -07:00
George Hotz
cccd9c2c03 loads are grouped 2026-06-22 16:12:18 -07:00
George Hotz
303b6ba14c
Merge branch 'master' into codegen_try_2 2026-06-22 13:02:37 -07:00
George Hotz
2bf3c48c1b reduce hack remove 2026-06-22 13:01:36 -07:00
George Hotz
ba75b68c12 simple 2026-06-22 11:36:13 -07:00
George Hotz
958088cc13 Merge remote-tracking branch 'origin/master' into codegen_try_2 2026-06-22 11:25:46 -07:00
George Hotz
257ed03f57 fix store 2026-06-22 08:46:56 -07:00
George Hotz
15476572cd
Merge branch 'master' into codegen_try_2 2026-06-22 08:34:49 -07:00
George Hotz
f8949a0de1 fix exec_alu 2026-06-20 17:01:28 -07:00
George Hotz
62c6c75657 test tiny almost passes 2026-06-20 16:46:01 -07:00
George Hotz
530aed739d devec 2026-06-20 08:34:20 -07:00
George Hotz
da402953d9 new codegen, try 2 2026-06-20 08:30:38 -07:00
5 changed files with 194 additions and 20 deletions

View file

@ -12,11 +12,9 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
from tinygrad.codegen.late.devectorizer import load_store_indexing, ReduceContext, pm_render, pm_make_images
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.late.gater import pm_move_gates_from_index
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
@ -24,6 +22,8 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
from tinygrad.codegen.codegen2 import expander2, pm_move_regs, devectorizer2, unbroadcast, pm_reduce_local, pm_horizontal_reduce, memory_coalesing
pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).cast(name="x"), lambda buf,idx,x:
@ -81,14 +81,16 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
# expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
#sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
# add locals
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
# ** devectorizer (full_graph_rewrite) **
# remove reduce
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
#sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
sink = graph_rewrite(sink, pm_reduce_local+pm_horizontal_reduce, ctx=ReduceContext(), name="remove_reduce")
# add gpu dims (late). this works after devectorize, but it's faster here
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
@ -96,15 +98,21 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# **** optimizations are done, now we lower to actual code ****
# add loads and remove invalids
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
#sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
sink = graph_rewrite(sink, pm_move_regs, name="** add loads")
# create image buffers
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
# hreduce
#sink = graph_rewrite(sink, pm_mops+pm_horizontal_reduce, name="hreduce")
# devectorize
sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
ctx=ren, name="devectorize")
#sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
# ctx=ren, name="devectorize")
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, ctx=ren, name="devectorize2")
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
@ -113,12 +121,21 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# optional pre matcher
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
# dtypes
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
# memory coalesing
sink = memory_coalesing(sink)
# again
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
# decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
# GEP/STACK stuff

View file

@ -0,0 +1,161 @@
from typing import Any
import itertools, functools
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, resolve, graph_rewrite
from tinygrad.dtype import dtypes, AddrSpace, ImageDType, Invalid
from tinygrad.helpers import all_same, flatten, getenv
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
from tinygrad.codegen.late.devectorizer import ReduceContext
from tinygrad.uop.symbolic import pm_clean_up_group_sink
from collections import defaultdict
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_move_regs = PatternMatcher([
# BITCAST?
(UPat(GroupOp.Elementwise|{Ops.REDUCE}, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
])
pm_lower_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
])
def build_range_map(ctx, sink:UOp):
for x in sink.toposort():
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
ctx[x.arg[0]] = len(ctx)
def fix_reduce(ctx, r:UOp):
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
expander2 = PatternMatcher([
(UPat(Ops.SINK, name="sink"), build_range_map),
(UPat(Ops.REDUCE, name="r"), fix_reduce),
(UPat(Ops.RANGE, name="r"),
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
])+pm_flatten_range
def broadcast_binary(x:UOp):
shapes = [u.shape for u in x.src]
if all_same(shapes): return None
shaped_aligned = _align_left(*shapes)
broadcasted = _broadcast_shape(*shapes)
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
return x.replace(src=tuple(src_reshaped))
unbroadcast = PatternMatcher([
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
])
def do_devectorize(b:UOp):
if b.shape == (): return None
# broadcasting needs to be already unpacked
if not all_same([x.shape for x in b.src]): return None
src = []
for idx in itertools.product(*[range(x) for x in b.shape]):
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
return UOp.vectorize(*src).reshape(b.shape) if b.op is not Ops.STORE else UOp.group(*src)
devectorizer2 = pm_mops+PatternMatcher([
# unpack broadcasting
(UPat(GroupOp.Elementwise|{Ops.LOAD,Ops.STORE}, name="b"), do_devectorize),
# const INDEX into STACK is src
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
# stacked INDEX is many INDEX
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
# INDEX into RESHAPE moves the RESHAPE
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
# RESHAPE a void is removed (hack for AFTER)
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
# reshape of a single element shaped value to scalar is an index
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
# INDEX without src is nothing
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
# RESHAPE+EXPAND -> STACK
(UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE, src=(UPat.var("x"), UPat())), UPat()), name="out"),
lambda x,out: UOp.vectorize(*([x]*out.max_numel())) if out.shape == (out.max_numel(),) else None),
])
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
ctx.acc_num += 1
topo = r.src[0].toposort()
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
acc_initted = acc.after(acc_init, *r.src[1:])
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
return acc.after(acc_out)
def expand_horizontal_reduce(r:UOp):
axes = r.arg[1]
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
pm_reduce_local = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
])+pm_clean_up_group_sink
pm_horizontal_reduce = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
])
# *** memory coalesing ***
def memory_coalesing(sink:UOp):
if getenv("DMC"): return sink
# collect
memory: defaultdict[tuple[UOp, UOp, UOp], dict[int, list[UOp]]] = defaultdict(dict)
for u in sink.toposort():
if u.op in {Ops.LOAD, Ops.STORE} and u.src[0].addrspace != AddrSpace.REG:
assert u.src[0].op is Ops.INDEX
buf,idx_u = u.src[0].src
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
# allowed lengths
lengths = [8,4,2,1]
# build replacements
replacements = {}
for (op,buf,base,valid),offsets in memory.items():
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for full_grp in grouped_offsets:
while len(full_grp):
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.weakint, full_grp[0])
length = [l for l in lengths if l <= len(full_grp) and offset.divides(l) is not None][0]
grp = full_grp[:length]
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op is Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1
datas.append(offsets[g][0].src[1])
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
store = idx.store(data, valid) if valid is not None else idx.store(data)
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
else:
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
for i,g in enumerate(grp):
for oo in offsets[g]:
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
full_grp = full_grp[length:]
# apply
return sink.substitute(replacements, name="memory coalesing")

View file

@ -275,7 +275,7 @@ SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# don't allow broadcast
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1)
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
@dataclass(frozen=True)
class Metadata:

View file

@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
if len(arg) == 0: return UOp(Ops.STACK)
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
elif len(arg) == 1: return UOp.const(dtypes.weakint, arg[0])
else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {}
@ -305,10 +305,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
return self.src[0]._shape
# REDUCE with empty axis is passthrough (lowered form)
case Ops.REDUCE if len(self.arg[1]) == 0:
# these can mismatch if there's a horizonal reduce
return (self.dtype.count,) if self.dtype.count > 1 else ()
# TODO: disallow shape changing bitcast
case Ops.BITCAST:
@ -473,7 +469,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs):
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
return UOp(Ops.STACK, self.dtype, (self,)+srcs)
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx):

View file

@ -192,8 +192,8 @@ spec_program = PatternMatcher([
# no more of these in programs
(UPat(Ops.GEP), lambda: False),
# weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
# weakint is not allowed in programs, except on CONST and STACK
(UPat(GroupOp.All-{Ops.CONST,Ops.STACK}, dtypes.weakint), lambda: False),
# allow special SHRINK
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),