mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
model/test fix that failed with WEBGPU=1 DEBUG=2 (#14706)
This commit is contained in:
parent
10c94d2c2d
commit
557134e1c7
2 changed files with 3 additions and 3 deletions
|
|
@ -150,7 +150,7 @@ class ResNet:
|
|||
continue # Skip FC if transfer learning
|
||||
|
||||
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||
obj.assign(dat.to(obj.device).reshape(obj.shape))
|
||||
obj.assign(dat.to(obj.device).cast(obj.dtype).reshape(obj.shape))
|
||||
|
||||
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
||||
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
||||
|
|
|
|||
|
|
@ -113,13 +113,13 @@ class TestIdxUpcast(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_int64_unsupported_overflow_sym(self):
|
||||
with self.assertRaises(KeyError):
|
||||
with self.assertRaises((KeyError, RuntimeError)):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
@unittest.expectedFailure # bug in gpu dims limiting
|
||||
def test_int64_unsupported_overflow(self):
|
||||
with self.assertRaises(KeyError):
|
||||
with self.assertRaises((KeyError, RuntimeError)):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, 2048)
|
||||
|
||||
@unittest.skip("This is kept for reference, it requires large memory to run")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue