hlb cifar touchups (#2113)

* types and cnt and EVAL_STEPS

* eval time + always print eval
This commit is contained in:
George Hotz 2023-10-18 16:26:15 -07:00 committed by GitHub
commit 5cfec59abc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 14 deletions

View file

@ -172,7 +172,7 @@ def train_cifar():
return X
# return a binary mask in the format of BS x C x H x W where H x W contains a random square mask
def make_square_mask(shape, mask_size):
def make_square_mask(shape, mask_size) -> Tensor:
is_even = int(mask_size % 2 == 0)
center_max = shape[-2]-mask_size//2-is_even
center_min = mask_size//2-is_even
@ -185,11 +185,10 @@ def train_cifar():
mask = d_y * d_x
return mask
def random_crop(X, crop_size=32):
def random_crop(X:Tensor, crop_size=32):
mask = make_square_mask(X.shape, crop_size)
mask = mask.repeat((1,3,1,1))
X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)])
return X_cropped.reshape((-1, 3, crop_size, crop_size))
def cutmix(X:Tensor, Y:Tensor, mask_size=3):
@ -206,7 +205,7 @@ def train_cifar():
# the operations that remain inside batch fetcher is the ones that involves random operations
def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool):
step = 0
step, cnt = 0, 0
while True:
st = time.monotonic()
X, Y = X_in, Y_in
@ -218,7 +217,7 @@ def train_cifar():
if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
X, Y = X.numpy(), Y.numpy()
et = time.monotonic()
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms")
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})")
for i in range(0, X.shape[0], BS):
# pad the last batch
batch_end = min(i+BS, Y.shape[0])
@ -226,7 +225,7 @@ def train_cifar():
y = Tensor(Y[order[batch_end-BS:batch_end]])
step += 1
yield x, y
cnt += 1
if not is_train: break
transform = [
@ -349,13 +348,13 @@ def train_cifar():
model_ema: Optional[modelEMA] = None
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
best_eval = -1
i = 0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train():
st = time.monotonic()
while i <= STEPS:
if i%getenv("EVAL_STEPS", 100) == 0 and i > 1:
if i%getenv("EVAL_STEPS", STEPS) == 0 and i > 1:
st_eval = time.monotonic()
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
corrects = []
corrects_ema = []
@ -399,10 +398,8 @@ def train_cifar():
if rank == 0:
acc = correct_sum/correct_len*100.0
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
if acc > best_eval:
best_eval = acc
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i}")
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
if STEPS == 0 or i==STEPS: break
X, Y = next(batcher)

View file

@ -48,7 +48,7 @@ class Conv2d:
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
def __call__(self, x):
def __call__(self, x:Tensor):
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
@ -61,7 +61,7 @@ class ConvTranspose2d(Conv2d):
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.output_padding = output_padding
def __call__(self, x):
def __call__(self, x:Tensor):
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))