mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove const without STACK (#16639)
* remove const without STACK * fix GEP rewrite * fix null tests * fix openpilot regression * it's 10 in CI
This commit is contained in:
parent
36f6d1b064
commit
d631716858
8 changed files with 15 additions and 14 deletions
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
|
@ -378,7 +378,7 @@ jobs:
|
|||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp32 (test correctness)
|
||||
run: |
|
||||
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
|
||||
|
|
|
|||
|
|
@ -46,9 +46,9 @@ class TestGraphRewriteConst(unittest.TestCase):
|
|||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
v2 = UOp.const(dtypes.int.vec(3), (2,1,0))
|
||||
ret = graph_rewrite(v1+v2, sym)
|
||||
self.assertEqual(ret.op, Ops.CONST)
|
||||
self.assertEqual(ret.op, Ops.STACK)
|
||||
self.assertEqual(ret.dtype, dtypes.int.vec(3))
|
||||
self.assertEqual(ret.arg, 2)
|
||||
self.assertEqual(const_values(ret), (2,2,2))
|
||||
|
||||
def xfail_broken_const_wraparound(fn):
|
||||
fn = pytest.mark.xfail(reason="const folding does not properly implement modular arithmetic")(fn)
|
||||
|
|
|
|||
|
|
@ -355,13 +355,13 @@ class TestUOpRender(unittest.TestCase):
|
|||
self.assertEqual(u.render(), "{}")
|
||||
def test_render_vectorize_same(self):
|
||||
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
|
||||
self.assertEqual(u.render(simplify=False), "{0, ...}")
|
||||
self.assertEqual(u.render(simplify=False), "{0,0,0}")
|
||||
def test_render_vectorize_different(self):
|
||||
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
|
||||
self.assertEqual(u.render(simplify=False), "{0,1,2}")
|
||||
def test_render_vectorize_same_simplified(self):
|
||||
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
|
||||
self.assertEqual(u.render(), "0")
|
||||
self.assertEqual(u.render(), "{0,0,0}")
|
||||
def test_render_vectorize_different_simplified(self):
|
||||
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
|
||||
self.assertEqual(u.render(), "{0,1,2}")
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ pm_index_is_shrink = PatternMatcher([
|
|||
UOp(Ops.SHRINK, dtype=x.dtype.base, src=(buf, idx, UOp.const(dtypes.int, x.dtype.count))) \
|
||||
if isinstance(buf.dtype, PtrDType) and x.dtype.count > 1 else None),
|
||||
# rewrite GEP to INDEX
|
||||
(UPat(Ops.GEP, name="x"), lambda x: x.replace(op=Ops.INDEX, src=x.src+(UOp.const(dtypes.int, x.arg),), arg=None)),
|
||||
(UPat(Ops.GEP, name="x"), lambda x: x.replace(op=Ops.INDEX, src=x.src+(UOp.const(dtypes.int, x.arg if len(x.arg) > 1 else x.arg[0]),), arg=None)),
|
||||
])
|
||||
|
||||
pm_remove_vec_dtypes = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -63,7 +63,9 @@ load_store_indexing = PatternMatcher([
|
|||
def expand_index(ctx, buf:UOp, vec:UOp):
|
||||
# determine optimal image shapes
|
||||
if isinstance(dt:=buf.dtype, ImageDType):
|
||||
x, valid = vec.get_idx().gep(0), vec.get_valid().gep(0)
|
||||
idxs, valids = vec.get_idx(), vec.get_valid()
|
||||
lane = next((i for i in range(vec.dtype.count) if valids.gep(i).vmax != 0), 0)
|
||||
x, valid = idxs.gep(lane), valids.gep(lane)
|
||||
# search for dims that drop the most valid statements
|
||||
best_drop, cands = -1, []
|
||||
for ch, cw in ImageDType.valid_dims(dt, ctx.target.arch):
|
||||
|
|
|
|||
|
|
@ -554,9 +554,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
@staticmethod
|
||||
def const(dtype:DType, b:ConstLike, shape:tuple[sint, ...]|None=None):
|
||||
if isinstance(b, UOp): return b.cast(dtype)
|
||||
if isinstance(b, tuple) and all_same(b):
|
||||
assert len(b) > 0, "can't create const from empty tuple"
|
||||
b = b[0] # doesn't have to be a STACK if they are all the same
|
||||
# NOTE: it always has to be STACK now, even if they are all the same
|
||||
if isinstance(b, tuple):
|
||||
stk = [UOp(Ops.CONST, dtype.scalar(), arg=dtype.const(c), src=()) for c in b]
|
||||
ret = UOp.vectorize(*stk)
|
||||
|
|
@ -583,9 +581,11 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
return cond.where(self.cast(dtypes.weakint), UOp.invalid(self.dtype.count))
|
||||
def get_idx(self) -> UOp:
|
||||
assert self.dtype.scalar() is dtypes.weakint, "Can only call get_idx on index dtype"
|
||||
if self.op is Ops.STACK: return UOp.vectorize(*(x.get_idx() for x in self.src))
|
||||
return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self
|
||||
def get_valid(self) -> UOp:
|
||||
assert self.dtype.scalar() is dtypes.weakint, "Can only call get_valid on index dtype"
|
||||
if self.op is Ops.STACK: return UOp.vectorize(*(x.get_valid() for x in self.src))
|
||||
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
||||
def reduce(self, *src:UOp, **kwargs):
|
||||
arg = kwargs.pop('arg', None)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import cast
|
|||
from tinygrad.dtype import dtypes, Invalid
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
|
||||
from tinygrad.helpers import strip_parens, all_same
|
||||
from tinygrad.helpers import strip_parens
|
||||
|
||||
def pretty_print(x:UOp, cache=None, d=0)->str:
|
||||
def dfs(x:UOp, cache:dict):
|
||||
|
|
@ -50,8 +50,7 @@ renderer = PatternMatcher([
|
|||
(UPat(Ops.CMOD, name="x"), lambda ctx,x: f"cmod({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
|
||||
(UPat(set(syms.keys()), name="x"), lambda ctx,x: strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")),
|
||||
(UPat((Ops.INDEX, Ops.STAGE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
|
||||
(UPat(Ops.STACK, name="x"),
|
||||
lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}" if not x.src or not all_same(x.src) else f"{{{ctx[x.src[0]]}, ...}}"),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}"),
|
||||
(UPat(GroupOp.All, name="x"), lambda x: str(x)),
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ spec_program = PatternMatcher([
|
|||
lambda x: False if x.dtype.count > 1 and (x.dtype.count,) != x.shape else None),
|
||||
|
||||
# STACK/GEP in program. TODO: this should match Tensor
|
||||
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 or len(x.src) == 0),
|
||||
(UPat(Ops.STACK), lambda: True),
|
||||
|
||||
# if has a <gate, index_for_dedup>
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue