assert view degrade to const tests post scheduler graph_rewrite [pr] (#7822)

* assert view degrade to const tests post scheduler graph_rewrite [pr]

* low pri, probably tricky, todo
This commit is contained in:
qazal 2024-11-21 06:00:41 -05:00 committed by GitHub
commit e378aeb94e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 22 deletions

View file

@ -113,26 +113,5 @@ class TestReduceOp(unittest.TestCase):
for s in sched:
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
class TestView(unittest.TestCase):
def test_all_masked_out(self):
# start with non CONST Ops
a = Tensor.rand(10, 10)
assert a.lazydata.base.op is not Ops.CONST
# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
# mask out dim = 1 works too
b = a.pad((None, (0, 10)))[:, 10:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
# partial masked out does not degrade into CONST
b = a.pad(((0, 5), None))[5:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is not Ops.CONST
if __name__ == "__main__":
unittest.main()

View file

@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from extra.models.llama import precompute_freqs_cis
@ -1864,5 +1864,40 @@ class TestSwizzle(unittest.TestCase):
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
class TestView(unittest.TestCase):
def test_all_masked_out(self):
# start with non CONST Ops
a = Tensor.rand(10, 10).realize()
# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
sched = check_schedule(b.contiguous(), 1)
# TODO: this VALID can clean up, where do we need st?
self.assertIs(store_val(sched[-1]), UOp(Ops.VALID, dtypes.bool, (b.lazydata.st.to_uop(),)).where(x:=UOp.const(b.dtype, 0), x))
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
def test_mask_dim_1(self):
# mask out dim = 1 works too
a = Tensor.rand(10, 10).realize()
b = a.pad((None, (0, 10)))[:, 10:]
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
self.assertEqual(sched[-1].ast.full_shape, (10, 10))
self.assertIs(store_val(sched[-1]), UOp(Ops.VALID, dtypes.bool, (b.lazydata.st.to_uop(),)).where(x:=UOp.const(b.dtype, 0), x))
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
def test_partial_mask(self):
# partial masked out does not degrade into CONST
a = Tensor.rand(10, 10).realize()
b = a.pad(((0, 5), None))[5:]
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
self.assertEqual(store_val(sched[-1]).op, Ops.LOAD)
self.assertEqual(store_val(sched[-1]).st_arg, b.lazydata.st)
run_schedule(sched)
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
if __name__ == '__main__':
unittest.main(verbosity=2)