mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove unused reduce rules + improve unparented (#5908)
* remove unused reduce rules [run_process_replay] * this work * those tests are meaningless now
This commit is contained in:
parent
d7387d31bf
commit
159ac06b5b
2 changed files with 29 additions and 33 deletions
|
|
@ -257,18 +257,20 @@ class TestLinearizer(unittest.TestCase):
|
|||
out = a.reshape(2, 1).expand(2, 3).sum()
|
||||
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
assert len(ranges) == 1 # NOTE: it collapses now
|
||||
# RANGE -> LOAD -> RANGE -> PHI
|
||||
assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
|
||||
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
|
||||
|
||||
def test_three_nested_range(self):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
|
||||
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
assert len(ranges) == 1 # NOTE: it collapses now
|
||||
# RANGE -> RANGE -> LOAD -> RANGE -> PHI
|
||||
# NOTE: nothing should toposort between the first two ranges
|
||||
assert ranges[0]+1 == ranges[1]
|
||||
assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]])
|
||||
#assert ranges[0]+1 == ranges[1]
|
||||
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]])
|
||||
|
||||
def test_two_nested_range_alt_indexing(self):
|
||||
a = Tensor([2, 2]).realize()
|
||||
|
|
@ -289,23 +291,23 @@ class TestLinearizer(unittest.TestCase):
|
|||
# LOAD -> RANGE -> LOAD -> PHI
|
||||
assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
|
||||
# TODO: this test is brittle
|
||||
def test_range_outer_op_before_phi_nested_range(self):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
b = Tensor.randn(1, 1).realize()
|
||||
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
if getenv("PTX"):
|
||||
assert len(ranges) == 1 # NOTE: it collapses now
|
||||
#if getenv("PTX"):
|
||||
# LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> PHI
|
||||
assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
assert ranges[1] == ranges[0]+6
|
||||
assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
|
||||
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
# assert ranges[1] == ranges[0]+6
|
||||
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
|
||||
# LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI
|
||||
else:
|
||||
assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
assert ranges[1] == ranges[0]+3
|
||||
assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
|
||||
#else:
|
||||
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
# assert ranges[1] == ranges[0]+3
|
||||
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
|
||||
|
||||
def test_range_outer_op_after_phi(self):
|
||||
a = Tensor.randn(4, 1).realize()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import functools, itertools, heapq, math, operator
|
|||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition
|
||||
from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
|
|
@ -132,11 +132,6 @@ def reduce_before_expand(reduce, expand, x):
|
|||
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg)
|
||||
return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), i) for i in range(x.dtype.count)), expand.arg)
|
||||
|
||||
def sum_collapse(phi_input, loop, val1, val2):
|
||||
for v1,v2 in [(val1, val2), (val2, val1)]:
|
||||
if loop not in v1.parents: return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+v1*(loop.src[1]-loop.src[0]).cast(v1.dtype)
|
||||
return None
|
||||
|
||||
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None):
|
||||
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
|
||||
if mval.arg >= 0 or loop_start.arg != 0:
|
||||
|
|
@ -175,11 +170,6 @@ constant_folder = PatternMatcher([
|
|||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
# threefry
|
||||
(NOp(UOps.ALU, dtype=dtypes.uint64, src=(NOp.var("x"), NOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
# sum collapse to mul (with possible GEP)
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
# extra arange loop folding because we don't fold adds. TODO: fold adds
|
||||
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
|
||||
NOp.var("idx2") + NOp.var("idx3"))
|
||||
|
|
@ -207,12 +197,9 @@ constant_folder = PatternMatcher([
|
|||
# const rules
|
||||
(NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const(c.arg)),
|
||||
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)),
|
||||
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
||||
(NOp(UOps.PHI, src=(NOp(UOps.DEFINE_ACC, name="acc"), NOp.var("acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
|
||||
(NOp(UOps.PHI, src=(NOp(UOps.DEFINE_ACC, src=(NOp.cvar(),)), NOp.var("x"))), lambda x: x),
|
||||
(NOp(UOps.PHI, src=(NOp.cvar(), NOp.var("x"))), lambda x: x),
|
||||
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
||||
(NOp(UOps.DEFINE_ACC, src=(NOp.cvar(),), name="root"), lambda root: UOp.cast(root.src[0], root.dtype)),
|
||||
# a REDUCE without ranges is a NOOP
|
||||
(NOp(UOps.REDUCE, src=(NOp.var('x'),)), lambda x: x),
|
||||
# GEP on a const is the const
|
||||
(NOp(UOps.GEP, src=(NOp.cvar("x"),), name="root"), lambda root,x: root.const(x.arg)),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(NOp.var().where(NOp.var("val"), NOp.var("val")), lambda val: val),
|
||||
|
|
@ -371,10 +358,17 @@ def do_expand(root:UOp):
|
|||
acc_number = 0
|
||||
def do_reduce(root):
|
||||
global acc_number
|
||||
const = UOp.const(root.dtype.scalar(), 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype))
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,))
|
||||
acc_number += 1
|
||||
return UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, root.src[0])))
|
||||
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].parents)
|
||||
ret = root.src[0]
|
||||
if len(reduce_parented):
|
||||
const = UOp.const(root.dtype.scalar(), 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype))
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
|
||||
acc_number += 1
|
||||
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
|
||||
# for MAX, we can just ignore the unparented
|
||||
if root.arg is BinaryOps.ADD:
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype)
|
||||
return ret
|
||||
|
||||
def do_contract(con:UOp):
|
||||
ex = con.src[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue