remove unused reduce rules + improve unparented (#5908)

* remove unused reduce rules [run_process_replay]

* this work

* those tests are meaningless now
This commit is contained in:
George Hotz 2024-08-04 18:18:27 -07:00 committed by GitHub
commit 159ac06b5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 29 additions and 33 deletions

View file

@ -257,18 +257,20 @@ class TestLinearizer(unittest.TestCase):
out = a.reshape(2, 1).expand(2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
# RANGE -> LOAD -> RANGE -> PHI
assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
def test_three_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
# RANGE -> RANGE -> LOAD -> RANGE -> PHI
# NOTE: nothing should toposort between the first two ranges
assert ranges[0]+1 == ranges[1]
assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]])
#assert ranges[0]+1 == ranges[1]
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]])
def test_two_nested_range_alt_indexing(self):
a = Tensor([2, 2]).realize()
@ -289,23 +291,23 @@ class TestLinearizer(unittest.TestCase):
# LOAD -> RANGE -> LOAD -> PHI
assert lin.uops[ranges[0]-2].op is UOps.LOAD
# TODO: this test is brittle
def test_range_outer_op_before_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
b = Tensor.randn(1, 1).realize()
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
if getenv("PTX"):
assert len(ranges) == 1 # NOTE: it collapses now
#if getenv("PTX"):
# LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> PHI
assert lin.uops[ranges[0]-2].op is UOps.LOAD
assert ranges[1] == ranges[0]+6
assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
# assert ranges[1] == ranges[0]+6
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
# LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI
else:
assert lin.uops[ranges[0]-2].op is UOps.LOAD
assert ranges[1] == ranges[0]+3
assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
#else:
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
# assert ranges[1] == ranges[0]+3
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
def test_range_outer_op_after_phi(self):
a = Tensor.randn(4, 1).realize()

View file

@ -4,7 +4,7 @@ import functools, itertools, heapq, math, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition
from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
if TYPE_CHECKING: from tinygrad.renderer import Renderer
@ -132,11 +132,6 @@ def reduce_before_expand(reduce, expand, x):
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg)
return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), i) for i in range(x.dtype.count)), expand.arg)
def sum_collapse(phi_input, loop, val1, val2):
for v1,v2 in [(val1, val2), (val2, val1)]:
if loop not in v1.parents: return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+v1*(loop.src[1]-loop.src[0]).cast(v1.dtype)
return None
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None):
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
if mval.arg >= 0 or loop_start.arg != 0:
@ -175,11 +170,6 @@ constant_folder = PatternMatcher([
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry
(NOp(UOps.ALU, dtype=dtypes.uint64, src=(NOp.var("x"), NOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
# sum collapse to mul (with possible GEP)
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
# extra arange loop folding because we don't fold adds. TODO: fold adds
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
NOp.var("idx2") + NOp.var("idx3"))
@ -207,12 +197,9 @@ constant_folder = PatternMatcher([
# const rules
(NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const(c.arg)),
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)),
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
(NOp(UOps.PHI, src=(NOp(UOps.DEFINE_ACC, name="acc"), NOp.var("acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
(NOp(UOps.PHI, src=(NOp(UOps.DEFINE_ACC, src=(NOp.cvar(),)), NOp.var("x"))), lambda x: x),
(NOp(UOps.PHI, src=(NOp.cvar(), NOp.var("x"))), lambda x: x),
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
(NOp(UOps.DEFINE_ACC, src=(NOp.cvar(),), name="root"), lambda root: UOp.cast(root.src[0], root.dtype)),
# a REDUCE without ranges is a NOOP
(NOp(UOps.REDUCE, src=(NOp.var('x'),)), lambda x: x),
# GEP on a const is the const
(NOp(UOps.GEP, src=(NOp.cvar("x"),), name="root"), lambda root,x: root.const(x.arg)),
# a conditional with the same results either way is a noop, also fold const conditionals
(NOp.var().where(NOp.var("val"), NOp.var("val")), lambda val: val),
@ -371,10 +358,17 @@ def do_expand(root:UOp):
acc_number = 0
def do_reduce(root):
global acc_number
const = UOp.const(root.dtype.scalar(), 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype))
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,))
acc_number += 1
return UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, root.src[0])))
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].parents)
ret = root.src[0]
if len(reduce_parented):
const = UOp.const(root.dtype.scalar(), 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype))
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
acc_number += 1
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
# for MAX, we can just ignore the unparented
if root.arg is BinaryOps.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype)
return ret
def do_contract(con:UOp):
ex = con.src[0]