mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
switch to the new memory coaleser [pr] (#16716)
* switch to the new memory coalese * move that stuff * copy in allowed length logic * mulitple buffers * new coalese is better * fine * earlier * fixes * work * work * valid * stack on index const
This commit is contained in:
parent
dfea9e7994
commit
0a8e61d0c5
7 changed files with 90 additions and 18 deletions
|
|
@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
|
|||
step = THREADS_PER_WG // 2
|
||||
while step:
|
||||
active = tid < step
|
||||
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
|
||||
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
|
||||
other = lds[(tid + step).valid(active)].load()
|
||||
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
|
||||
step //= 2
|
||||
|
||||
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
||||
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
|
||||
assert num_loads <= 4, "more load uops than needed"
|
||||
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
|
||||
assert num_loads >= 1, "expected at least one load uop"
|
||||
|
||||
@unittest.skip("this is handled at higher level now")
|
||||
def test_upcast_cse(self):
|
||||
|
|
|
|||
|
|
@ -16,13 +16,11 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
|
|||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
return UOp(Ops.LOAD, dtypes.float, (
|
||||
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
|
||||
UOp.const(dtypes.float, 0.0)
|
||||
))
|
||||
|
||||
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
|
||||
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||
))
|
||||
|
||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
|
|||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
||||
from tinygrad.codegen.late.coalese import memory_coalesing
|
||||
|
||||
pm_index_is_shrink = PatternMatcher([
|
||||
# rewrite non-image INDEX to SHRINK
|
||||
|
|
@ -52,6 +53,10 @@ pm_number_params = PatternMatcher([
|
|||
(UPat(Ops.PARAM, name="x"), do_number_param),
|
||||
])
|
||||
|
||||
pm_no_weakints = PatternMatcher([
|
||||
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int))
|
||||
])
|
||||
|
||||
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
|
|
@ -119,6 +124,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
sink = graph_rewrite(sink, pm_decomp, name="early decompositions")
|
||||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
||||
|
||||
# do memory coalesing (late)
|
||||
sink = memory_coalesing(sink, ren)
|
||||
|
||||
# instruction selection decompositions
|
||||
pm_decomp = pm_decomp+\
|
||||
get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))+\
|
||||
|
|
@ -135,7 +143,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
|
||||
# final rules for the renderer (without sym)
|
||||
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
||||
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends
|
||||
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends+pm_no_weakints
|
||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
|
||||
|
||||
# this was the linearizer
|
||||
|
|
|
|||
73
tinygrad/codegen/late/coalese.py
Normal file
73
tinygrad/codegen/late/coalese.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from typing import Any
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
|
||||
if getenv("DMC"): return sink
|
||||
|
||||
# collect
|
||||
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
|
||||
for u in sink.toposort():
|
||||
# TODO: this should handle images too, it's just memory coalesing
|
||||
if u.op in {Ops.LOAD, Ops.STORE} and not isinstance(u.src[0].src[0].dtype, ImageDType):
|
||||
assert len(u.src) == (2 if u.op is Ops.STORE else 1), "memory coalesing does not support gated loads/stores"
|
||||
assert u.src[0].op is Ops.INDEX
|
||||
buf, idx_u = u.src[0].src
|
||||
if buf.addrspace == AddrSpace.REG: continue
|
||||
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
|
||||
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
|
||||
|
||||
# build replacements
|
||||
replacements = {}
|
||||
for (op,buf,base,valid),offsets in memory.items():
|
||||
# allowed lengths (copied in)
|
||||
lengths = []
|
||||
must_divide = True
|
||||
if ctx is not None and ctx.target.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
must_divide = False
|
||||
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif buf.addrspace == AddrSpace.REG:
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
elif ctx is not None and ctx.supports_float4:
|
||||
# TODO: a better way to get this than ctx
|
||||
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
|
||||
lengths.append(1) # worst case, it's not folded
|
||||
# do the grouping
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for full_grp in grouped_offsets:
|
||||
while len(full_grp):
|
||||
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
|
||||
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
|
||||
grp = full_grp[:length]
|
||||
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
|
||||
if op == Ops.STORE:
|
||||
datas = []
|
||||
for i,g in enumerate(grp):
|
||||
assert len(offsets[g]) == 1, f"attempting multiple stores: {len(offsets[g])}"
|
||||
datas.append(offsets[g][0].src[1])
|
||||
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
|
||||
store = idx.store(data, valid) if valid is not None else idx.store(data)
|
||||
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
|
||||
else:
|
||||
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]:
|
||||
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
|
||||
full_grp = full_grp[length:]
|
||||
|
||||
# apply
|
||||
return sink.substitute(replacements, name="memory coalesing")
|
||||
|
|
@ -162,18 +162,8 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
|||
# determine fold lengths
|
||||
lengths = []
|
||||
must_divide = True
|
||||
if ctx is not None and ctx.target.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
must_divide = False
|
||||
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif buf.addrspace == AddrSpace.REG:
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
elif ctx is not None and ctx.supports_float4:
|
||||
# TODO: a better way to get this than ctx
|
||||
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
|
||||
# TODO: this belongs in coalese
|
||||
if isinstance(buf.dtype, ImageDType): lengths = [4]
|
||||
lengths.append(1) # worst case, it's not folded
|
||||
|
||||
# filter fold lengths that don't divide
|
||||
|
|
|
|||
|
|
@ -170,6 +170,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
||||
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
||||
# STACK on INDEX CONST (TODO: remove all the GEP crap)
|
||||
(UPat(Ops.STACK, src=UPat(Ops.INDEX, src=(UPat.var("src"), UPat(Ops.CONST))), name="stk"),
|
||||
lambda src,stk: src if stk.shape == src.shape and list(range(len(stk.src))) == [x.src[1].arg for x in stk.src] else None),
|
||||
])
|
||||
|
||||
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue