mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
hotfix: examples/transformer.py
This commit is contained in:
parent
145718a90f
commit
ae83733431
2 changed files with 6 additions and 4 deletions
|
|
@ -28,7 +28,8 @@ if __name__ == "__main__":
|
|||
lr = 0.003
|
||||
for i in range(10):
|
||||
optim = Adam(get_parameters(model), lr=lr)
|
||||
train(model, X_train, Y_train, optim, 50, BS=64)
|
||||
# TODO: BUG! why doesn't the JIT work here?
|
||||
train(model, X_train, Y_train, optim, 50, BS=64, allow_jit=False)
|
||||
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True)
|
||||
lr /= 1.2
|
||||
print(f'reducing lr to {lr:.4f}')
|
||||
|
|
@ -37,6 +38,6 @@ if __name__ == "__main__":
|
|||
for k in range(len(Y_test_preds)):
|
||||
if (Y_test_preds[k] != Y_test[k]).any():
|
||||
wrong+=1
|
||||
a,b,c,x = X_test[k,:2], X_test[k,2:4], Y_test[k,-3:], Y_test_preds[k,-3:]
|
||||
a,b,c,x = X_test[k,:2].astype(np.int32), X_test[k,2:4].astype(np.int32), Y_test[k,-3:].astype(np.int32), Y_test_preds[k,-3:].astype(np.int32)
|
||||
print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})')
|
||||
print(f'Wrong predictions: {wrong}, acc = {acc:.4f}')
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@ from tinygrad.jit import TinyJit
|
|||
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
|
||||
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
|
||||
transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True):
|
||||
|
||||
@TinyJit
|
||||
def train_step(x, y):
|
||||
# network
|
||||
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
||||
|
|
@ -22,6 +21,8 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
|
|||
accuracy = (cat == y).mean()
|
||||
return loss.realize(), accuracy.realize()
|
||||
|
||||
if allow_jit: train_step = TinyJit(train_step)
|
||||
|
||||
with Tensor.train():
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=CI)):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue