continue work on beautiful cifar (#10555)

This commit is contained in:
George Hotz 2025-05-28 21:42:01 -07:00 committed by GitHub
commit e4e7b5d7e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,16 +3,19 @@ start_tm = time.perf_counter()
import math
from typing import Tuple, cast
import numpy as np
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes, Device
from tinygrad.helpers import partition, trange, getenv, Context
from extra.lr_scheduler import OneCycleLR
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
# override tinygrad defaults
dtypes.default_float = dtypes.half
Context(FUSE_ARANGE=1, FUSE_OPTIM=1).__enter__()
# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
batchsize = getenv("BS", 1024)
assert batchsize % len(GPUS) == 0, f"{batchsize=} is not a multiple of {len(GPUS)=}"
bias_scaler = 64
hyp = {
'opt': {
@ -94,8 +97,13 @@ if __name__ == "__main__":
# *** model ***
model = SpeedyConvNet()
state_dict = nn.state.get_state_dict(model)
if len(GPUS) > 1:
cifar10_std.to_(GPUS)
cifar10_mean.to_(GPUS)
for x in state_dict.values(): x.to_(GPUS)
params_bias, params_non_bias = partition(nn.state.get_state_dict(model).items(), lambda x: 'bias' in x[0])
params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0])
opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay'])
opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias)
@ -117,8 +125,12 @@ if __name__ == "__main__":
@TinyJit
@Tensor.train()
def train_step(idxs:Tensor) -> Tensor:
out = model(preprocess(X_train[idxs]))
loss = loss_fn(out, Y_train[idxs])
X, Y = X_train[idxs], Y_train[idxs]
if len(GPUS) > 1:
X.shard_(GPUS, axis=0)
Y.shard_(GPUS, axis=0)
out = model(preprocess(X))
loss = loss_fn(out, Y)
opt.zero_grad()
loss.backward()
return (loss / (batchsize*loss_batchsize_scaler)).realize(*opt.schedule_step(),
@ -130,8 +142,11 @@ if __name__ == "__main__":
def val_step() -> Tuple[Tensor, Tensor]:
loss, acc = [], []
for i in range(0, X_test.size(0), eval_batchsize):
Y = Y_test[i:i+eval_batchsize]
out = model(preprocess(X_test[i:i+eval_batchsize]))
X, Y = X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize]
if len(GPUS) > 1:
X.shard_(GPUS, axis=0)
Y.shard_(GPUS, axis=0)
out = model(preprocess(X))
loss.append(loss_fn(out, Y))
acc.append((out.argmax(-1) == Y).sum() / eval_batchsize)
return Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()