mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
zero_grad there to match readme
This commit is contained in:
parent
c63f950348
commit
97fd9c1237
1 changed files with 1 additions and 1 deletions
|
|
@ -56,7 +56,6 @@ class TinyConvNet:
|
|||
def train(model, optim, steps, BS=128, gpu=False):
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
|
||||
optim.zero_grad()
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
|
||||
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)
|
||||
|
|
@ -71,6 +70,7 @@ def train(model, optim, steps, BS=128, gpu=False):
|
|||
|
||||
# NLL loss function
|
||||
loss = out.mul(y).mean()
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue