mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add collapse_view to the scheduler [pr] (#8440)
This commit is contained in:
parent
98b2854f14
commit
a44cd1e6f7
3 changed files with 38 additions and 1 deletions
|
|
@ -1981,6 +1981,39 @@ class TestView(unittest.TestCase):
|
|||
run_schedule(sched)
|
||||
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
|
||||
|
||||
# a*VIEW(x), where VIEW(x) = 0
|
||||
# x collapses along with its children
|
||||
def test_parent_view_collapses(self):
|
||||
a = Tensor([1, 2])
|
||||
b = Tensor.arange(3).contiguous()
|
||||
bv = b.pad(((0, 2),))[-2:]
|
||||
# this becomes a late a*0
|
||||
late_mul = a*bv
|
||||
check_schedule(late_mul, 0)
|
||||
# the arange doesn't realize
|
||||
self.assertIsNone(b.lazydata.base.realized)
|
||||
# mul doesn't realize
|
||||
self.assertIsNone(late_mul.lazydata.base.realized)
|
||||
self.assertEqual(late_mul.tolist(), [0, 0])
|
||||
|
||||
# SINK has two branches:
|
||||
# a*VIEW(x), where VIEW(x) = 0
|
||||
# x+2
|
||||
# as long as one child realizes, x does not collapse
|
||||
def test_parent_multiple_children_no_collapse(self):
|
||||
a = Tensor([1, 2])
|
||||
b = Tensor.arange(3).contiguous()
|
||||
bv = b.pad(((0, 2),))[-2:]
|
||||
late_mul = a*bv
|
||||
other_child = b+2
|
||||
s = check_schedule([late_mul, other_child], 2)
|
||||
# the arange realizes
|
||||
self.assertIsNotNone(b.lazydata.base.realized)
|
||||
# mul still collapses
|
||||
self.assertIsNone(late_mul.lazydata.base.realized)
|
||||
run_schedule(s)
|
||||
self.assertEqual(other_child.tolist(), [2, 3, 4])
|
||||
|
||||
def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, remove_movement_ops+symbolic)
|
||||
class TestBigGraph(unittest.TestCase):
|
||||
def test_sink_childless_const(self):
|
||||
|
|
|
|||
|
|
@ -553,8 +553,13 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
|||
buf_uop.buffer.ref(1)
|
||||
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
||||
|
||||
# **** movement ops
|
||||
|
||||
remove_movement_ops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))),
|
||||
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
||||
(UPat(Ops.VIEW, name="view"),
|
||||
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
||||
# merge one src (unrealized) views
|
||||
# NOTE: we can't merge realized buffer views here, because the buffer is realized before the view
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat.var("x"),), name="v1")), name="v2"),
|
||||
|
|
|
|||
|
|
@ -487,7 +487,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st)
|
||||
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
# instant folding rules
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0)
|
||||
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
||||
return ret
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue