mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
cleanup loads
This commit is contained in:
parent
05249466ed
commit
e143904deb
2 changed files with 10 additions and 17 deletions
|
|
@ -17,7 +17,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, p
|
|||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
||||
ReduceContext, correct_load_store, pm_render, pm_make_images
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
|
|
@ -56,6 +56,14 @@ pm_number_params = PatternMatcher([
|
|||
(UPat(Ops.PARAM, name="x"), do_number_param),
|
||||
])
|
||||
|
||||
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
|
||||
pm_load_to_alu = PatternMatcher([
|
||||
# NOTE: the PtrDType thing is temporary
|
||||
(UPat(GroupOp.Elementwise|{Ops.STACK}, name="x"), lambda x:
|
||||
x.replace(src=tuple([maybe_load(u) for u in x.src])) if not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
|
||||
])
|
||||
|
||||
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
|
|
@ -100,7 +108,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# **** optimizations are done, now we lower to actual code ****
|
||||
|
||||
# add loads and remove invalids
|
||||
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
||||
sink = graph_rewrite(sink, pm_load_to_alu+pm_remove_invalid, name="** add loads (code)")
|
||||
|
||||
# create image buffers
|
||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
||||
|
|
|
|||
|
|
@ -354,21 +354,6 @@ pm_reduce = PatternMatcher([
|
|||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
])
|
||||
|
||||
# add loads
|
||||
|
||||
def add_load(idx:UOp):
|
||||
if isinstance(idx.dtype, PtrDType): return None
|
||||
assert isinstance(idx.src[0].dtype, PtrDType), f"param is not PtrDType {idx.src[0].dtype}"
|
||||
return idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)
|
||||
|
||||
pm_add_loads = PatternMatcher([
|
||||
# add loads to non ptr index
|
||||
(UPat(Ops.INDEX, name="idx"), add_load),
|
||||
# remove loads from stores
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.LOAD),), allow_any_len=True, name="l"), lambda l: l.replace(src=(l.src[0].src[0],)+l.src[1:])),
|
||||
])
|
||||
|
||||
# make images
|
||||
|
||||
pm_imageh_store = PatternMatcher([
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue