mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
clean_load
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a214c4499 |
||
|
|
3e16109eb6 | ||
|
|
f79a7fc7c6 |
||
|
|
3526f8272b | ||
|
|
e143904deb |
3 changed files with 12 additions and 18 deletions
|
|
@ -16,7 +16,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, p
|
|||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
||||
ReduceContext, correct_load_store, pm_render, pm_make_images
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
|
|
@ -52,6 +52,14 @@ pm_number_params = PatternMatcher([
|
|||
(UPat(Ops.PARAM, name="x"), do_number_param),
|
||||
])
|
||||
|
||||
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
|
||||
pm_load_to_alu = PatternMatcher([
|
||||
# NOTE: the PtrDType thing is temporary
|
||||
(UPat(GroupOp.Elementwise|{Ops.STACK,Ops.GEP}, name="x"), lambda x:
|
||||
x.replace(src=tuple([maybe_load(u) for u in x.src])) if not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
|
||||
])
|
||||
|
||||
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
|
|
@ -96,7 +104,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# **** optimizations are done, now we lower to actual code ****
|
||||
|
||||
# add loads and remove invalids
|
||||
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
||||
sink = graph_rewrite(sink, pm_load_to_alu+pm_remove_invalid, name="** add loads (code)")
|
||||
|
||||
# create image buffers
|
||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
||||
|
|
|
|||
|
|
@ -289,6 +289,7 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||
(UPat(Ops.PTRCAT, src=(UPat(name='x'),)), lambda x: x),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
|
@ -356,21 +357,6 @@ pm_reduce = PatternMatcher([
|
|||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
])
|
||||
|
||||
# add loads
|
||||
|
||||
def add_load(idx:UOp):
|
||||
if isinstance(idx.dtype, PtrDType): return None
|
||||
assert isinstance(idx.src[0].dtype, PtrDType), f"param is not PtrDType {idx.src[0].dtype}"
|
||||
return idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)
|
||||
|
||||
pm_add_loads = PatternMatcher([
|
||||
# add loads to non ptr index
|
||||
(UPat(Ops.INDEX, name="idx"), add_load),
|
||||
# remove loads from stores
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.LOAD),), allow_any_len=True, name="l"), lambda l: l.replace(src=(l.src[0].src[0],)+l.src[1:])),
|
||||
])
|
||||
|
||||
# make images
|
||||
|
||||
pm_imageh_store = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -797,7 +797,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
|
||||
return self.src[0].addrspace
|
||||
if self.op in GroupOp.Movement: return self.src[0].addrspace
|
||||
if self.op in {Ops.STACK, Ops.WMMA} or self.op in GroupOp.Elementwise:
|
||||
if self.op in {Ops.STACK, Ops.PTRCAT, Ops.WMMA} or self.op in GroupOp.Elementwise:
|
||||
ad = [x.addrspace for x in self.src if x.addrspace is not None]
|
||||
if not len(ad) or not all_same(ad): return None
|
||||
return ad[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue