llama: fix FP8=1 FAKEDATA=1 (#15564)

This commit is contained in:
qazal 2026-04-01 14:53:03 +03:00 committed by GitHub
commit 09f60d80fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1397,7 +1397,7 @@ def train_llama3():
if getenv("FAKEDATA"):
for v in get_parameters(model):
v = v.assign(Tensor.empty(v.shape))
v = v.assign(Tensor.empty(v.shape, dtype=v.dtype))
is_dp = (DP := getenv("DP", 1)) > 1
is_mp = (MP := getenv("MP", 1)) > 1