indexing an after with all fully invalid stores is invalid (#16643)

* indexing an after with all fully invalid stores is invalid

* typing cast
This commit is contained in:
chenyu 2026-06-17 11:06:36 -04:00 committed by GitHub
commit 1acc40600d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 33 additions and 1 deletions

View file

@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor, UOp, GlobalCounters, Context, Device
from tinygrad.dtype import AddrSpace, dtypes
from tinygrad.dtype import AddrSpace, dtypes, Invalid
from tinygrad.uop.ops import KernelInfo, AxisType, Ops
# **** kernels ****
@ -328,6 +328,27 @@ class TestCustomKernel(unittest.TestCase):
if prg.op is not Ops.PROGRAM: continue
self.assertTrue(len(prg.arg.globals) > 0, f"empty kernel compiled (no globals): name={prg.arg.name}")
def test_multi_invalids_custom_kernel_no_copy(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.ones(4, 4).shard(devs, axis=0).realize()
c = Tensor(UOp.const(dtypes.float, Invalid, shape=(2, 4)).clone(device=devs).multi(0), device=devs)
c = Tensor.custom_kernel(c, a, fxn=custom_add_one_kernel)[0]
GlobalCounters.reset()
c.realize()
self.assertEqual(GlobalCounters.kernel_count, len(devs))
self.assertTrue((c == 2).all().item())
def test_partial_invalid_store_keeps_uncovered_reads(self):
x = Tensor([10., 20., 30., 40.])
after = x.uop.after(x.uop.shrink(((0, 2),)).store(UOp.const(dtypes.float, Invalid, shape=(2,))))
self.assertEqual(Tensor(after).contiguous().tolist(), [10., 20., 30., 40.])
def test_expand_view_invalid_assign_keeps_uncovered_reads(self):
x = Tensor([[10., 11., 12., 13.], [20., 21., 22., 23.], [30., 31., 32., 33.], [40., 41., 42., 43.]]).realize()
v = x[:1, :].expand(4, 4)
v.assign(Tensor.invalids(4, 4, dtype=dtypes.float))
self.assertEqual(v.contiguous().tolist(), [[10., 11., 12., 13.]]*4)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "kernel timing not supported")
def test_invalids_into_custom_kernel_with_beam(self):
a = Tensor.full((4, 4), 3.).contiguous()

View file

@ -1,4 +1,5 @@
from dataclasses import dataclass, field, replace
from typing import cast
import itertools
from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, ParamArg
@ -313,6 +314,13 @@ def remove_noop_bufferize(idx,b2):
if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.SLICE: return None
return idx.src[0].shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0]
def after_all_invalid(after:UOp):
buf = after.src[0].buf_uop
# check all ranges are used (no expand), and same size (no pad and shrink)
return all(s.op is Ops.END and (st:=s.src[0]).op is Ops.STORE and st.src[1].base.arg is Invalid and st.src[0].buf_uop is buf
and all(r in st.src[0].ranges for r in s.ended_ranges)
and resolve(cast(UOp, prod(r.src[0] for r in s.ended_ranges)).eq(buf.numel()), False) for s in after.src[1:])
pm_const_buffer_folding = pm_mops+PatternMatcher([
(UPat(Ops.STAGE, name="b"), cleanup_dead_axes),
# remove noop buffers. if we look at the next index we can remove even more of these
@ -323,6 +331,9 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
(UPat(Ops.CONST, name='c').f(Ops.STAGE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
# indexing a const is a const
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
# indexing an after with all fully invalid stores is invalid
(UPat(Ops.INDEX, src=(UPat(Ops.AFTER, name="after"),), allow_any_len=True, name="idx"),
lambda idx,after: idx.const_like(Invalid) if after_all_invalid(after) else None),
# copy on CONST is CONST
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
# hack if a noop turned to a const