model/test fix that failed with WEBGPU=1 DEBUG=2 (#14706)

This commit is contained in:
chenyu 2026-02-12 09:08:16 -05:00 committed by GitHub
commit 557134e1c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View file

@ -150,7 +150,7 @@ class ResNet:
continue # Skip FC if transfer learning
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
obj.assign(dat.to(obj.device).reshape(obj.shape))
obj.assign(dat.to(obj.device).cast(obj.dtype).reshape(obj.shape))
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)

View file

@ -113,13 +113,13 @@ class TestIdxUpcast(unittest.TestCase):
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
def test_int64_unsupported_overflow_sym(self):
with self.assertRaises(KeyError):
with self.assertRaises((KeyError, RuntimeError)):
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
@unittest.expectedFailure # bug in gpu dims limiting
def test_int64_unsupported_overflow(self):
with self.assertRaises(KeyError):
with self.assertRaises((KeyError, RuntimeError)):
self.do_op_then_assert(dtypes.long, 2048, 2048, 2048)
@unittest.skip("This is kept for reference, it requires large memory to run")