Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
7a214c4499
Merge branch 'master' into clean_load 2026-06-19 16:56:57 -07:00
George Hotz
3e16109eb6 okay w/e 2026-06-18 21:00:36 -07:00
George Hotz
f79a7fc7c6
Merge branch 'master' into clean_load 2026-06-18 20:54:45 -07:00
George Hotz
3526f8272b a few fixups 2026-06-18 20:53:30 -07:00
George Hotz
e143904deb cleanup loads 2026-06-18 18:24:59 -07:00
3 changed files with 12 additions and 18 deletions

View file

@ -16,7 +16,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, p
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images ReduceContext, correct_load_store, pm_render, pm_make_images
from tinygrad.codegen.opt.postrange import apply_opts from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.late.gater import pm_move_gates_from_index from tinygrad.codegen.late.gater import pm_move_gates_from_index
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
@ -52,6 +52,14 @@ pm_number_params = PatternMatcher([
(UPat(Ops.PARAM, name="x"), do_number_param), (UPat(Ops.PARAM, name="x"), do_number_param),
]) ])
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_load_to_alu = PatternMatcher([
# NOTE: the PtrDType thing is temporary
(UPat(GroupOp.Elementwise|{Ops.STACK,Ops.GEP}, name="x"), lambda x:
x.replace(src=tuple([maybe_load(u) for u in x.src])) if not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
])
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast)) if DEBUG >= 5: print(pyrender(ast))
@ -96,7 +104,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# **** optimizations are done, now we lower to actual code **** # **** optimizations are done, now we lower to actual code ****
# add loads and remove invalids # add loads and remove invalids
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)") sink = graph_rewrite(sink, pm_load_to_alu+pm_remove_invalid, name="** add loads (code)")
# create image buffers # create image buffers
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}: if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:

View file

@ -289,6 +289,7 @@ pm_render = PatternMatcher([
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), (UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None), (UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x), (UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
(UPat(Ops.PTRCAT, src=(UPat(name='x'),)), lambda x: x),
]) ])
# *** Ops.REDUCE -> Ops.DEFINE_ACC *** # *** Ops.REDUCE -> Ops.DEFINE_ACC ***
@ -356,21 +357,6 @@ pm_reduce = PatternMatcher([
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
]) ])
# add loads
def add_load(idx:UOp):
if isinstance(idx.dtype, PtrDType): return None
assert isinstance(idx.src[0].dtype, PtrDType), f"param is not PtrDType {idx.src[0].dtype}"
return idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)
pm_add_loads = PatternMatcher([
# add loads to non ptr index
(UPat(Ops.INDEX, name="idx"), add_load),
# remove loads from stores
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
(UPat(Ops.LOAD, src=(UPat(Ops.LOAD),), allow_any_len=True, name="l"), lambda l: l.replace(src=(l.src[0].src[0],)+l.src[1:])),
])
# make images # make images
pm_imageh_store = PatternMatcher([ pm_imageh_store = PatternMatcher([

View file

@ -797,7 +797,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}: if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
return self.src[0].addrspace return self.src[0].addrspace
if self.op in GroupOp.Movement: return self.src[0].addrspace if self.op in GroupOp.Movement: return self.src[0].addrspace
if self.op in {Ops.STACK, Ops.WMMA} or self.op in GroupOp.Elementwise: if self.op in {Ops.STACK, Ops.PTRCAT, Ops.WMMA} or self.op in GroupOp.Elementwise:
ad = [x.addrspace for x in self.src if x.addrspace is not None] ad = [x.addrspace for x in self.src if x.addrspace is not None]
if not len(ad) or not all_same(ad): return None if not len(ad) or not all_same(ad): return None
return ad[0] return ad[0]