optim flatten().shape[0] is numel (#10935)

This commit is contained in:
chenyu 2025-06-23 13:11:19 -04:00 committed by GitHub
commit 785b4ea8ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -21,7 +21,7 @@ class Optimizer:
# store lr in at least float32 precision
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.flatten().shape[0], initial=0))
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
def _new_optim_param(self) -> list[Tensor]:
param_dtype = getenv("OPTIM_DTYPE", "float32")