mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix flaky test_disktensor (#16549)
This commit is contained in:
parent
fa400f9790
commit
d18ad49f20
1 changed files with 4 additions and 4 deletions
|
|
@ -203,9 +203,9 @@ class TestSafetensors(TempDirTestCase):
|
|||
"weight_I16": torch.tensor([127, 64], dtype=torch.short),
|
||||
"weight_BF16": torch.randn((2, 2), dtype=torch.bfloat16),
|
||||
}
|
||||
save_file(tensors, self.tmp("dtypes.safetensors"))
|
||||
save_file(tensors, self.tmp("dtypes_torch.safetensors"))
|
||||
|
||||
loaded = safe_load(self.tmp("dtypes.safetensors"))
|
||||
loaded = safe_load(self.tmp("dtypes_torch.safetensors"))
|
||||
for k,v in loaded.items():
|
||||
if v.dtype != dtypes.bfloat16:
|
||||
assert v.numpy().dtype == tensors[k].numpy().dtype
|
||||
|
|
@ -217,9 +217,9 @@ class TestSafetensors(TempDirTestCase):
|
|||
"weight_U32": np.array([1, 2, 3], dtype=np.uint32),
|
||||
"weight_U64": np.array([1, 2, 3], dtype=np.uint64),
|
||||
}
|
||||
np_save_file(tensors, self.tmp("dtypes.safetensors"))
|
||||
np_save_file(tensors, self.tmp("dtypes_numpy.safetensors"))
|
||||
|
||||
loaded = safe_load(self.tmp("dtypes.safetensors"))
|
||||
loaded = safe_load(self.tmp("dtypes_numpy.safetensors"))
|
||||
for k,v in loaded.items():
|
||||
assert v.numpy().dtype == tensors[k].dtype
|
||||
np.testing.assert_allclose(v.numpy(), tensors[k])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue