better _drop_valid_stmts [pr] (#16719)

also dropped the unused is_increasing
This commit is contained in:
chenyu 2026-06-23 19:35:01 -04:00 committed by GitHub
commit ce87d80911
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 18 additions and 36 deletions

View file

@ -327,7 +327,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=4 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness)
run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx

View file

@ -29,28 +29,6 @@ def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtyp
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(nmax, n)
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
f3 = (idx2*2)+ridx1+(-1)
self.assertFalse(f0.is_increasing())
self.assertTrue(f1.is_increasing())
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid, extra=()):
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
@ -506,6 +484,16 @@ class TestImageSimplification(unittest.TestCase):
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
def test_drop_non_monotonic_window(self):
# two-sided window valid (645 <= gidx0 < 653) on a non-monotonic index (lane split via %4 and //4):
# gidx0 outside the window pushes idx_x out of the (1, 48) image, so the gate is dropped
gidx0 = Special("gidx0", 1064)
r12 = Range(12, 3)
valid = ((gidx0 < 645).ne(True)) & (gidx0 < 653)
idx = (r12*4 + (gidx0+3)%4 + (gidx0+3)//4*24 - 3888, UOp.const(dtypes.weakint, 0))
load = get_load_image_uop((1, 48, 4), valid, idx)
self.check(load, None, "(r12*4+(gidx0+3)%4+(gidx0+3)//4*24+-3888)", "0")
class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True valid gets simplified to drop the valid

View file

@ -14,7 +14,7 @@ from tinygrad.renderer import Renderer
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in valid.split_uop(Ops.AND):
for i,stmt in enumerate(valid.split_uop(Ops.AND)):
if (res:=parse_valid(stmt)) is None: continue
X, is_upper_bound, c = res
@ -25,12 +25,12 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
drop_stmt.append(stmt)
continue
# if X <= c, check if it's out of bound when X = c+1
# if X >= c, check if it's out of bound when X = c-1
test_value = c + 1 if is_upper_bound else c - 1
for i,b in zip(idx.src, (width, height)):
if i.is_increasing():
rw = i.substitute({X:X.const_like(test_value)})
# check if idx is out of bound when X is on the wrong side of the bound: X in [c+1, vmax] or [vmin, c-1]
lo, hi = (c + 1, X.vmax) if is_upper_bound else (X.vmin, c - 1)
if lo <= hi:
fake = UOp.variable(f"fake{i}", lo, hi, X.dtype)
for coord,b in zip(idx.src, (width, height)):
rw = coord.substitute({X:fake}).simplify()
if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt)
break

View file

@ -919,12 +919,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# *** uop symbolic stuff ***
def is_increasing(self:UOp) -> bool:
# is f a monotonically increasing function regards its input
if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
return False # False if not sure
def const_factor(self) -> int:
"""largest known int that divides self"""
# TODO: for negatives it's not the largest