mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
lengths
This commit is contained in:
parent
5238f304c7
commit
4e398e3f1f
2 changed files with 29 additions and 19 deletions
|
|
@ -114,12 +114,15 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# ctx=ren, name="devectorize")
|
||||
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
|
||||
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, ctx=ren, name="devectorize2")
|
||||
sink = graph_rewrite(sink, symbolic, name="pre memory coalese")
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||
|
||||
# memory coalesing
|
||||
sink = memory_coalesing(sink)
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
# again
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ def memory_coalesing(sink:UOp):
|
|||
if u.op in {Ops.LOAD, Ops.STORE} and u.src[0].addrspace != AddrSpace.REG:
|
||||
assert u.src[0].op is Ops.INDEX
|
||||
buf,idx_u = u.src[0].src
|
||||
idx: Any = idx_u.get_idx()
|
||||
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
|
||||
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
|
|
@ -128,26 +128,33 @@ def memory_coalesing(sink:UOp):
|
|||
else: root_src, arg = idx, 0
|
||||
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
|
||||
|
||||
# allowed lengths
|
||||
lengths = [8,4,2,1]
|
||||
|
||||
# build replacements
|
||||
replacements = {}
|
||||
for (op,buf,base,valid),offsets in memory.items():
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for grp in grouped_offsets:
|
||||
offset = (base+grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.weakint, grp[0])
|
||||
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
|
||||
if op is Ops.STORE:
|
||||
datas = []
|
||||
for i,g in enumerate(grp):
|
||||
assert len(offsets[g]) == 1
|
||||
datas.append(offsets[g][0].src[1])
|
||||
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
|
||||
store = idx.store(data, valid) if valid is not None else idx.store(data)
|
||||
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
|
||||
else:
|
||||
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]:
|
||||
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
|
||||
for full_grp in grouped_offsets:
|
||||
while len(full_grp):
|
||||
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.weakint, full_grp[0])
|
||||
length = [l for l in lengths if l <= len(full_grp) and offset.divides(l) is not None][0]
|
||||
grp = full_grp[:length]
|
||||
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
|
||||
if op is Ops.STORE:
|
||||
datas = []
|
||||
for i,g in enumerate(grp):
|
||||
assert len(offsets[g]) == 1
|
||||
datas.append(offsets[g][0].src[1])
|
||||
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
|
||||
store = idx.store(data, valid) if valid is not None else idx.store(data)
|
||||
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
|
||||
else:
|
||||
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]:
|
||||
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
|
||||
full_grp = full_grp[length:]
|
||||
|
||||
# apply
|
||||
return sink.substitute(replacements, name="memory coalesing")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue