loads are grouped

This commit is contained in:
George Hotz 2026-06-22 16:12:18 -07:00
commit cccd9c2c03
3 changed files with 47 additions and 7 deletions

View file

@ -58,7 +58,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if SPEC: type_verify(ast, spec_tensor)
# preprocess
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
sink = graph_rewrite(ast, pm_remove_vec_dtypes+pm_mops+pm_syntactic_sugar+pm_store_ranges,
ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
# first we optimize
if optimize:
@ -109,7 +110,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
#sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
# ctx=ren, name="devectorize")
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, name="devectorize2")
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, ctx=ren, name="devectorize2")
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")

View file

@ -1,12 +1,15 @@
from typing import Any
import itertools, functools
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.helpers import all_same, flatten
from tinygrad.dtype import dtypes, AddrSpace, ImageDType, Invalid
from tinygrad.helpers import all_same, flatten, getenv
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
from tinygrad.codegen.late.devectorizer import ReduceContext
from tinygrad.uop.symbolic import pm_clean_up_group_sink
from tinygrad.renderer import Renderer
from collections import defaultdict
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_move_regs = PatternMatcher([
@ -56,13 +59,44 @@ def do_devectorize(b:UOp):
for idx in itertools.product(*[range(x) for x in b.shape]):
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
if b.op is Ops.STORE: return UOp.group(*src)
return UOp.vectorize(*src).reshape(b.shape)
def new_split_load_store(ls:UOp, midx:UOp):
# extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
for i in range(len(midx.src)):
idx: Any = midx.src[i].src[1].get_idx()
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
root_src = (midx.src[i].src[1].get_valid(), root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
idxs: list[UOp|None] = [None]*len(midx.src)
for (valid,_),offsets in offsets_rootsrc.items():
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for grp in grouped_offsets:
# get the index offset for this element. using [0] is okay, because they are the same
lidx = midx.src[offsets[grp[0]][0]]
if len(grp) > 1: lidx = lidx.src[0]._mop(Ops.SHRINK, arg=[(lidx.src[1], len(grp))])
# do load
lidx = lidx.load(lidx.vconst_like(0), valid) if valid != True else lidx.load()
# set the idxs of the output
for i,g in enumerate(grp):
for oo in offsets[g]:
idxs[oo] = lidx.index(UOp.const(dtypes.int, i))
assert None not in idxs, f"some idxs are missing {idxs}"
return UOp.vectorize(*idxs)
#from tinygrad.codegen.late.devectorizer import fold_expanded_index
devectorizer2 = pm_mops+PatternMatcher([
# TODO: support STORE
(UPat((Ops.LOAD,), src=(UPat(Ops.STACK, src=UPat(Ops.INDEX), name="midx"),), name="ls", allow_any_len=True), new_split_load_store),
# unpack broadcasting
(UPat(GroupOp.Elementwise|{Ops.LOAD, Ops.STORE}, name="b"), do_devectorize),
# INDEX into STACK is src
(UPat(GroupOp.Elementwise|{Ops.STORE}, name="b"), do_devectorize),
# const INDEX into STACK is src
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
# stacked INDEX is many INDEX
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
@ -76,6 +110,9 @@ devectorizer2 = pm_mops+PatternMatcher([
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
# INDEX without src is nothing
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
# RESHAPE+EXPAND -> STACK
(UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE, src=(UPat.var("x"), UPat())), UPat()), name="out"),
lambda x,out: UOp.vectorize(*([x]*out.max_numel())) if out.shape == (out.max_numel(),) else None),
])
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):

View file

@ -1669,6 +1669,8 @@ pm_lower_index_dtype = PatternMatcher([
lambda var,val: var.bind(val).cast(dtypes.weakint)),
# remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("len", dtypes.ints).cast(),), name="shrink"),
lambda shrink,buf,idx,len: shrink.replace(src=(buf,idx,len))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
# remove hanging casts for images