cleanup test_conv2d_ceildiv_edge_case [pr] (#10317)

This commit is contained in:
chenyu 2025-05-15 11:35:28 +08:00 committed by GitHub
commit f6cf25fce4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,9 +1,8 @@
import unittest
from tinygrad import Variable
from tinygrad import Tensor, Variable
from tinygrad.shape.shapetracker import View
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.ops import UOp
from tinygrad.ops import sym_infer
from examples.gpt2 import Attention
import numpy as np
@ -224,16 +223,16 @@ class TestSymbolicOps(unittest.TestCase):
@unittest.expectedFailure
def test_conv2d_ceildiv_edge_case(self):
def eval_uops(a): return a.sym_infer(dict(v.unbind() for v in a.vars()))
v = Variable('qwe', 11, 50_000).bind(39601)
x = Tensor.randn(1, 22, 39601).reshape(1, 22, v)
v = Variable('v', 11, 50_000)
val = 39601
x = Tensor.randn(1, 22, 39601).reshape(1, 22, v.bind(val))
weight = Tensor.randn(256, 22, 12)
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
shape = tuple(eval_uops(i) if isinstance(i, UOp) else i for i in result.shape)
self.assertEqual(shape, (1, 256, 6600)) # fails if ceildiv is incorrect
var_val = {v: val}
shape = tuple(sym_infer(s, var_val) for s in result.shape)
self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect
# TODO: test output is correct
if __name__ == '__main__':
unittest.main()