remove const without STACK (#16639)

* remove const without STACK

* fix GEP rewrite

* fix null tests

* fix openpilot regression

* it's 10 in CI
This commit is contained in:
George Hotz 2026-06-16 21:25:42 -07:00 committed by GitHub
commit d631716858
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 15 additions and 14 deletions

View file

@ -378,7 +378,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness)
run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx

View file

@ -46,9 +46,9 @@ class TestGraphRewriteConst(unittest.TestCase):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = UOp.const(dtypes.int.vec(3), (2,1,0))
ret = graph_rewrite(v1+v2, sym)
self.assertEqual(ret.op, Ops.CONST)
self.assertEqual(ret.op, Ops.STACK)
self.assertEqual(ret.dtype, dtypes.int.vec(3))
self.assertEqual(ret.arg, 2)
self.assertEqual(const_values(ret), (2,2,2))
def xfail_broken_const_wraparound(fn):
fn = pytest.mark.xfail(reason="const folding does not properly implement modular arithmetic")(fn)

View file

@ -355,13 +355,13 @@ class TestUOpRender(unittest.TestCase):
self.assertEqual(u.render(), "{}")
def test_render_vectorize_same(self):
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
self.assertEqual(u.render(simplify=False), "{0, ...}")
self.assertEqual(u.render(simplify=False), "{0,0,0}")
def test_render_vectorize_different(self):
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
self.assertEqual(u.render(simplify=False), "{0,1,2}")
def test_render_vectorize_same_simplified(self):
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
self.assertEqual(u.render(), "0")
self.assertEqual(u.render(), "{0,0,0}")
def test_render_vectorize_different_simplified(self):
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
self.assertEqual(u.render(), "{0,1,2}")

View file

@ -31,7 +31,7 @@ pm_index_is_shrink = PatternMatcher([
UOp(Ops.SHRINK, dtype=x.dtype.base, src=(buf, idx, UOp.const(dtypes.int, x.dtype.count))) \
if isinstance(buf.dtype, PtrDType) and x.dtype.count > 1 else None),
# rewrite GEP to INDEX
(UPat(Ops.GEP, name="x"), lambda x: x.replace(op=Ops.INDEX, src=x.src+(UOp.const(dtypes.int, x.arg),), arg=None)),
(UPat(Ops.GEP, name="x"), lambda x: x.replace(op=Ops.INDEX, src=x.src+(UOp.const(dtypes.int, x.arg if len(x.arg) > 1 else x.arg[0]),), arg=None)),
])
pm_remove_vec_dtypes = PatternMatcher([

View file

@ -63,7 +63,9 @@ load_store_indexing = PatternMatcher([
def expand_index(ctx, buf:UOp, vec:UOp):
# determine optimal image shapes
if isinstance(dt:=buf.dtype, ImageDType):
x, valid = vec.get_idx().gep(0), vec.get_valid().gep(0)
idxs, valids = vec.get_idx(), vec.get_valid()
lane = next((i for i in range(vec.dtype.count) if valids.gep(i).vmax != 0), 0)
x, valid = idxs.gep(lane), valids.gep(lane)
# search for dims that drop the most valid statements
best_drop, cands = -1, []
for ch, cw in ImageDType.valid_dims(dt, ctx.target.arch):

View file

@ -554,9 +554,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
@staticmethod
def const(dtype:DType, b:ConstLike, shape:tuple[sint, ...]|None=None):
if isinstance(b, UOp): return b.cast(dtype)
if isinstance(b, tuple) and all_same(b):
assert len(b) > 0, "can't create const from empty tuple"
b = b[0] # doesn't have to be a STACK if they are all the same
# NOTE: it always has to be STACK now, even if they are all the same
if isinstance(b, tuple):
stk = [UOp(Ops.CONST, dtype.scalar(), arg=dtype.const(c), src=()) for c in b]
ret = UOp.vectorize(*stk)
@ -583,9 +581,11 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
return cond.where(self.cast(dtypes.weakint), UOp.invalid(self.dtype.count))
def get_idx(self) -> UOp:
assert self.dtype.scalar() is dtypes.weakint, "Can only call get_idx on index dtype"
if self.op is Ops.STACK: return UOp.vectorize(*(x.get_idx() for x in self.src))
return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self
def get_valid(self) -> UOp:
assert self.dtype.scalar() is dtypes.weakint, "Can only call get_valid on index dtype"
if self.op is Ops.STACK: return UOp.vectorize(*(x.get_valid() for x in self.src))
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
def reduce(self, *src:UOp, **kwargs):
arg = kwargs.pop('arg', None)

View file

@ -2,7 +2,7 @@ from typing import cast
from tinygrad.dtype import dtypes, Invalid
from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
from tinygrad.helpers import strip_parens, all_same
from tinygrad.helpers import strip_parens
def pretty_print(x:UOp, cache=None, d=0)->str:
def dfs(x:UOp, cache:dict):
@ -50,8 +50,7 @@ renderer = PatternMatcher([
(UPat(Ops.CMOD, name="x"), lambda ctx,x: f"cmod({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(set(syms.keys()), name="x"), lambda ctx,x: strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat((Ops.INDEX, Ops.STAGE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
(UPat(Ops.STACK, name="x"),
lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}" if not x.src or not all_same(x.src) else f"{{{ctx[x.src[0]]}, ...}}"),
(UPat(Ops.STACK, name="x"), lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}"),
(UPat(GroupOp.All, name="x"), lambda x: str(x)),
])

View file

@ -215,7 +215,7 @@ spec_program = PatternMatcher([
lambda x: False if x.dtype.count > 1 and (x.dtype.count,) != x.shape else None),
# STACK/GEP in program. TODO: this should match Tensor
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 or len(x.src) == 0),
(UPat(Ops.STACK), lambda: True),
# if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True),