This commit is contained in:
George Hotz 2026-06-23 09:18:51 -07:00
commit 4e398e3f1f
2 changed files with 29 additions and 19 deletions

View file

@ -114,12 +114,15 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# ctx=ren, name="devectorize")
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, ctx=ren, name="devectorize2")
sink = graph_rewrite(sink, symbolic, name="pre memory coalese")
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
# memory coalesing
sink = memory_coalesing(sink)
# lower the index dtype to a concrete int
# again
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
sink = graph_rewrite(sink, symbolic, name="post index symbolic")

View file

@ -119,7 +119,7 @@ def memory_coalesing(sink:UOp):
if u.op in {Ops.LOAD, Ops.STORE} and u.src[0].addrspace != AddrSpace.REG:
assert u.src[0].op is Ops.INDEX
buf,idx_u = u.src[0].src
idx: Any = idx_u.get_idx()
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
@ -128,26 +128,33 @@ def memory_coalesing(sink:UOp):
else: root_src, arg = idx, 0
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
# allowed lengths
lengths = [8,4,2,1]
# build replacements
replacements = {}
for (op,buf,base,valid),offsets in memory.items():
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for grp in grouped_offsets:
offset = (base+grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.weakint, grp[0])
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op is Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1
datas.append(offsets[g][0].src[1])
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
store = idx.store(data, valid) if valid is not None else idx.store(data)
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
else:
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
for i,g in enumerate(grp):
for oo in offsets[g]:
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
for full_grp in grouped_offsets:
while len(full_grp):
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.weakint, full_grp[0])
length = [l for l in lengths if l <= len(full_grp) and offset.divides(l) is not None][0]
grp = full_grp[:length]
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op is Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1
datas.append(offsets[g][0].src[1])
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
store = idx.store(data, valid) if valid is not None else idx.store(data)
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
else:
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
for i,g in enumerate(grp):
for oo in offsets[g]:
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
full_grp = full_grp[length:]
# apply
return sink.substitute(replacements, name="memory coalesing")