copy in allowed length logic

This commit is contained in:
George Hotz 2026-06-23 16:15:02 -07:00
commit 919a253ea1
2 changed files with 22 additions and 7 deletions

View file

@ -121,7 +121,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
# do memory coalesing (late)
sink = memory_coalesing(sink)
sink = memory_coalesing(sink, ren)
# instruction selection decompositions
pm_decomp = pm_decomp+\

View file

@ -1,15 +1,16 @@
from typing import Any
import itertools
from collections import defaultdict
from tinygrad.dtype import dtypes, AddrSpace, Invalid
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import getenv
from tinygrad.renderer import Renderer
def memory_coalesing(sink:UOp):
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
if getenv("DMC"): return sink
# collect
memory: defaultdict[tuple[UOp, UOp, UOp|str, UOp|None], dict[int, list[UOp]]] = defaultdict(dict)
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
for u in sink.toposort():
if u.op in {Ops.LOAD, Ops.STORE} and u.src[0].addrspace != AddrSpace.REG:
assert u.src[0].op is Ops.INDEX
@ -24,7 +25,21 @@ def memory_coalesing(sink:UOp):
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
# allowed lengths
lengths = [8,4,2,1]
lengths = []
must_divide = True
if ctx is not None and ctx.target.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded
# build replacements
replacements = {}
@ -33,10 +48,10 @@ def memory_coalesing(sink:UOp):
for full_grp in grouped_offsets:
while len(full_grp):
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
length = [l for l in lengths if l <= len(full_grp) and offset.divides(l) is not None][0]
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
grp = full_grp[:length]
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op is Ops.STORE:
if op == Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1