mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
correctness test for multireduce nested locals (#4682)
* nested locals test * move st
This commit is contained in:
parent
bc9be39dec
commit
c5f5755328
1 changed files with 15 additions and 0 deletions
|
|
@ -355,6 +355,21 @@ class TestLinearizer(unittest.TestCase):
|
|||
assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float
|
||||
assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL
|
||||
|
||||
@unittest.skip("multireduce isn't supported yet")
|
||||
def test_upcast_multireduce_nested_local_upcast(self):
|
||||
x, y, z, w = [Tensor.rand(1,128).realize() for _ in range(4)]
|
||||
st0 = ShapeTracker(views=(View(shape=(1, 128, 128), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),))
|
||||
st1 = ShapeTracker(views=(View(shape=(1, 128, 128), strides=(0, 1, 128), offset=0, mask=None, contiguous=False),))
|
||||
ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, st0))
|
||||
ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, st1))
|
||||
ld2 = LazyOp(BufferOps.LOAD, (), MemBuffer(3, dtypes.float, st0))
|
||||
ld3 = LazyOp(BufferOps.LOAD, (), MemBuffer(4, dtypes.float, st1))
|
||||
r0 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.MUL, (ld0, ld1)), ), (2,))
|
||||
r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.MUL, (ld2, ld3)), ), (2,))
|
||||
out_st = ShapeTracker(views=(View(shape=(1, 128, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),))
|
||||
ast = (LazyOp(BufferOps.STORE, (LazyOp(BinaryOps.ADD, (r0, r1)), ), MemBuffer(0, dtypes.float, out_st)),)
|
||||
helper_linearizer_ast(ast, [x, y, z, w])
|
||||
|
||||
def test_zero_fold(self):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack([a, b])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue