move up migrate + new gated fold (#7403)

* move up migrate + new gated fold [pr]

* vcount for const ptr

* move those rules there

* fix openpilot
This commit is contained in:
George Hotz 2024-10-30 21:14:01 +07:00 committed by GitHub
commit adccfade7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 69 deletions

View file

@ -5,7 +5,7 @@ from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding, finalize
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.shape.shapetracker import ShapeTracker, View
@ -389,12 +389,10 @@ class TestUOpGraph(unittest.TestCase):
glbl1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
glbl2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, ld1+ld0))])
ld0, ld1 = uops[-1].src[-1].src
# ld0 becomes the invalid value
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
@ -404,12 +402,10 @@ class TestUOpGraph(unittest.TestCase):
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier))
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier))
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))])
ld0, ld1 = uops[-1].src[-1].src
# ld0 becomes the invalid value
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.index(lidx+2))
@ -449,7 +445,8 @@ class TestUOpGraph(unittest.TestCase):
def expander_rewrite(sink):
sink = graph_rewrite(sink, sym + expander)
return graph_rewrite(sink, sym + reducer)
sink = graph_rewrite(sink, sym + reducer)
return graph_rewrite(sink, sym + finalize)
def float4_rewrite(sink): return graph_rewrite(sink, sym + expander + float4_folding)
class TestExpander(unittest.TestCase):
@ -660,8 +657,6 @@ class TestLoadStoreFolder(unittest.TestCase):
print(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
def gate_rewrite(sink): return graph_rewrite(sink, sym + expander + reducer)
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
@ -675,12 +670,12 @@ class TestIFUOps(unittest.TestCase):
lbuf = UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier))
store = UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate))
sink = UOp(UOps.SINK, dtypes.void, (store,))
sink = gate_rewrite(sink)
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
self.assertEqual(len(st.src), 2)
def test_expand_ifs_one_gate(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
@ -693,12 +688,12 @@ class TestIFUOps(unittest.TestCase):
lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
stores = [UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
sink = gate_rewrite(sink)
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
self.assertEqual(len(st.src), 2)
# this will be fixed with the merge gated stores bounty
@unittest.expectedFailure
@ -709,12 +704,12 @@ class TestIFUOps(unittest.TestCase):
gate = valid&(lidx.ne(2))
stores = [UOp(UOps.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
sink = gate_rewrite(sink)
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
self.assertEqual(len(st.src), 2)
if __name__ == '__main__':

View file

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, DefaultDict, Callable
import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
@ -61,13 +61,13 @@ def fold_expanded(ex, buf):
return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
def fix_unfoldable_image_load(load:UOp, buf:UOp):
if not isinstance(buf.dtype, ImageDType) or load.src[1].dtype.count == 2: return None
id4 = load.src[1] % 4
if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
id4 = oidx % 4
new_src = list(load.src)
# TODO: copied logic from above
new_src[1] = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((load.src[1] // 4) % buf.dtype.shape[1], (load.src[1] // (4 * buf.dtype.shape[1]))))
if len(new_src) >= 4:
new_src[2] = UOp(UOps.VECTORIZE, new_src[2].dtype.vec(4), tuple(new_src[2] for _ in range(4)))
new_src[0] = load.src[0].src[0].index(
UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
load.src[0].src[2] if len(load.src[0].src) == 3 else None)
vec_load = UOp(UOps.LOAD, load.dtype.vec(4), tuple(new_src))
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
@ -78,16 +78,12 @@ float4_folding = PatternMatcher([
# ***** image load valid simplification *****
def simplify_buffer_load(load:UOp) -> Optional[UOp]:
if not isinstance(load.src[0].dtype, PtrDType) or len(load.src) != 4: return None
buf, start_idx, invalid_val, valid = load.src
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
return None if idx is start_idx else load.replace(src=((buf, idx, invalid_val, valid)))
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
def simplify_image_load(load:UOp) -> Optional[UOp]:
if not isinstance(buf_dtype:=load.src[0].dtype, ImageDType) or len(load.src) != 4: return None
buf, start_idx, invalid_val, valid = load.src
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
# wait for it to be image indexed before running simplification
if start_idx.dtype.count != 2: return None
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
@ -105,7 +101,7 @@ def simplify_image_load(load:UOp) -> Optional[UOp]:
# if X <= c, check if it's out of bound when X = c+1
# if X >= c, check if it's out of bound when X = c-1
test_value = c + 1 if is_upper_bound else c - 1
for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])):
for i,b in zip(idx.src, (buf.dtype.shape[1], buf.dtype.shape[0])):
if is_increasing(i):
rw = graph_rewrite(i.substitute({X:X.const_like(test_value)}), sym)
if rw.vmin >= b or rw.vmax < 0:
@ -114,7 +110,8 @@ def simplify_image_load(load:UOp) -> Optional[UOp]:
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, BinaryOps.AND) if s not in drop_stmt]) else None
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))
return buf.index(idx, new_valid)
# ***** optional patterns *****
@ -290,12 +287,10 @@ sym = symbolic_flat+PatternMatcher([
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat.var("buf"), UPat.var("idx")))),
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
# fold gated LOAD/STORE
(UPat.load(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, True), name="ld"), lambda ld: ld.replace(src=ld.src[:2])),
(UPat.load(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, True), UPat.var("bar"), name="ld"), lambda ld,bar: ld.replace(src=ld.src[:2]+(bar,))),
(UPat.load(UPat(), UPat(), UPat.var("var"), UPat.const(dtypes.bool, False)), lambda var: var),
(UPat.load(UPat(), UPat(), UPat.var("var"), UPat.const(dtypes.bool, False), UPat()), lambda var: var),
(UPat.store(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, True), name="store"), lambda store: store.replace(src=store.src[:3])),
(UPat.store(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
(UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
(UPat(UOps.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
(UPat(UOps.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(UOps.NOOP)), # NULL pointer store does nothing
# remove NOOPs from SINK
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
@ -450,14 +445,6 @@ def no_vectorized_acc(acc:UOp):
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
return UOp(UOps.VECTORIZE, acc.dtype, alus)
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
@functools.lru_cache(None)
def find_gate(x:UOp) -> Optional[UOp]:
if x.op is UOps.IF: return x
return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
if len(root.src) == 3 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[3]: return None
return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg)
just_reduce = PatternMatcher([
# do reduce
(UPat(UOps.REDUCE, name="root"), do_reduce),
@ -472,20 +459,12 @@ devectorize = PatternMatcher([
])
reducer = PatternMatcher([
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
# delete_redundant_gates (after expand, is this still needed?)
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
# late fixup of unfoldable image loads
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
# simplify valid
(UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
# image load valid idx simplification
(UPat(UOps.LOAD, name="load"), simplify_image_load),
# buffer load valid idx simplification
(UPat(UOps.LOAD, name="load"), simplify_buffer_load),
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
])
def idx_load_store(x:UOp):
@ -506,14 +485,25 @@ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
masked_index = UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask")))
move_masks = PatternMatcher([
# NOTE: this shouldn't be here
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
@functools.lru_cache(None)
def find_gate(x:UOp) -> Optional[UOp]:
if x.op is UOps.IF: return x
return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
if len(root.src) == 2 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[2]: return None
return UOp(UOps.STORE, root.dtype, root.src[:2], root.arg)
finalize = PatternMatcher([
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
# fix up loads/stores
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
# move masks of loads/stores
# TODO: this should be an IF instead of a masked STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index, masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
# delete_redundant_gates (after expand)
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
])
# *** uop graph ***
@ -532,13 +522,15 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
# convert REDUCE to DEFINE_ACC + ASSIGN
sink = graph_rewrite(sink, sym+just_reduce)
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
sink = graph_rewrite(sink, sym+reducer)
# temp for indexing migration
sink = graph_rewrite(sink, sym+migrate_indexing)
# cleanups
sink = graph_rewrite(sink, sym+reducer)
# finalize
sink = graph_rewrite(sink, sym+move_masks+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
sink = graph_rewrite(sink, sym+finalize+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
return sink

View file

@ -561,7 +561,7 @@ class UPat(MathTrait):
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b)
# copied from UOp
def index(self, idx:UPat): return UPat(UOps.INDEX, self.dtype, (self,idx))
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))