mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
update test_conv2d_ceildiv_edge_case (#12779)
This commit is contained in:
parent
442218266d
commit
30ff84d050
1 changed files with 3 additions and 3 deletions
|
|
@ -287,7 +287,6 @@ class TestSymbolicOps(unittest.TestCase):
|
|||
symbolic = symbolic_result[:].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_conv2d_ceildiv_edge_case(self):
|
||||
v = Variable('v', 11, 50_000)
|
||||
val = 39601
|
||||
|
|
@ -295,9 +294,10 @@ class TestSymbolicOps(unittest.TestCase):
|
|||
weight = Tensor.randn(256, 22, 12)
|
||||
|
||||
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
var_val = {v: val}
|
||||
var_val = {v.expr: val}
|
||||
shape = tuple(sym_infer(s, var_val) for s in result.shape)
|
||||
self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect
|
||||
# TODO: test output is correct
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue