cleanup loads

This commit is contained in:
George Hotz 2026-06-18 18:24:59 -07:00
commit e143904deb
2 changed files with 10 additions and 17 deletions

View file

@ -17,7 +17,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, p
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
ReduceContext, correct_load_store, pm_render, pm_make_images
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.late.gater import pm_move_gates_from_index
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
@ -56,6 +56,14 @@ pm_number_params = PatternMatcher([
(UPat(Ops.PARAM, name="x"), do_number_param),
])
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_load_to_alu = PatternMatcher([
# NOTE: the PtrDType thing is temporary
(UPat(GroupOp.Elementwise|{Ops.STACK}, name="x"), lambda x:
x.replace(src=tuple([maybe_load(u) for u in x.src])) if not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
])
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
@ -100,7 +108,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# **** optimizations are done, now we lower to actual code ****
# add loads and remove invalids
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
sink = graph_rewrite(sink, pm_load_to_alu+pm_remove_invalid, name="** add loads (code)")
# create image buffers
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:

View file

@ -354,21 +354,6 @@ pm_reduce = PatternMatcher([
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
])
# add loads
def add_load(idx:UOp):
if isinstance(idx.dtype, PtrDType): return None
assert isinstance(idx.src[0].dtype, PtrDType), f"param is not PtrDType {idx.src[0].dtype}"
return idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)
pm_add_loads = PatternMatcher([
# add loads to non ptr index
(UPat(Ops.INDEX, name="idx"), add_load),
# remove loads from stores
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
(UPat(Ops.LOAD, src=(UPat(Ops.LOAD),), allow_any_len=True, name="l"), lambda l: l.replace(src=(l.src[0].src[0],)+l.src[1:])),
])
# make images
pm_imageh_store = PatternMatcher([