Tensor.train -> TRAINING [PR] (#16705)

* Tensor.train -> TRAINING [PR]

* doc
This commit is contained in:
chenyu 2026-06-22 15:13:22 -04:00 committed by GitHub
commit 33b635d23a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 80 additions and 86 deletions

View file

@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
Throw in an optimizer, a data loader, and some compute, and you have all you need.
```python
from tinygrad import Tensor, nn
from tinygrad import Tensor, nn, Context
class LinearNet:
def __init__(self):
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
with Tensor.train():
with Context(TRAINING=1):
for i in range(10):
optim.zero_grad()
loss = model(x).sparse_categorical_crossentropy(y).backward()

View file

@ -165,13 +165,14 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64.
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training.
We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager.
```python
from tinygrad import Context
X_train, Y_train, X_test, Y_test = fetch_mnist()
with Tensor.train():
with Context(TRAINING=1):
for step in range(1000):
# random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64))

View file

@ -1,6 +1,6 @@
from typing import Tuple
import time
from tinygrad import Tensor, TinyJit, nn
from tinygrad import Tensor, TinyJit, nn, Context
import gymnasium as gym
from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import
@ -55,7 +55,7 @@ if __name__ == "__main__":
@TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train():
with Context(TRAINING=1):
log_dist, value = model(x)
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()

View file

@ -122,7 +122,7 @@ if __name__ == "__main__":
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step(idxs:Tensor) -> Tensor:
X, Y = X_train[idxs], Y_train[idxs]
if len(GPUS) > 1:

View file

@ -1,6 +1,6 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
@ -19,7 +19,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])

View file

@ -1,6 +1,6 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
@ -31,7 +31,7 @@ if __name__ == "__main__":
@TinyJit
def train_step() -> Tensor:
with Tensor.train():
with Context(TRAINING=1):
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0

View file

@ -1,6 +1,6 @@
import itertools
from typing import Callable
from tinygrad import nn, Tensor, dtypes, Device, TinyJit
from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
from tinygrad.helpers import getenv, trange, partition
class Model:
@ -59,7 +59,7 @@ if __name__ == "__main__":
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def microbatch():
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
for t in params: t.grad = None

View file

@ -359,7 +359,7 @@ def train_cifar():
i = 0
eval_acc_pct = 0.0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train():
with Context(TRAINING=1):
st = time.monotonic()
while i <= STEPS:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3
import os, math, time
import numpy as np
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
from dataclasses import dataclass
@dataclass
@ -177,7 +177,7 @@ if __name__ == "__main__":
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y)
optimizer.zero_grad()
@ -204,4 +204,3 @@ if __name__ == "__main__":
top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))

View file

@ -1,5 +1,5 @@
# much taken from https://github.com/cloneofsimo/minRF
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@ -135,7 +135,7 @@ if __name__ == "__main__":
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step():
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])

View file

@ -358,7 +358,7 @@ def eval_stable_diffusion():
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
return batch, unpadded_bs
@Tensor.train(mode=False)
@Context(TRAINING=0)
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
# Eval is divided into 5 jits, one per model

View file

@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
from pathlib import Path
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@ -614,7 +614,7 @@ def train_retinanet():
if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False):
with Context(TRAINING=0):
if not RUNMLPERF:
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
else:
@ -784,7 +784,7 @@ def train_unet3d():
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step(model, x, y):
optim.zero_grad()
@ -795,7 +795,7 @@ def train_unet3d():
optim.step()
return loss.realize()
@Tensor.train(mode=False)
@Context(TRAINING=0)
def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y)
@ -1490,7 +1490,7 @@ def train_llama3():
return lr_cpu, grad_norm_cpu
@TinyJit
@Tensor.train(False)
@Context(TRAINING=0)
def eval_step(tokens:Tensor):
if is_dp: tokens = tokens.to(None).shard(device, 0)
if is_mp: tokens = tokens.shard(device)
@ -1803,7 +1803,7 @@ if __name__ == "__main__":
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
else: bench_log_manager = contextlib.nullcontext()
with Tensor.train():
with Context(TRAINING=1):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
nm = f"train_{m}"
if nm in globals():

View file

@ -3,7 +3,7 @@ import torch
from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.helpers import trange
from tinygrad.helpers import trange, Context
from tinygrad.nn import optim
from tinygrad.nn.datasets import mnist
@ -86,7 +86,7 @@ if __name__ == "__main__":
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop
with Tensor.train():
with Context(TRAINING=1):
for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps):

View file

@ -5,7 +5,7 @@
# - symbolic removal
from examples.beautiful_mnist import Model
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
from tinygrad.nn.datasets import mnist
from tinygrad.helpers import trange
@ -26,7 +26,7 @@ if __name__ == "__main__":
X_samp, Y_samp = X_train[samples], Y_train[samples]
print("*** got samples")
with Tensor.train():
with Context(TRAINING=1):
"""
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)

View file

@ -1,6 +1,6 @@
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import trange
from tinygrad.helpers import trange, Context
from tinygrad.engine.jit import TinyJit
@ -22,7 +22,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
if allow_jit: train_step = TinyJit(train_step)
with Tensor.train():
with Context(TRAINING=1):
losses, accuracies = [], []
for i in (t := trange(steps, disable=None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
@ -55,4 +55,3 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
acc, Y_test_pred = numpy_eval(Y_test, num_classes)
print("test set accuracy is %f" % acc)
return (acc, Y_test_pred) if return_predict else acc

View file

@ -25,7 +25,7 @@
import unittest
import numpy as np
import torch
from tinygrad import Tensor, dtypes, nn
from tinygrad import Tensor, dtypes, nn, Context
from tinygrad.device import Device
from tinygrad.helpers import DEV
from tinygrad.renderer.nir import NIRRenderer
@ -101,7 +101,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
# we don't need more of these
def test_dropout_rate_one(self):
with Tensor.train():
with Context(TRAINING=1):
out = Tensor.ones(100).dropout(1.0)
np.testing.assert_allclose(out.numpy(), np.zeros(100))
@ -109,7 +109,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
with self.assertRaises(ValueError):
torch.nn.functional.dropout(torch.ones(10), -0.1, True)
with self.assertRaises(ValueError):
with Tensor.train():
with Context(TRAINING=1):
Tensor.ones(10).dropout(-0.1)
class TestInputValidation(unittest.TestCase):

View file

@ -14,7 +14,7 @@ from test.helpers import not_support_multi_device, needs_second_gpu, slow
@slow
class TestNN(unittest.TestCase):
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
with Tensor.train(training):
with Context(TRAINING=training):
szs = [4, 8, 16, 32]
for sz in szs:
# create in tinygrad

View file

@ -41,7 +41,7 @@ class TestStunning(unittest.TestCase):
X_samp, Y_samp = X_train[samples], Y_train[samples]
vi = Variable('i', 0, samples.shape[0]-1)
with Context(SPLIT_REDUCEOP=0):
with Tensor.train():
with Context(TRAINING=1):
losses = []
for i in range(samples.shape[0]):
vib = vi.bind(i)

View file

@ -1,5 +1,5 @@
import unittest
from tinygrad import Tensor, Variable, GlobalCounters
from tinygrad import Tensor, Variable, GlobalCounters, Context
from tinygrad.uop.ops import sym_infer
from tinygrad.dtype import dtypes
from examples.gpt2 import Attention
@ -63,7 +63,7 @@ class TestSymbolicOps(unittest.TestCase):
self.test_attention(imin=4, imax=5, use_symbolic=True)
def test_attention_training(self):
with Tensor.train():
with Context(TRAINING=1):
self.test_attention(dropout_p=0.0)
with self.assertRaises(ValueError):
# symbolic shape dropout is not supported

View file

@ -1,7 +1,7 @@
import numpy as np
import torch
import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad import Tensor, Device, dtypes, nn, Context
from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
@ -203,7 +203,7 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_dropout(self):
with Tensor.train():
with Context(TRAINING=1):
n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.numpy())

View file

@ -1,6 +1,6 @@
import unittest, os
from tempfile import TemporaryDirectory
from tinygrad import Tensor
from tinygrad import Context
from tinygrad.helpers import getenv
from examples.mlperf.model_train import train_stable_diffusion
@ -14,10 +14,10 @@ class TestTrain(unittest.TestCase):
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
with TemporaryDirectory(prefix="test-train") as tmp:
os.environ["UNET_CKPTDIR"] = tmp
with Tensor.train():
with Context(TRAINING=1):
saved_ckpts = train_stable_diffusion()
expected_ckpt = f"{tmp}/{num_steps}.safetensors"
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
if __name__=="__main__":
unittest.main()
unittest.main()

View file

@ -3,7 +3,7 @@ import ast, pathlib, unittest
import numpy as np
from PIL import Image
from tinygrad import Tensor
from tinygrad import Tensor, Context
from tinygrad.helpers import getenv
from test.helpers import slow
from extra.models.efficientnet import EfficientNet
@ -40,7 +40,7 @@ def preprocess(img, new=False):
return img
def _infer(model: EfficientNet, img):
with Tensor.train(False):
with Context(TRAINING=0):
out = model.forward(Tensor(img)).argmax(axis=-1)
return out.tolist()

View file

@ -5,10 +5,11 @@ import numpy as np
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor
from tinygrad.helpers import Context
from extra.datasets import fetch_mnist
def compare_tiny_torch(model, model_torch, X, Y):
with Tensor.train():
with Context(TRAINING=1):
model_torch.train()
model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters():

View file

@ -106,7 +106,7 @@ class TestRealWorld(unittest.TestCase):
@slow
def test_train_mnist(self):
from examples.beautiful_mnist import Model
with Tensor.train():
with Context(TRAINING=1):
model = Model()
optimizer = optim.Adam(get_parameters(model))
BS = 32
@ -125,7 +125,7 @@ class TestRealWorld(unittest.TestCase):
def test_forward_cifar(self):
BS = 32
# with training batchnorm still though
with Tensor.train():
with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
@TinyJit
def run(X): return model(X)
@ -133,7 +133,7 @@ class TestRealWorld(unittest.TestCase):
@slow
def test_train_cifar(self):
with Tensor.train():
with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
BS = 32
@ -151,7 +151,7 @@ class TestRealWorld(unittest.TestCase):
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_train_cifar_hyp(self):
dtypes.default_float = dtypes.float16
with Tensor.train():
with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
initial_div_factor = hyp['opt']['initial_div_factor']
@ -163,7 +163,7 @@ class TestRealWorld(unittest.TestCase):
@slow
def test_bert(self):
with Tensor.train():
with Context(TRAINING=1):
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
model = BertForPretraining(**args_tiny)

View file

@ -1093,14 +1093,14 @@ class TestSchedule(unittest.TestCase):
#@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3, nn.state.get_parameters(bn))
def test_fold_conv_batchnorm_notrain(self):
with Tensor.train(False):
with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=True)
@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm_notrain_no_running_stats(self):
with Tensor.train(False):
with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1116,7 +1116,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1125,7 +1125,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self, adam=False):
optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15)
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1139,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
def test_fold_batchnorm_backward(self):
with Tensor.train():
with Context(TRAINING=1):
x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16)
fw = bn(x).contiguous_backward().relu().contiguous()
@ -1484,7 +1484,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4)
def test_adam_step_fusion(self):
with Tensor.train():
with Context(TRAINING=1):
x = Tensor.empty(4, 64, 32)
layer = nn.Linear(32, 32*4)
_realize_weights(layer)
@ -1494,7 +1494,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13)
def test_adam_conv_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
_realize_weights(c1)
@ -1505,7 +1505,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13)
def test_adam_2convs_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
@ -1517,7 +1517,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15)
def test_sgd_conv_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
_realize_weights(c1)
@ -1527,7 +1527,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 5) # TODO: 3?
def test_sgd_2convs_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
@ -1538,7 +1538,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 7)
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
@ -1550,7 +1550,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 11)
def test_sgd_4convs_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
@ -1563,7 +1563,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15)
def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
@ -1664,7 +1664,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0)
def test_resnet_block(self):
with Tensor.train(False):
with Context(TRAINING=0):
in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes)

View file

@ -1,7 +1,7 @@
# tensor tests that pass on NULL backend (no copyout needed)
import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
@ -15,7 +15,7 @@ m_init = np.random.randn(1,3).astype(np.float32)
class TestTrainMode(unittest.TestCase):
def test_train_mode(self):
assert not Tensor.training
@Tensor.train()
@Context(TRAINING=1)
def f():
assert Tensor.training
f()

View file

@ -207,7 +207,7 @@ class TestMultiTensor(unittest.TestCase):
out.numpy()
def test_backprop_conv(self):
with Tensor.train():
with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.Adam(get_parameters(conv))
@ -511,7 +511,7 @@ class TestMultiTensor(unittest.TestCase):
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_dropout_on_shard(self):
with Tensor.train():
with Context(TRAINING=1):
X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
@ -519,7 +519,7 @@ class TestMultiTensor(unittest.TestCase):
assert 96 < counts[0] < 160, counts[0]
def test_dropout_on_shard_axis(self):
with Tensor.train():
with Context(TRAINING=1):
X = Tensor.ones(512).shard(devices_2, axis=0)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
@ -664,7 +664,7 @@ class TestBatchNorm(unittest.TestCase):
def setUp(self): pass
def test_unsynced_backprop_conv_bn(self):
with Tensor.train():
with Context(TRAINING=1):
from extra.lr_scheduler import OneCycleLR
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
@ -709,7 +709,7 @@ class TestBatchNorm(unittest.TestCase):
bn_ts.append(bni)
return bn_ts[0].cat(*bn_ts[1:])
with Tensor.train():
with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3)
bn = BatchNorm(16)
@ -731,7 +731,7 @@ class TestBatchNorm(unittest.TestCase):
from examples.hlb_cifar10 import UnsyncedBatchNorm
GPUS = (d1, d2)
with Tensor.train():
with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
@ -756,7 +756,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0)
with Tensor.train(is_training):
with Context(TRAINING=is_training):
bns = []
for _ in range(len(devices)):
bn = nn.BatchNorm2d(8)
@ -777,7 +777,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train():
with Context(TRAINING=1):
synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))

View file

@ -1,7 +1,6 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
from contextlib import ContextDecorator
from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype
@ -71,7 +70,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
```python exec="true" session="tensor"
from tinygrad import Tensor, dtypes, nn
from tinygrad import Tensor, dtypes, nn, Context
import numpy as np
import math
np.set_printoptions(precision=4)
@ -154,11 +153,6 @@ class Tensor(RandMixin, metaclass=TensorMeta):
self.is_param = is_param
return self
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, TRAINING.value = TRAINING.value, int(self.mode)
def __exit__(self, exc_type, exc_value, traceback): TRAINING.value = self.prev
def __repr__(self):
ld = self.uop
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>"
@ -804,7 +798,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Tensor.train():
with Context(TRAINING=1):
print(t.dropout().numpy())
```
"""