mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
loads are grouped
This commit is contained in:
parent
303b6ba14c
commit
cccd9c2c03
3 changed files with 47 additions and 7 deletions
|
|
@ -58,7 +58,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
if SPEC: type_verify(ast, spec_tensor)
|
||||
|
||||
# preprocess
|
||||
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
|
||||
sink = graph_rewrite(ast, pm_remove_vec_dtypes+pm_mops+pm_syntactic_sugar+pm_store_ranges,
|
||||
ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
|
||||
|
||||
# first we optimize
|
||||
if optimize:
|
||||
|
|
@ -109,7 +110,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
#sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
|
||||
# ctx=ren, name="devectorize")
|
||||
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
|
||||
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, name="devectorize2")
|
||||
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, ctx=ren, name="devectorize2")
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
from typing import Any
|
||||
import itertools, functools
|
||||
from tinygrad.schedule.rangeify import pm_mops
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.helpers import all_same, flatten
|
||||
from tinygrad.dtype import dtypes, AddrSpace, ImageDType, Invalid
|
||||
from tinygrad.helpers import all_same, flatten, getenv
|
||||
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
|
||||
from tinygrad.codegen.late.devectorizer import ReduceContext
|
||||
from tinygrad.uop.symbolic import pm_clean_up_group_sink
|
||||
from tinygrad.renderer import Renderer
|
||||
from collections import defaultdict
|
||||
|
||||
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
|
||||
pm_move_regs = PatternMatcher([
|
||||
|
|
@ -56,13 +59,44 @@ def do_devectorize(b:UOp):
|
|||
for idx in itertools.product(*[range(x) for x in b.shape]):
|
||||
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
|
||||
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
|
||||
if b.op is Ops.STORE: return UOp.group(*src)
|
||||
return UOp.vectorize(*src).reshape(b.shape)
|
||||
|
||||
def new_split_load_store(ls:UOp, midx:UOp):
|
||||
# extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
for i in range(len(midx.src)):
|
||||
idx: Any = midx.src[i].src[1].get_idx()
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
root_src = (midx.src[i].src[1].get_valid(), root_src)
|
||||
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
||||
|
||||
idxs: list[UOp|None] = [None]*len(midx.src)
|
||||
for (valid,_),offsets in offsets_rootsrc.items():
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
lidx = midx.src[offsets[grp[0]][0]]
|
||||
if len(grp) > 1: lidx = lidx.src[0]._mop(Ops.SHRINK, arg=[(lidx.src[1], len(grp))])
|
||||
# do load
|
||||
lidx = lidx.load(lidx.vconst_like(0), valid) if valid != True else lidx.load()
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]:
|
||||
idxs[oo] = lidx.index(UOp.const(dtypes.int, i))
|
||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||
return UOp.vectorize(*idxs)
|
||||
|
||||
#from tinygrad.codegen.late.devectorizer import fold_expanded_index
|
||||
devectorizer2 = pm_mops+PatternMatcher([
|
||||
# TODO: support STORE
|
||||
(UPat((Ops.LOAD,), src=(UPat(Ops.STACK, src=UPat(Ops.INDEX), name="midx"),), name="ls", allow_any_len=True), new_split_load_store),
|
||||
# unpack broadcasting
|
||||
(UPat(GroupOp.Elementwise|{Ops.LOAD, Ops.STORE}, name="b"), do_devectorize),
|
||||
# INDEX into STACK is src
|
||||
(UPat(GroupOp.Elementwise|{Ops.STORE}, name="b"), do_devectorize),
|
||||
# const INDEX into STACK is src
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
|
||||
# stacked INDEX is many INDEX
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
|
||||
|
|
@ -76,6 +110,9 @@ devectorizer2 = pm_mops+PatternMatcher([
|
|||
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
|
||||
# INDEX without src is nothing
|
||||
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
|
||||
# RESHAPE+EXPAND -> STACK
|
||||
(UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE, src=(UPat.var("x"), UPat())), UPat()), name="out"),
|
||||
lambda x,out: UOp.vectorize(*([x]*out.max_numel())) if out.shape == (out.max_numel(),) else None),
|
||||
])
|
||||
|
||||
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
|
||||
|
|
|
|||
|
|
@ -1669,6 +1669,8 @@ pm_lower_index_dtype = PatternMatcher([
|
|||
lambda var,val: var.bind(val).cast(dtypes.weakint)),
|
||||
# remove hanging casts
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("len", dtypes.ints).cast(),), name="shrink"),
|
||||
lambda shrink,buf,idx,len: shrink.replace(src=(buf,idx,len))),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
|
||||
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
|
||||
# remove hanging casts for images
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue