mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
momentum support in SGD
This commit is contained in:
parent
0a2035e015
commit
9152bb5b4a
2 changed files with 13 additions and 6 deletions
|
|
@ -87,7 +87,8 @@ def train_cifar():
|
|||
if getenv("ADAM"):
|
||||
optimizer = optim.Adam(get_parameters(model), lr=3e-4)
|
||||
else:
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.001)
|
||||
#optimizer = optim.SGD(get_parameters(model), lr=0.001)
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.001, momentum=0.85, nesterov=True)
|
||||
|
||||
# 97 steps in 2 seconds = 20ms / step
|
||||
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
|
||||
|
|
|
|||
|
|
@ -28,15 +28,21 @@ class Optimizer:
|
|||
p.realize()
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params : List[Tensor], lr=0.001):
|
||||
def __init__(self, params : List[Tensor], lr=0.001, momentum=0, nesterov=False):
|
||||
super().__init__(params)
|
||||
self.lr = lr
|
||||
self.lr, self.momentum, self.nesterov = lr, momentum, nesterov
|
||||
self.b = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params] if self.momentum else []
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
|
||||
def step(self) -> None:
|
||||
for t in self.params:
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
t.assign(t.detach() - t.grad * self.lr)
|
||||
self.realize()
|
||||
g = t.grad
|
||||
if self.momentum:
|
||||
self.b[i].assign(self.momentum * self.b[i] + g)
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
t.assign(t.detach() - g * self.lr)
|
||||
self.realize(self.b)
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
def __init__(self, params : List[Tensor], lr=0.001, decay=0.9, eps=1e-8):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue