mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
copy in allowed length logic
This commit is contained in:
parent
087ae6436a
commit
919a253ea1
2 changed files with 22 additions and 7 deletions
|
|
@ -121,7 +121,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
||||
|
||||
# do memory coalesing (late)
|
||||
sink = memory_coalesing(sink)
|
||||
sink = memory_coalesing(sink, ren)
|
||||
|
||||
# instruction selection decompositions
|
||||
pm_decomp = pm_decomp+\
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
from typing import Any
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, AddrSpace, Invalid
|
||||
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def memory_coalesing(sink:UOp):
|
||||
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
|
||||
if getenv("DMC"): return sink
|
||||
|
||||
# collect
|
||||
memory: defaultdict[tuple[UOp, UOp, UOp|str, UOp|None], dict[int, list[UOp]]] = defaultdict(dict)
|
||||
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
|
||||
for u in sink.toposort():
|
||||
if u.op in {Ops.LOAD, Ops.STORE} and u.src[0].addrspace != AddrSpace.REG:
|
||||
assert u.src[0].op is Ops.INDEX
|
||||
|
|
@ -24,7 +25,21 @@ def memory_coalesing(sink:UOp):
|
|||
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
|
||||
|
||||
# allowed lengths
|
||||
lengths = [8,4,2,1]
|
||||
lengths = []
|
||||
must_divide = True
|
||||
if ctx is not None and ctx.target.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
must_divide = False
|
||||
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif buf.addrspace == AddrSpace.REG:
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
elif ctx is not None and ctx.supports_float4:
|
||||
# TODO: a better way to get this than ctx
|
||||
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
|
||||
lengths.append(1) # worst case, it's not folded
|
||||
|
||||
# build replacements
|
||||
replacements = {}
|
||||
|
|
@ -33,10 +48,10 @@ def memory_coalesing(sink:UOp):
|
|||
for full_grp in grouped_offsets:
|
||||
while len(full_grp):
|
||||
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
|
||||
length = [l for l in lengths if l <= len(full_grp) and offset.divides(l) is not None][0]
|
||||
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
|
||||
grp = full_grp[:length]
|
||||
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
|
||||
if op is Ops.STORE:
|
||||
if op == Ops.STORE:
|
||||
datas = []
|
||||
for i,g in enumerate(grp):
|
||||
assert len(offsets[g]) == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue