tinygrad/test/test_masked_tensor.py
2026-02-03 00:11:24 +08:00

29 lines
750 B
Python

import unittest
from tinygrad.tensor import Tensor
class TestMaskedTensor(unittest.TestCase):
def test_mul_masked(self):
a = Tensor([1,1,1,1,1])
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
def test_mul_both_masked(self):
a = Tensor([1,1]).pad(((0,3),))
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
def test_add_masked(self):
a = Tensor([1,1]).pad(((0,2),))
b = Tensor([1,1]).pad(((0,2),))
c = a+b
ret = c.data()
assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]
if __name__ == '__main__':
unittest.main()