Simplify openpilot kernel (#2460)

* a conditional with the same results either way is a noop

* add unit test
This commit is contained in:
qazal 2023-11-27 13:02:27 -05:00 committed by GitHub
commit 262cd26d28
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 2 deletions

View file

@ -2,8 +2,10 @@ import numpy as np
import unittest, os
from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores
from tinygrad.codegen.linearizer import Linearizer, UOps
from tinygrad.ops import Compiled, Device, LoadOps
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps
from tinygrad.ops import BufferOps, Compiled, ConstBuffer, Device, LazyOp, LoadOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.realize import run_schedule
@ -117,6 +119,19 @@ class TestLinearizer(unittest.TestCase):
lin = Linearizer(sched[0].ast)
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
def test_simplify_uop(self):
def helper_test_simplify(uop, dtype, vin, arg=None):
lin = Linearizer(ast=LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=42, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),))))) # this is a dummy ast
lin.uops = []
return lin.uop(uop, dtype, vin, arg, cachable=False)
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
assert helper_test_simplify(UOps.ALU, dtypes.bool, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c0), arg=TernaryOps.WHERE) == c0
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0)
assert helper_test_simplify(UOps.ALU, dtypes.bool, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1), arg=TernaryOps.WHERE).uop == UOps.ALU
def helper_realized_ast(r:Tensor):
s = r.lazydata.schedule()
run_schedule(s[:-1]) # run all kernels except the last one

View file

@ -459,6 +459,7 @@ class Linearizer(Kernel):
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before)
# constant folding
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
if arg == TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
# zero folding
for x in [0,1]:
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]