mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
weights_only=False (#8839)
This commit is contained in:
parent
741bbc900d
commit
07d3676019
1 changed files with 1 additions and 1 deletions
|
|
@ -10,7 +10,7 @@ def compare_weights_both(url):
|
|||
import torch
|
||||
fn = fetch(url)
|
||||
tg_weights = get_state_dict(torch_load(fn))
|
||||
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=True), tensor_type=torch.Tensor)
|
||||
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=False), tensor_type=torch.Tensor)
|
||||
assert list(tg_weights.keys()) == list(torch_weights.keys())
|
||||
for k in tg_weights:
|
||||
if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue