correctness test for multireduce nested locals (#4682)

* nested locals test

* move st
This commit is contained in:
qazal 2024-05-23 00:35:35 +08:00 committed by GitHub
commit c5f5755328
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -355,6 +355,21 @@ class TestLinearizer(unittest.TestCase):
assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float
assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL
@unittest.skip("multireduce isn't supported yet")
def test_upcast_multireduce_nested_local_upcast(self):
x, y, z, w = [Tensor.rand(1,128).realize() for _ in range(4)]
st0 = ShapeTracker(views=(View(shape=(1, 128, 128), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),))
st1 = ShapeTracker(views=(View(shape=(1, 128, 128), strides=(0, 1, 128), offset=0, mask=None, contiguous=False),))
ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, st0))
ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, st1))
ld2 = LazyOp(BufferOps.LOAD, (), MemBuffer(3, dtypes.float, st0))
ld3 = LazyOp(BufferOps.LOAD, (), MemBuffer(4, dtypes.float, st1))
r0 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.MUL, (ld0, ld1)), ), (2,))
r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.MUL, (ld2, ld3)), ), (2,))
out_st = ShapeTracker(views=(View(shape=(1, 128, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),))
ast = (LazyOp(BufferOps.STORE, (LazyOp(BinaryOps.ADD, (r0, r1)), ), MemBuffer(0, dtypes.float, out_st)),)
helper_linearizer_ast(ast, [x, y, z, w])
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack([a, b])