mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
assert view degrade to const tests post scheduler graph_rewrite [pr] (#7822)
* assert view degrade to const tests post scheduler graph_rewrite [pr] * low pri, probably tricky, todo
This commit is contained in:
parent
cdc431803f
commit
e378aeb94e
2 changed files with 36 additions and 22 deletions
|
|
@ -113,26 +113,5 @@ class TestReduceOp(unittest.TestCase):
|
|||
for s in sched:
|
||||
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
|
||||
class TestView(unittest.TestCase):
|
||||
def test_all_masked_out(self):
|
||||
# start with non CONST Ops
|
||||
a = Tensor.rand(10, 10)
|
||||
assert a.lazydata.base.op is not Ops.CONST
|
||||
|
||||
# all masked out, degrades to const 0
|
||||
b = a.pad(((0, 10), None))[10:]
|
||||
assert b.shape == (10, 10)
|
||||
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
|
||||
|
||||
# mask out dim = 1 works too
|
||||
b = a.pad((None, (0, 10)))[:, 10:]
|
||||
assert b.shape == (10, 10)
|
||||
assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0
|
||||
|
||||
# partial masked out does not degrade into CONST
|
||||
b = a.pad(((0, 5), None))[5:]
|
||||
assert b.shape == (10, 10)
|
||||
assert b.lazydata.base.op is not Ops.CONST
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
|||
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left
|
||||
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
|
||||
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
|
@ -1864,5 +1864,40 @@ class TestSwizzle(unittest.TestCase):
|
|||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
|
||||
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
|
||||
class TestView(unittest.TestCase):
|
||||
def test_all_masked_out(self):
|
||||
# start with non CONST Ops
|
||||
a = Tensor.rand(10, 10).realize()
|
||||
# all masked out, degrades to const 0
|
||||
b = a.pad(((0, 10), None))[10:]
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
# TODO: this VALID can clean up, where do we need st?
|
||||
self.assertIs(store_val(sched[-1]), UOp(Ops.VALID, dtypes.bool, (b.lazydata.st.to_uop(),)).where(x:=UOp.const(b.dtype, 0), x))
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(b.numpy(), 0)
|
||||
|
||||
def test_mask_dim_1(self):
|
||||
# mask out dim = 1 works too
|
||||
a = Tensor.rand(10, 10).realize()
|
||||
b = a.pad((None, (0, 10)))[:, 10:]
|
||||
assert b.shape == (10, 10)
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
self.assertEqual(sched[-1].ast.full_shape, (10, 10))
|
||||
self.assertIs(store_val(sched[-1]), UOp(Ops.VALID, dtypes.bool, (b.lazydata.st.to_uop(),)).where(x:=UOp.const(b.dtype, 0), x))
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(b.numpy(), 0)
|
||||
|
||||
def test_partial_mask(self):
|
||||
# partial masked out does not degrade into CONST
|
||||
a = Tensor.rand(10, 10).realize()
|
||||
b = a.pad(((0, 5), None))[5:]
|
||||
assert b.shape == (10, 10)
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
self.assertEqual(store_val(sched[-1]).op, Ops.LOAD)
|
||||
self.assertEqual(store_val(sched[-1]).st_arg, b.lazydata.st)
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue