mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
indexing an after with all fully invalid stores is invalid (#16643)
* indexing an after with all fully invalid stores is invalid * typing cast
This commit is contained in:
parent
0f0c622086
commit
1acc40600d
2 changed files with 33 additions and 1 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, UOp, GlobalCounters, Context, Device
|
||||
from tinygrad.dtype import AddrSpace, dtypes
|
||||
from tinygrad.dtype import AddrSpace, dtypes, Invalid
|
||||
from tinygrad.uop.ops import KernelInfo, AxisType, Ops
|
||||
|
||||
# **** kernels ****
|
||||
|
|
@ -328,6 +328,27 @@ class TestCustomKernel(unittest.TestCase):
|
|||
if prg.op is not Ops.PROGRAM: continue
|
||||
self.assertTrue(len(prg.arg.globals) > 0, f"empty kernel compiled (no globals): name={prg.arg.name}")
|
||||
|
||||
def test_multi_invalids_custom_kernel_no_copy(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
a = Tensor.ones(4, 4).shard(devs, axis=0).realize()
|
||||
c = Tensor(UOp.const(dtypes.float, Invalid, shape=(2, 4)).clone(device=devs).multi(0), device=devs)
|
||||
c = Tensor.custom_kernel(c, a, fxn=custom_add_one_kernel)[0]
|
||||
GlobalCounters.reset()
|
||||
c.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, len(devs))
|
||||
self.assertTrue((c == 2).all().item())
|
||||
|
||||
def test_partial_invalid_store_keeps_uncovered_reads(self):
|
||||
x = Tensor([10., 20., 30., 40.])
|
||||
after = x.uop.after(x.uop.shrink(((0, 2),)).store(UOp.const(dtypes.float, Invalid, shape=(2,))))
|
||||
self.assertEqual(Tensor(after).contiguous().tolist(), [10., 20., 30., 40.])
|
||||
|
||||
def test_expand_view_invalid_assign_keeps_uncovered_reads(self):
|
||||
x = Tensor([[10., 11., 12., 13.], [20., 21., 22., 23.], [30., 31., 32., 33.], [40., 41., 42., 43.]]).realize()
|
||||
v = x[:1, :].expand(4, 4)
|
||||
v.assign(Tensor.invalids(4, 4, dtype=dtypes.float))
|
||||
self.assertEqual(v.contiguous().tolist(), [[10., 11., 12., 13.]]*4)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "kernel timing not supported")
|
||||
def test_invalids_into_custom_kernel_with_beam(self):
|
||||
a = Tensor.full((4, 4), 3.).contiguous()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass, field, replace
|
||||
from typing import cast
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, ParamArg
|
||||
|
|
@ -313,6 +314,13 @@ def remove_noop_bufferize(idx,b2):
|
|||
if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.SLICE: return None
|
||||
return idx.src[0].shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0]
|
||||
|
||||
def after_all_invalid(after:UOp):
|
||||
buf = after.src[0].buf_uop
|
||||
# check all ranges are used (no expand), and same size (no pad and shrink)
|
||||
return all(s.op is Ops.END and (st:=s.src[0]).op is Ops.STORE and st.src[1].base.arg is Invalid and st.src[0].buf_uop is buf
|
||||
and all(r in st.src[0].ranges for r in s.ended_ranges)
|
||||
and resolve(cast(UOp, prod(r.src[0] for r in s.ended_ranges)).eq(buf.numel()), False) for s in after.src[1:])
|
||||
|
||||
pm_const_buffer_folding = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.STAGE, name="b"), cleanup_dead_axes),
|
||||
# remove noop buffers. if we look at the next index we can remove even more of these
|
||||
|
|
@ -323,6 +331,9 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
|
|||
(UPat(Ops.CONST, name='c').f(Ops.STAGE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
|
||||
# indexing a const is a const
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
|
||||
# indexing an after with all fully invalid stores is invalid
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.AFTER, name="after"),), allow_any_len=True, name="idx"),
|
||||
lambda idx,after: idx.const_like(Invalid) if after_all_invalid(after) else None),
|
||||
# copy on CONST is CONST
|
||||
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
|
||||
# hack if a noop turned to a const
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue