fix flaky test_disktensor (#16549)

This commit is contained in:
qazal 2026-06-09 17:23:22 +08:00 committed by GitHub
commit d18ad49f20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -203,9 +203,9 @@ class TestSafetensors(TempDirTestCase):
"weight_I16": torch.tensor([127, 64], dtype=torch.short),
"weight_BF16": torch.randn((2, 2), dtype=torch.bfloat16),
}
save_file(tensors, self.tmp("dtypes.safetensors"))
save_file(tensors, self.tmp("dtypes_torch.safetensors"))
loaded = safe_load(self.tmp("dtypes.safetensors"))
loaded = safe_load(self.tmp("dtypes_torch.safetensors"))
for k,v in loaded.items():
if v.dtype != dtypes.bfloat16:
assert v.numpy().dtype == tensors[k].numpy().dtype
@ -217,9 +217,9 @@ class TestSafetensors(TempDirTestCase):
"weight_U32": np.array([1, 2, 3], dtype=np.uint32),
"weight_U64": np.array([1, 2, 3], dtype=np.uint64),
}
np_save_file(tensors, self.tmp("dtypes.safetensors"))
np_save_file(tensors, self.tmp("dtypes_numpy.safetensors"))
loaded = safe_load(self.tmp("dtypes.safetensors"))
loaded = safe_load(self.tmp("dtypes_numpy.safetensors"))
for k,v in loaded.items():
assert v.numpy().dtype == tensors[k].dtype
np.testing.assert_allclose(v.numpy(), tensors[k])