mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
optim flatten().shape[0] is numel (#10935)
This commit is contained in:
parent
ac39f27ae6
commit
785b4ea8ac
1 changed files with 1 additions and 1 deletions
|
|
@ -21,7 +21,7 @@ class Optimizer:
|
|||
# store lr in at least float32 precision
|
||||
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
|
||||
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
|
||||
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.flatten().shape[0], initial=0))
|
||||
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
|
||||
|
||||
def _new_optim_param(self) -> list[Tensor]:
|
||||
param_dtype = getenv("OPTIM_DTYPE", "float32")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue