Tensor.train -> TRAINING [PR] (#16705)

* Tensor.train -> TRAINING [PR]

* doc
This commit is contained in:
chenyu 2026-06-22 15:13:22 -04:00 committed by GitHub
commit 33b635d23a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 80 additions and 86 deletions

View file

@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
Throw in an optimizer, a data loader, and some compute, and you have all you need. Throw in an optimizer, a data loader, and some compute, and you have all you need.
```python ```python
from tinygrad import Tensor, nn from tinygrad import Tensor, nn, Context
class LinearNet: class LinearNet:
def __init__(self): def __init__(self):
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
with Tensor.train(): with Context(TRAINING=1):
for i in range(10): for i in range(10):
optim.zero_grad() optim.zero_grad()
loss = model(x).sparse_categorical_crossentropy(y).backward() loss = model(x).sparse_categorical_crossentropy(y).backward()

View file

@ -165,13 +165,14 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network. Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64. We will be training for 1000 steps with a batch size of 64.
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training. We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager. Upon exit, the flag is restored to its previous value by the context manager.
```python ```python
from tinygrad import Context
X_train, Y_train, X_test, Y_test = fetch_mnist() X_train, Y_train, X_test, Y_test = fetch_mnist()
with Tensor.train(): with Context(TRAINING=1):
for step in range(1000): for step in range(1000):
# random sample a batch # random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64)) samp = np.random.randint(0, X_train.shape[0], size=(64))

View file

@ -1,6 +1,6 @@
from typing import Tuple from typing import Tuple
import time import time
from tinygrad import Tensor, TinyJit, nn from tinygrad import Tensor, TinyJit, nn, Context
import gymnasium as gym import gymnasium as gym
from tinygrad.helpers import trange from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import import numpy as np # TODO: remove numpy import
@ -55,7 +55,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]: def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train(): with Context(TRAINING=1):
log_dist, value = model(x) log_dist, value = model(x)
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float() action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()

View file

@ -122,7 +122,7 @@ if __name__ == "__main__":
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(idxs:Tensor) -> Tensor: def train_step(idxs:Tensor) -> Tensor:
X, Y = X_train[idxs], Y_train[idxs] X, Y = X_train[idxs], Y_train[idxs]
if len(GPUS) > 1: if len(GPUS) > 1:

View file

@ -1,6 +1,6 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -19,7 +19,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor: def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])

View file

@ -1,6 +1,6 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -31,7 +31,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step() -> Tensor: def train_step() -> Tensor:
with Tensor.train(): with Context(TRAINING=1):
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0 Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0

View file

@ -1,6 +1,6 @@
import itertools import itertools
from typing import Callable from typing import Callable
from tinygrad import nn, Tensor, dtypes, Device, TinyJit from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
from tinygrad.helpers import getenv, trange, partition from tinygrad.helpers import getenv, trange, partition
class Model: class Model:
@ -59,7 +59,7 @@ if __name__ == "__main__":
Tensor.realize(*params, *buffers, *adam_params, loss, grads) Tensor.realize(*params, *buffers, *adam_params, loss, grads)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def microbatch(): def microbatch():
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0]) samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
for t in params: t.grad = None for t in params: t.grad = None

View file

@ -359,7 +359,7 @@ def train_cifar():
i = 0 i = 0
eval_acc_pct = 0.0 eval_acc_pct = 0.0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train(): with Context(TRAINING=1):
st = time.monotonic() st = time.monotonic()
while i <= STEPS: while i <= STEPS:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"): if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os, math, time import os, math, time
import numpy as np import numpy as np
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
@ -177,7 +177,7 @@ if __name__ == "__main__":
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0) if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def step(x:Tensor, y:Tensor) -> Tensor: def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y) _, loss = model(x, y)
optimizer.zero_grad() optimizer.zero_grad()
@ -204,4 +204,3 @@ if __name__ == "__main__":
top_k = 40 top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist())) print(decode(y[0].tolist()))

View file

@ -1,5 +1,5 @@
# much taken from https://github.com/cloneofsimo/minRF # much taken from https://github.com/cloneofsimo/minRF
from tinygrad import Tensor, nn, GlobalCounters, TinyJit from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
from tinygrad.helpers import getenv, trange from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@ -135,7 +135,7 @@ if __name__ == "__main__":
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4) optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(): def train_step():
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int') if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0]) else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])

View file

@ -358,7 +358,7 @@ def eval_stable_diffusion():
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape)) batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
return batch, unpadded_bs return batch, unpadded_bs
@Tensor.train(mode=False) @Context(TRAINING=0)
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL, def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]: inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
# Eval is divided into 5 jits, one per model # Eval is divided into 5 jits, one per model

View file

@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
from pathlib import Path from pathlib import Path
import multiprocessing import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@ -614,7 +614,7 @@ def train_retinanet():
if getenv("RESET_STEP", 1): _train_step.reset() if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False): with Context(TRAINING=0):
if not RUNMLPERF: if not RUNMLPERF:
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True)) i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
else: else:
@ -784,7 +784,7 @@ def train_unet3d():
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(model, x, y): def train_step(model, x, y):
optim.zero_grad() optim.zero_grad()
@ -795,7 +795,7 @@ def train_unet3d():
optim.step() optim.step()
return loss.realize() return loss.realize()
@Tensor.train(mode=False) @Context(TRAINING=0)
def eval_step(model, x, y): def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS) y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y) y_hat, y = Tensor(y_hat), Tensor(y)
@ -1490,7 +1490,7 @@ def train_llama3():
return lr_cpu, grad_norm_cpu return lr_cpu, grad_norm_cpu
@TinyJit @TinyJit
@Tensor.train(False) @Context(TRAINING=0)
def eval_step(tokens:Tensor): def eval_step(tokens:Tensor):
if is_dp: tokens = tokens.to(None).shard(device, 0) if is_dp: tokens = tokens.to(None).shard(device, 0)
if is_mp: tokens = tokens.shard(device) if is_mp: tokens = tokens.shard(device)
@ -1803,7 +1803,7 @@ if __name__ == "__main__":
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN) elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
else: bench_log_manager = contextlib.nullcontext() else: bench_log_manager = contextlib.nullcontext()
with Tensor.train(): with Context(TRAINING=1):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","): for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
nm = f"train_{m}" nm = f"train_{m}"
if nm in globals(): if nm in globals():

View file

@ -3,7 +3,7 @@ import torch
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import trange from tinygrad.helpers import trange, Context
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -86,7 +86,7 @@ if __name__ == "__main__":
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium! optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5) optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop # training loop
with Tensor.train(): with Context(TRAINING=1):
for epoch in (t := trange(epochs)): for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0 loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps): for _ in range(n_steps):

View file

@ -5,7 +5,7 @@
# - symbolic removal # - symbolic removal
from examples.beautiful_mnist import Model from examples.beautiful_mnist import Model
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
from tinygrad.helpers import trange from tinygrad.helpers import trange
@ -26,7 +26,7 @@ if __name__ == "__main__":
X_samp, Y_samp = X_train[samples], Y_train[samples] X_samp, Y_samp = X_train[samples], Y_train[samples]
print("*** got samples") print("*** got samples")
with Tensor.train(): with Context(TRAINING=1):
""" """
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i) losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)

View file

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import trange from tinygrad.helpers import trange, Context
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit
@ -22,7 +22,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
if allow_jit: train_step = TinyJit(train_step) if allow_jit: train_step = TinyJit(train_step)
with Tensor.train(): with Context(TRAINING=1):
losses, accuracies = [], [] losses, accuracies = [], []
for i in (t := trange(steps, disable=None)): for i in (t := trange(steps, disable=None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS)) samp = np.random.randint(0, X_train.shape[0], size=(BS))
@ -55,4 +55,3 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
acc, Y_test_pred = numpy_eval(Y_test, num_classes) acc, Y_test_pred = numpy_eval(Y_test, num_classes)
print("test set accuracy is %f" % acc) print("test set accuracy is %f" % acc)
return (acc, Y_test_pred) if return_predict else acc return (acc, Y_test_pred) if return_predict else acc

View file

@ -25,7 +25,7 @@
import unittest import unittest
import numpy as np import numpy as np
import torch import torch
from tinygrad import Tensor, dtypes, nn from tinygrad import Tensor, dtypes, nn, Context
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.helpers import DEV from tinygrad.helpers import DEV
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
@ -101,7 +101,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
# we don't need more of these # we don't need more of these
def test_dropout_rate_one(self): def test_dropout_rate_one(self):
with Tensor.train(): with Context(TRAINING=1):
out = Tensor.ones(100).dropout(1.0) out = Tensor.ones(100).dropout(1.0)
np.testing.assert_allclose(out.numpy(), np.zeros(100)) np.testing.assert_allclose(out.numpy(), np.zeros(100))
@ -109,7 +109,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
torch.nn.functional.dropout(torch.ones(10), -0.1, True) torch.nn.functional.dropout(torch.ones(10), -0.1, True)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
with Tensor.train(): with Context(TRAINING=1):
Tensor.ones(10).dropout(-0.1) Tensor.ones(10).dropout(-0.1)
class TestInputValidation(unittest.TestCase): class TestInputValidation(unittest.TestCase):

View file

@ -14,7 +14,7 @@ from test.helpers import not_support_multi_device, needs_second_gpu, slow
@slow @slow
class TestNN(unittest.TestCase): class TestNN(unittest.TestCase):
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True): def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
with Tensor.train(training): with Context(TRAINING=training):
szs = [4, 8, 16, 32] szs = [4, 8, 16, 32]
for sz in szs: for sz in szs:
# create in tinygrad # create in tinygrad

View file

@ -41,7 +41,7 @@ class TestStunning(unittest.TestCase):
X_samp, Y_samp = X_train[samples], Y_train[samples] X_samp, Y_samp = X_train[samples], Y_train[samples]
vi = Variable('i', 0, samples.shape[0]-1) vi = Variable('i', 0, samples.shape[0]-1)
with Context(SPLIT_REDUCEOP=0): with Context(SPLIT_REDUCEOP=0):
with Tensor.train(): with Context(TRAINING=1):
losses = [] losses = []
for i in range(samples.shape[0]): for i in range(samples.shape[0]):
vib = vi.bind(i) vib = vi.bind(i)

View file

@ -1,5 +1,5 @@
import unittest import unittest
from tinygrad import Tensor, Variable, GlobalCounters from tinygrad import Tensor, Variable, GlobalCounters, Context
from tinygrad.uop.ops import sym_infer from tinygrad.uop.ops import sym_infer
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from examples.gpt2 import Attention from examples.gpt2 import Attention
@ -63,7 +63,7 @@ class TestSymbolicOps(unittest.TestCase):
self.test_attention(imin=4, imax=5, use_symbolic=True) self.test_attention(imin=4, imax=5, use_symbolic=True)
def test_attention_training(self): def test_attention_training(self):
with Tensor.train(): with Context(TRAINING=1):
self.test_attention(dropout_p=0.0) self.test_attention(dropout_p=0.0)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# symbolic shape dropout is not supported # symbolic shape dropout is not supported

View file

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import torch import torch
import unittest, copy, mmap, random, math, array import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes, nn from tinygrad import Tensor, Device, dtypes, nn, Context
from tinygrad.helpers import getenv, temp, mv_address from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat from hypothesis import given, settings, strategies as strat
@ -203,7 +203,7 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(x, y, atol=1e-5) np.testing.assert_allclose(x, y, atol=1e-5)
def test_dropout(self): def test_dropout(self):
with Tensor.train(): with Context(TRAINING=1):
n, rate = 1_000_000, 0.1 n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate) w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.numpy()) non_zeros = np.count_nonzero(w.numpy())

View file

@ -1,6 +1,6 @@
import unittest, os import unittest, os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from tinygrad import Tensor from tinygrad import Context
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from examples.mlperf.model_train import train_stable_diffusion from examples.mlperf.model_train import train_stable_diffusion
@ -14,10 +14,10 @@ class TestTrain(unittest.TestCase):
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion" if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
with TemporaryDirectory(prefix="test-train") as tmp: with TemporaryDirectory(prefix="test-train") as tmp:
os.environ["UNET_CKPTDIR"] = tmp os.environ["UNET_CKPTDIR"] = tmp
with Tensor.train(): with Context(TRAINING=1):
saved_ckpts = train_stable_diffusion() saved_ckpts = train_stable_diffusion()
expected_ckpt = f"{tmp}/{num_steps}.safetensors" expected_ckpt = f"{tmp}/{num_steps}.safetensors"
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
if __name__=="__main__": if __name__=="__main__":
unittest.main() unittest.main()

View file

@ -3,7 +3,7 @@ import ast, pathlib, unittest
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tinygrad import Tensor from tinygrad import Tensor, Context
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from test.helpers import slow from test.helpers import slow
from extra.models.efficientnet import EfficientNet from extra.models.efficientnet import EfficientNet
@ -40,7 +40,7 @@ def preprocess(img, new=False):
return img return img
def _infer(model: EfficientNet, img): def _infer(model: EfficientNet, img):
with Tensor.train(False): with Context(TRAINING=0):
out = model.forward(Tensor(img)).argmax(axis=-1) out = model.forward(Tensor(img)).argmax(axis=-1)
return out.tolist() return out.tolist()

View file

@ -5,10 +5,11 @@ import numpy as np
from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import Context
from extra.datasets import fetch_mnist from extra.datasets import fetch_mnist
def compare_tiny_torch(model, model_torch, X, Y): def compare_tiny_torch(model, model_torch, X, Y):
with Tensor.train(): with Context(TRAINING=1):
model_torch.train() model_torch.train()
model_state_dict = get_state_dict(model) model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters(): for k,v in model_torch.named_parameters():

View file

@ -106,7 +106,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_train_mnist(self): def test_train_mnist(self):
from examples.beautiful_mnist import Model from examples.beautiful_mnist import Model
with Tensor.train(): with Context(TRAINING=1):
model = Model() model = Model()
optimizer = optim.Adam(get_parameters(model)) optimizer = optim.Adam(get_parameters(model))
BS = 32 BS = 32
@ -125,7 +125,7 @@ class TestRealWorld(unittest.TestCase):
def test_forward_cifar(self): def test_forward_cifar(self):
BS = 32 BS = 32
# with training batchnorm still though # with training batchnorm still though
with Tensor.train(): with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
@TinyJit @TinyJit
def run(X): return model(X) def run(X): return model(X)
@ -133,7 +133,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_train_cifar(self): def test_train_cifar(self):
with Tensor.train(): with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15) optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
BS = 32 BS = 32
@ -151,7 +151,7 @@ class TestRealWorld(unittest.TestCase):
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16") @unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_train_cifar_hyp(self): def test_train_cifar_hyp(self):
dtypes.default_float = dtypes.float16 dtypes.default_float = dtypes.float16
with Tensor.train(): with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay']) optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
initial_div_factor = hyp['opt']['initial_div_factor'] initial_div_factor = hyp['opt']['initial_div_factor']
@ -163,7 +163,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_bert(self): def test_bert(self):
with Tensor.train(): with Context(TRAINING=1):
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2, args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2} "max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
model = BertForPretraining(**args_tiny) model = BertForPretraining(**args_tiny)

View file

@ -1093,14 +1093,14 @@ class TestSchedule(unittest.TestCase):
#@unittest.skip("may want to reconsider this") #@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self): def test_fold_batchnorm(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(1,32,4,4) img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img) out = bn(img)
check_schedule(out, 3, nn.state.get_parameters(bn)) check_schedule(out, 3, nn.state.get_parameters(bn))
def test_fold_conv_batchnorm_notrain(self): def test_fold_conv_batchnorm_notrain(self):
with Tensor.train(False): with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=True) bn = nn.BatchNorm2d(32, track_running_stats=True)
@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)]) check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm_notrain_no_running_stats(self): def test_fold_conv_batchnorm_notrain_no_running_stats(self):
with Tensor.train(False): with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1116,7 +1116,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)]) check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm(self): def test_fold_conv_batchnorm(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1125,7 +1125,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self, adam=False): def test_fold_conv_batchnorm_optim(self, adam=False):
optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15) optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15)
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.ones(1,3,4,4) img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1139,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True) def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
def test_fold_batchnorm_backward(self): def test_fold_batchnorm_backward(self):
with Tensor.train(): with Context(TRAINING=1):
x = Tensor.empty((2, 16, 8, 8)).contiguous() x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16) bn = nn.BatchNorm2d(16)
fw = bn(x).contiguous_backward().relu().contiguous() fw = bn(x).contiguous_backward().relu().contiguous()
@ -1484,7 +1484,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4) check_schedule(out, 4)
def test_adam_step_fusion(self): def test_adam_step_fusion(self):
with Tensor.train(): with Context(TRAINING=1):
x = Tensor.empty(4, 64, 32) x = Tensor.empty(4, 64, 32)
layer = nn.Linear(32, 32*4) layer = nn.Linear(32, 32*4)
_realize_weights(layer) _realize_weights(layer)
@ -1494,7 +1494,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13) check_schedule(opt.schedule_step(), 13)
def test_adam_conv_fuse(self): def test_adam_conv_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
_realize_weights(c1) _realize_weights(c1)
@ -1505,7 +1505,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13) check_schedule(opt.schedule_step(), 13)
def test_adam_2convs_fuse(self): def test_adam_2convs_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1517,7 +1517,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15) check_schedule(opt.schedule_step(), 15)
def test_sgd_conv_fuse(self): def test_sgd_conv_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
_realize_weights(c1) _realize_weights(c1)
@ -1527,7 +1527,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 5) # TODO: 3? check_schedule(opt.schedule_step(), 5) # TODO: 3?
def test_sgd_2convs_fuse(self): def test_sgd_2convs_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1538,7 +1538,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 7) check_schedule(opt.schedule_step(), 7)
def test_fold_2convs_sgd_nesterov_momentum_wd(self): def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1550,7 +1550,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 11) check_schedule(opt.schedule_step(), 11)
def test_sgd_4convs_fuse(self): def test_sgd_4convs_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16) img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False) c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False)
@ -1563,7 +1563,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15) check_schedule(opt.schedule_step(), 15)
def test_sgd_4convs_fuse_conv_bw(self): def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16) img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False) c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False)
@ -1664,7 +1664,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0) self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0)
def test_resnet_block(self): def test_resnet_block(self):
with Tensor.train(False): with Context(TRAINING=0):
in_planes, planes = 64, 64 in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes) bn1 = nn.BatchNorm2d(planes)

View file

@ -1,7 +1,7 @@
# tensor tests that pass on NULL backend (no copyout needed) # tensor tests that pass on NULL backend (no copyout needed)
import numpy as np import numpy as np
import unittest import unittest
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import Ops, UOp from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
@ -15,7 +15,7 @@ m_init = np.random.randn(1,3).astype(np.float32)
class TestTrainMode(unittest.TestCase): class TestTrainMode(unittest.TestCase):
def test_train_mode(self): def test_train_mode(self):
assert not Tensor.training assert not Tensor.training
@Tensor.train() @Context(TRAINING=1)
def f(): def f():
assert Tensor.training assert Tensor.training
f() f()

View file

@ -207,7 +207,7 @@ class TestMultiTensor(unittest.TestCase):
out.numpy() out.numpy()
def test_backprop_conv(self): def test_backprop_conv(self):
with Tensor.train(): with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2) for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.Adam(get_parameters(conv)) optim = nn.optim.Adam(get_parameters(conv))
@ -511,7 +511,7 @@ class TestMultiTensor(unittest.TestCase):
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0) def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_dropout_on_shard(self): def test_dropout_on_shard(self):
with Tensor.train(): with Context(TRAINING=1):
X = Tensor.ones(256).to(devices_2) X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5).numpy() output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True) unique, counts = np.unique(output, return_counts=True)
@ -519,7 +519,7 @@ class TestMultiTensor(unittest.TestCase):
assert 96 < counts[0] < 160, counts[0] assert 96 < counts[0] < 160, counts[0]
def test_dropout_on_shard_axis(self): def test_dropout_on_shard_axis(self):
with Tensor.train(): with Context(TRAINING=1):
X = Tensor.ones(512).shard(devices_2, axis=0) X = Tensor.ones(512).shard(devices_2, axis=0)
output = X.dropout(0.5).numpy() output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True) unique, counts = np.unique(output, return_counts=True)
@ -664,7 +664,7 @@ class TestBatchNorm(unittest.TestCase):
def setUp(self): pass def setUp(self): pass
def test_unsynced_backprop_conv_bn(self): def test_unsynced_backprop_conv_bn(self):
with Tensor.train(): with Context(TRAINING=1):
from extra.lr_scheduler import OneCycleLR from extra.lr_scheduler import OneCycleLR
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)] convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
@ -709,7 +709,7 @@ class TestBatchNorm(unittest.TestCase):
bn_ts.append(bni) bn_ts.append(bni)
return bn_ts[0].cat(*bn_ts[1:]) return bn_ts[0].cat(*bn_ts[1:])
with Tensor.train(): with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
bn = BatchNorm(16) bn = BatchNorm(16)
@ -731,7 +731,7 @@ class TestBatchNorm(unittest.TestCase):
from examples.hlb_cifar10 import UnsyncedBatchNorm from examples.hlb_cifar10 import UnsyncedBatchNorm
GPUS = (d1, d2) GPUS = (d1, d2)
with Tensor.train(): with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS)) bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
@ -756,7 +756,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0) x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0)
with Tensor.train(is_training): with Context(TRAINING=is_training):
bns = [] bns = []
for _ in range(len(devices)): for _ in range(len(devices)):
bn = nn.BatchNorm2d(8) bn = nn.BatchNorm2d(8)
@ -777,7 +777,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train(): with Context(TRAINING=1):
synced_bn = BatchNorm2d(8) synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices)) unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))

View file

@ -1,7 +1,6 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations from __future__ import annotations
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
from contextlib import ContextDecorator
from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype
@ -71,7 +70,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
A `Tensor` is a multi-dimensional matrix containing elements of a single data type. A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
```python exec="true" session="tensor" ```python exec="true" session="tensor"
from tinygrad import Tensor, dtypes, nn from tinygrad import Tensor, dtypes, nn, Context
import numpy as np import numpy as np
import math import math
np.set_printoptions(precision=4) np.set_printoptions(precision=4)
@ -154,11 +153,6 @@ class Tensor(RandMixin, metaclass=TensorMeta):
self.is_param = is_param self.is_param = is_param
return self return self
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, TRAINING.value = TRAINING.value, int(self.mode)
def __exit__(self, exc_type, exc_value, traceback): TRAINING.value = self.prev
def __repr__(self): def __repr__(self):
ld = self.uop ld = self.uop
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>" ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>"
@ -804,7 +798,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
```python exec="true" source="above" session="tensor" result="python" ```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42) Tensor.manual_seed(42)
t = Tensor.randn(2, 2) t = Tensor.randn(2, 2)
with Tensor.train(): with Context(TRAINING=1):
print(t.dropout().numpy()) print(t.dropout().numpy())
``` ```
""" """