mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Simplify openpilot kernel (#2460)
* a conditional with the same results either way is a noop * add unit test
This commit is contained in:
parent
61a80a0675
commit
262cd26d28
2 changed files with 18 additions and 2 deletions
|
|
@ -2,8 +2,10 @@ import numpy as np
|
|||
import unittest, os
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps
|
||||
from tinygrad.ops import Compiled, Device, LoadOps
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps
|
||||
from tinygrad.ops import BufferOps, Compiled, ConstBuffer, Device, LazyOp, LoadOps, TernaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.realize import run_schedule
|
||||
|
|
@ -117,6 +119,19 @@ class TestLinearizer(unittest.TestCase):
|
|||
lin = Linearizer(sched[0].ast)
|
||||
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
def test_simplify_uop(self):
|
||||
def helper_test_simplify(uop, dtype, vin, arg=None):
|
||||
lin = Linearizer(ast=LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=42, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),))))) # this is a dummy ast
|
||||
lin.uops = []
|
||||
return lin.uop(uop, dtype, vin, arg, cachable=False)
|
||||
|
||||
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
|
||||
assert helper_test_simplify(UOps.ALU, dtypes.bool, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c0), arg=TernaryOps.WHERE) == c0
|
||||
|
||||
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0)
|
||||
assert helper_test_simplify(UOps.ALU, dtypes.bool, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1), arg=TernaryOps.WHERE).uop == UOps.ALU
|
||||
|
||||
def helper_realized_ast(r:Tensor):
|
||||
s = r.lazydata.schedule()
|
||||
run_schedule(s[:-1]) # run all kernels except the last one
|
||||
|
|
|
|||
|
|
@ -459,6 +459,7 @@ class Linearizer(Kernel):
|
|||
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before)
|
||||
# constant folding
|
||||
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
|
||||
if arg == TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
||||
# zero folding
|
||||
for x in [0,1]:
|
||||
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue