Compare commits

...

18 commits

Author SHA1 Message Date
chenyu
687ade119e
IMAGE hand_coded_optimizations update (#16720) 2026-06-23 21:55:28 -04:00
George Hotz
0a8e61d0c5
switch to the new memory coaleser [pr] (#16716)
* switch to the new memory coalese

* move that stuff

* copy in allowed length logic

* mulitple buffers

* new coalese is better

* fine

* earlier

* fixes

* work

* work

* valid

* stack on index const
2026-06-23 18:03:48 -07:00
wozeparrot
dfea9e7994
llama: fused silu mul quantize mxfp8 (#16704) 2026-06-23 16:59:50 -07:00
chenyu
ce87d80911
better _drop_valid_stmts [pr] (#16719)
also dropped the unused is_increasing
2026-06-23 19:35:01 -04:00
George Hotz
5a2b3b7b06
early dtype decomp (#16718)
* early dtype decomp

* simplify

* cleanup

* that goes there

* doing too much

* stupid symbolic rules
2026-06-23 16:07:20 -07:00
Christopher Milan
116045cc8e
ci: remove tensorflow from testoptim (#16717) 2026-06-23 18:11:48 -04:00
nimlgen
7c1d0b6d9a
hcq2: use shrink(bitcast) (#16713)
* hcq2: use shrink(bitcast)

* x
2026-06-23 18:11:39 +03:00
George Hotz
c9dc1d63cc
small changes from new codegen (#16712)
* small changes from new codegen

* shrink/flatten
2026-06-22 17:44:15 -07:00
Christopher Milan
da98fae9e1
ci: try parallelizing tc tests (#16710) 2026-06-22 20:43:32 -04:00
chenyu
15988b5941
contiguous to mixin and cleanups [PR] (#16711) 2026-06-22 20:18:18 -04:00
Christopher Milan
cbfcf36e44
ci: remove generate_dataset and CL misc (#16709) 2026-06-22 18:01:07 -04:00
nimlgen
f9c8c697d6
hcq2: drop args after inner deps (#16708) 2026-06-22 23:26:11 +03:00
chenyu
0138480910
dropout and scaled_dot_product_attention to mixin (#16707) 2026-06-22 16:17:45 -04:00
chenyu
33b635d23a
Tensor.train -> TRAINING [PR] (#16705)
* Tensor.train -> TRAINING [PR]

* doc
2026-06-22 15:13:22 -04:00
chenyu
625d8bbd0d
TRAINING ContextVar (#16703) 2026-06-22 13:03:08 -04:00
wozeparrot
fe9b19b12d
llama: more mp mem fixes (#16701)
* llama: more mp mem fixes

* clean: unused

* fix: batch
2026-06-22 10:54:35 -04:00
chenyu
267af9c601
full_like to CreationMixin [PR] (#16702) 2026-06-22 09:33:23 -04:00
chenyu
97da54b9d6
more method to CreationMixin [PR] (#16698) 2026-06-22 00:01:22 -04:00
53 changed files with 648 additions and 486 deletions

View file

@ -133,46 +133,26 @@ jobs:
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20 run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
- name: Test IMAGE support - name: Test IMAGE support
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
- name: Test emulated METAL tensor cores - name: Test emulated tensor cores
env: env:
DEV: 'PYTHON::METAL' DEBUG: 2
N: 64
CNT: 1
SHOULD_USE_TC: 1
run: | run: |
DEBUG=2 python3 test/backend/test_ops.py TestOps.test_big_gemm parallel -k --link --tagstring '[{1}]' '{2} python3 ./extra/gemm/simple_matmul.py' \
python3 -m pytest -nauto test/opt/test_tensor_cores.py ::: metal gfx950 gfx1100 gfx1100_acchalf gfx1201 gfx1201_acchalf sm_75 sm_80_half sm_80_tf32 \
- name: Test emulated AMD tensor cores ::: 'DEV=PYTHON::METAL' 'DEV=PYTHON::gfx950 HALF=1 ACC_HALF=0' \
env: 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=1 ATOL=1e-3' \
DEV: 'PYTHON::gfx1100' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=1 ATOL=1e-3' \
'DEV=PYTHON::sm_75 HALF=1' 'DEV=PYTHON::sm_80 HALF=1' 'DEV=PYTHON::sm_80 ALLOW_TF32=1'
- name: Run additional tensor core tests
run: | run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::METAL python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::gfx1100 python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::gfx950 python3 -m pytest -nauto test/opt/test_tensor_cores.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::gfx1201 python3 -m pytest -nauto test/opt/test_tensor_cores.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD MFMA tensor cores
env:
DEV: 'PYTHON::gfx950'
run: |
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD RDNA4 tensor cores
env:
DEV: 'PYTHON::gfx1201'
run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated CUDA tensor cores
run: |
DEBUG=2 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test device flop counts
run: |
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
linter: linter:
@ -267,13 +247,6 @@ jobs:
run: python3 test/external/external_benchmark_schedule.py run: python3 test/external/external_benchmark_schedule.py
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
- name: Regen dataset on test_tiny
run: |
test/external/process_replay/reset.py
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
python extra/optimization/extract_dataset.py
gzip -c /tmp/sops > extra/datasets/sops.gz
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
- name: Repo line count < 25000 lines - name: Repo line count < 25000 lines
run: MAX_LINE_COUNT=25000 python sz.py run: MAX_LINE_COUNT=25000 python sz.py
@ -338,31 +311,6 @@ jobs:
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
testgpumisc:
name: CL Misc tests
runs-on: *linux
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: gen-dataset
deps: testing
opencl: 'true'
- name: Generate Dataset
run: DEV=CL extra/optimization/generate_dataset.sh
- name: Run Kernel Count Test
run: DEV=CL python -m pytest -n=auto test/external/external_test_opt.py
- name: Run fused optimizer tests
run: DEV=CL FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py test/backend/test_optim.py -k "not muon"
- name: Upload artifact
uses: actions/upload-artifact@v7
with:
name: sops.gz
path: /tmp/sops.gz
testopenpilot: testopenpilot:
name: openpilot Compile Tests name: openpilot Compile Tests
runs-on: *linux runs-on: *linux
@ -379,7 +327,7 @@ jobs:
llvm: 'true' llvm: 'true'
- name: Test openpilot model kernel count and gate usage - name: Test openpilot model kernel count and gate usage
run: | run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness) - name: Test openpilot CL compile fp32 (test correctness)
run: | run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
@ -422,7 +370,6 @@ jobs:
with: with:
key: optim key: optim
deps: testing deps: testing
pydeps: "tensorflow==2.19"
opencl: 'true' opencl: 'true'
#- name: Test Optimization Helpers #- name: Test Optimization Helpers
# run: DEBUG=1 python3 extra/optimization/test_helpers.py # run: DEBUG=1 python3 extra/optimization/test_helpers.py
@ -431,7 +378,7 @@ jobs:
- name: Test Beam Search - name: Test Beam Search
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Test MLPerf stuff - name: Test MLPerf stuff
run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20 run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
- name: DEV=NULL beautiful_mnist_multigpu - name: DEV=NULL beautiful_mnist_multigpu
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
- name: Test Bert training - name: Test Bert training

View file

@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
Throw in an optimizer, a data loader, and some compute, and you have all you need. Throw in an optimizer, a data loader, and some compute, and you have all you need.
```python ```python
from tinygrad import Tensor, nn from tinygrad import Tensor, nn, Context
class LinearNet: class LinearNet:
def __init__(self): def __init__(self):
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
with Tensor.train(): with Context(TRAINING=1):
for i in range(10): for i in range(10):
optim.zero_grad() optim.zero_grad()
loss = model(x).sparse_categorical_crossentropy(y).backward() loss = model(x).sparse_categorical_crossentropy(y).backward()

View file

@ -165,13 +165,14 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network. Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64. We will be training for 1000 steps with a batch size of 64.
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training. We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager. Upon exit, the flag is restored to its previous value by the context manager.
```python ```python
from tinygrad import Context
X_train, Y_train, X_test, Y_test = fetch_mnist() X_train, Y_train, X_test, Y_test = fetch_mnist()
with Tensor.train(): with Context(TRAINING=1):
for step in range(1000): for step in range(1000):
# random sample a batch # random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64)) samp = np.random.randint(0, X_train.shape[0], size=(64))

View file

@ -1,6 +1,6 @@
from typing import Tuple from typing import Tuple
import time import time
from tinygrad import Tensor, TinyJit, nn from tinygrad import Tensor, TinyJit, nn, Context
import gymnasium as gym import gymnasium as gym
from tinygrad.helpers import trange from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import import numpy as np # TODO: remove numpy import
@ -55,7 +55,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]: def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train(): with Context(TRAINING=1):
log_dist, value = model(x) log_dist, value = model(x)
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float() action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()

View file

@ -122,7 +122,7 @@ if __name__ == "__main__":
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(idxs:Tensor) -> Tensor: def train_step(idxs:Tensor) -> Tensor:
X, Y = X_train[idxs], Y_train[idxs] X, Y = X_train[idxs], Y_train[idxs]
if len(GPUS) > 1: if len(GPUS) > 1:

View file

@ -1,6 +1,6 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -19,7 +19,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor: def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])

View file

@ -1,6 +1,6 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -31,7 +31,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step() -> Tensor: def train_step() -> Tensor:
with Tensor.train(): with Context(TRAINING=1):
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0 Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0

View file

@ -1,6 +1,6 @@
import itertools import itertools
from typing import Callable from typing import Callable
from tinygrad import nn, Tensor, dtypes, Device, TinyJit from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
from tinygrad.helpers import getenv, trange, partition from tinygrad.helpers import getenv, trange, partition
class Model: class Model:
@ -59,7 +59,7 @@ if __name__ == "__main__":
Tensor.realize(*params, *buffers, *adam_params, loss, grads) Tensor.realize(*params, *buffers, *adam_params, loss, grads)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def microbatch(): def microbatch():
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0]) samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
for t in params: t.grad = None for t in params: t.grad = None

View file

@ -359,7 +359,7 @@ def train_cifar():
i = 0 i = 0
eval_acc_pct = 0.0 eval_acc_pct = 0.0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train(): with Context(TRAINING=1):
st = time.monotonic() st = time.monotonic()
while i <= STEPS: while i <= STEPS:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"): if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os, math, time import os, math, time
import numpy as np import numpy as np
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
@ -177,7 +177,7 @@ if __name__ == "__main__":
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0) if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def step(x:Tensor, y:Tensor) -> Tensor: def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y) _, loss = model(x, y)
optimizer.zero_grad() optimizer.zero_grad()
@ -204,4 +204,3 @@ if __name__ == "__main__":
top_k = 40 top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist())) print(decode(y[0].tolist()))

View file

@ -1,5 +1,5 @@
# much taken from https://github.com/cloneofsimo/minRF # much taken from https://github.com/cloneofsimo/minRF
from tinygrad import Tensor, nn, GlobalCounters, TinyJit from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
from tinygrad.helpers import getenv, trange from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@ -135,7 +135,7 @@ if __name__ == "__main__":
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4) optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(): def train_step():
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int') if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0]) else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])

View file

@ -358,7 +358,7 @@ def eval_stable_diffusion():
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape)) batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
return batch, unpadded_bs return batch, unpadded_bs
@Tensor.train(mode=False) @Context(TRAINING=0)
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL, def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]: inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
# Eval is divided into 5 jits, one per model # Eval is divided into 5 jits, one per model

View file

@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
from pathlib import Path from pathlib import Path
import multiprocessing import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@ -614,7 +614,7 @@ def train_retinanet():
if getenv("RESET_STEP", 1): _train_step.reset() if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False): with Context(TRAINING=0):
if not RUNMLPERF: if not RUNMLPERF:
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True)) i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
else: else:
@ -784,7 +784,7 @@ def train_unet3d():
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(model, x, y): def train_step(model, x, y):
optim.zero_grad() optim.zero_grad()
@ -795,7 +795,7 @@ def train_unet3d():
optim.step() optim.step()
return loss.realize() return loss.realize()
@Tensor.train(mode=False) @Context(TRAINING=0)
def eval_step(model, x, y): def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS) y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y) y_hat, y = Tensor(y_hat), Tensor(y)
@ -1490,7 +1490,7 @@ def train_llama3():
return lr_cpu, grad_norm_cpu return lr_cpu, grad_norm_cpu
@TinyJit @TinyJit
@Tensor.train(False) @Context(TRAINING=0)
def eval_step(tokens:Tensor): def eval_step(tokens:Tensor):
if is_dp: tokens = tokens.to(None).shard(device, 0) if is_dp: tokens = tokens.to(None).shard(device, 0)
if is_mp: tokens = tokens.shard(device) if is_mp: tokens = tokens.shard(device)
@ -1803,7 +1803,7 @@ if __name__ == "__main__":
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN) elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
else: bench_log_manager = contextlib.nullcontext() else: bench_log_manager = contextlib.nullcontext()
with Tensor.train(): with Context(TRAINING=1):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","): for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
nm = f"train_{m}" nm = f"train_{m}"
if nm in globals(): if nm in globals():

View file

@ -38,7 +38,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None, def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None, x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]: grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]:
if not fp8: if not fp8:
if ASM_GEMM: if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
@ -47,12 +47,14 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)" assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
if MXFP8: if MXFP8:
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1])) if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d)
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
if can_use_asm_gemm(x_q, w.T): if can_use_asm_gemm(x_q, w.T):
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale), out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0]) mx_w_stored=True).reshape(*l_shape, w.shape[0])
else: else:
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1]) x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1])
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
return out, (amax_x.detach() if amax_x is not None else None), x_q return out, (amax_x.detach() if amax_x is not None else None), x_q
if x_fp8 is None: if x_fp8 is None:
@ -126,10 +128,8 @@ class FlatTransformer:
# FeedForward # FeedForward
if SPLIT_W13: if SPLIT_W13:
if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim) self.w1, s_1 = self.lin_per_layer(dim, hidden_dim)
else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02) self.w3, s_3 = self.lin_per_layer(dim, hidden_dim)
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
else: else:
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2) self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std) self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
@ -160,7 +160,7 @@ class FlatTransformer:
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None): def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
if w is None: if w is None:
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features) if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize() else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
if MXFP8: if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8 from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features)) w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
@ -216,6 +216,13 @@ class FlatTransformer:
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"]) x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
amaxs.append(new_amax) amaxs.append(new_amax)
saves.extend([*s, x_w3]) saves.extend([*s, x_w3])
if FUSED_SILU_W13 and MXFP8:
from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
else:
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"], out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
grad_amax_state=kwargs["grad_amax_xout"]) grad_amax_state=kwargs["grad_amax_xout"])
amaxs.append(new_amax) amaxs.append(new_amax)
@ -247,20 +254,30 @@ class FlatTransformer:
for v in get_parameters(self): v.shard_(device, axis=None) for v in get_parameters(self): v.shard_(device, axis=None)
else: else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer # flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
def _shard_fp8(name:str, axis:int): def _shard_fp8(name:str, axis:int, std:float=0.02):
getattr(self, name).shard_(device, axis=axis) w = getattr(self, name)
scale_axis = axis if MXFP8 else (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std
w_q, w_e8, _ = quantize_mxfp8(w_bf16)
w.replace(w_q)
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
else:
w.shard_(device, axis=axis)
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False) self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False) self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name]) Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
sstd = 0.02 / math.sqrt(2 * self.n_layers)
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out _shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in _shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in
if SPLIT_W13: if SPLIT_W13:
_shard_fp8("w1", 1) _shard_fp8("w1", 1)
_shard_fp8("w3", 1) _shard_fp8("w3", 1)
else: else:
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out _shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in _shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize() self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize() self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize() self.norm.weight.shard_(device, axis=None).realize()

View file

@ -3,7 +3,7 @@ import torch
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import trange from tinygrad.helpers import trange, Context
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -86,7 +86,7 @@ if __name__ == "__main__":
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium! optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5) optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop # training loop
with Tensor.train(): with Context(TRAINING=1):
for epoch in (t := trange(epochs)): for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0 loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps): for _ in range(n_steps):

View file

@ -5,7 +5,7 @@
# - symbolic removal # - symbolic removal
from examples.beautiful_mnist import Model from examples.beautiful_mnist import Model
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
from tinygrad.helpers import trange from tinygrad.helpers import trange
@ -26,7 +26,7 @@ if __name__ == "__main__":
X_samp, Y_samp = X_train[samples], Y_train[samples] X_samp, Y_samp = X_train[samples], Y_train[samples]
print("*** got samples") print("*** got samples")
with Tensor.train(): with Context(TRAINING=1):
""" """
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i) losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)

View file

@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float) # accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG) acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
acc = acc.after(acc.store(acc.zeros_like())) acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
if use_wmma: if use_wmma:
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE) k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)

View file

@ -2675,8 +2675,8 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]: def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# 1x32 block scaling along the last axis # 1x32 block scaling along the last axis
*batch, K = x.shape *batch, K = x.shape
scale_K, k_iters = K // 32, K // 128 scale_K = K // 32
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1) amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8) e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K) qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
x_scaled = x.float() * qscale x_scaled = x.float() * qscale

View file

@ -143,14 +143,17 @@ def make_getaddr(u, device=None):
def make_ins(op, *srcs): def make_ins(op, *srcs):
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op) return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
dt = dtype or val.dtype
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
def make_cmdbuf(lin, devs, tag): def make_cmdbuf(lin, devs, tag):
blob, patches = b'', [] blob, patches = b'', []
for s in (s for ins in lin.src for s in ins.src): for s in (s for ins in lin.src for s in ins.src):
if s.op is not Ops.CONST: patches.append((len(blob), s)) if s.op is not Ops.CONST: patches.append((len(blob), s))
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0) blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag) buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches] return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches])
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops)) def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
@ -211,15 +214,11 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call)) return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
def prep_kernargs(call:UOp, prg:UOp) -> UOp: def prep_kernargs(call:UOp, prg:UOp) -> UOp:
data, info = prg.arg (data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device)
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))), buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
dtypes.uint64) for i,gi in enumerate(info.globals)] \ patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)] + [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)]
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
pm_prep_runtime = PatternMatcher([ pm_prep_runtime = PatternMatcher([
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering # bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
@ -307,7 +306,7 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
for (_, lane), dep in latest.items(): deps[dep] += (lane,) for (_, lane), dep in latest.items(): deps[dep] += (lane,)
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps") if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:]))) new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}),)))
return linear.replace(src=tuple(new_src)) return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)]) pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
@ -532,9 +531,9 @@ pm_resolve_patches = PatternMatcher([
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack), (UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack), (UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
# index on slice is index # shrink on slice is shrink on base at offset
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True), (UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"),
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))), lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))),
# getaddr # getaddr
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x) (UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
@ -542,8 +541,8 @@ pm_resolve_patches = PatternMatcher([
# folders # folders
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store), (UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), (UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast()
fold_const_store), .store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store),
]) + symbolic_simple ]) + symbolic_simple
# ***************** # *****************

View file

@ -0,0 +1,104 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
LOG2E = 1.4426950408889634
@functools.cache
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
act = w1 * sig * w3
abs_a = (act < 0.0).where(-act, act)
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
@functools.cache
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
VEC = 8
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
nwg = n_elems // (THREADS_PER_WG * VEC)
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
lane = UOp.range(VEC, 2, AxisType.UNROLL)
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
e8v = e8[idx // BLK].cast(dtypes.float)
qscale = (127.0 - e8v).exp2()
ga = grad_aq[idx].cast(dtypes.float) * qscale
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
s = w1 * sig
sprime = sig * (1.0 + w1 * (1.0 - sig))
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
device = x_w1.device
rows, K = x_w1.shape
axis = x_w1.axis if isinstance(device, tuple) else None
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
fxn=_custom_silu_mul_bwd_mxfp8)
return (None, None, None, gx1.uop, gx3.uop)
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x_w1.shape
scale_K = K // BLK
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
return fp8_out, e8_out, si_out

View file

@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
step = THREADS_PER_WG // 2 step = THREADS_PER_WG // 2
while step: while step:
active = tid < step active = tid < step
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active) other = lds[(tid + step).valid(active)].load()
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier()) lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
step //= 2 step //= 2
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0]) amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])

View file

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import trange from tinygrad.helpers import trange, Context
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit
@ -22,7 +22,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
if allow_jit: train_step = TinyJit(train_step) if allow_jit: train_step = TinyJit(train_step)
with Tensor.train(): with Context(TRAINING=1):
losses, accuracies = [], [] losses, accuracies = [], []
for i in (t := trange(steps, disable=None)): for i in (t := trange(steps, disable=None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS)) samp = np.random.randint(0, X_train.shape[0], size=(BS))
@ -55,4 +55,3 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
acc, Y_test_pred = numpy_eval(Y_test, num_classes) acc, Y_test_pred = numpy_eval(Y_test, num_classes)
print("test set accuracy is %f" % acc) print("test set accuracy is %f" % acc)
return (acc, Y_test_pred) if return_predict else acc return (acc, Y_test_pred) if return_predict else acc

View file

@ -25,7 +25,7 @@
import unittest import unittest
import numpy as np import numpy as np
import torch import torch
from tinygrad import Tensor, dtypes, nn from tinygrad import Tensor, dtypes, nn, Context
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.helpers import DEV from tinygrad.helpers import DEV
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
@ -101,7 +101,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
# we don't need more of these # we don't need more of these
def test_dropout_rate_one(self): def test_dropout_rate_one(self):
with Tensor.train(): with Context(TRAINING=1):
out = Tensor.ones(100).dropout(1.0) out = Tensor.ones(100).dropout(1.0)
np.testing.assert_allclose(out.numpy(), np.zeros(100)) np.testing.assert_allclose(out.numpy(), np.zeros(100))
@ -109,7 +109,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
torch.nn.functional.dropout(torch.ones(10), -0.1, True) torch.nn.functional.dropout(torch.ones(10), -0.1, True)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
with Tensor.train(): with Context(TRAINING=1):
Tensor.ones(10).dropout(-0.1) Tensor.ones(10).dropout(-0.1)
class TestInputValidation(unittest.TestCase): class TestInputValidation(unittest.TestCase):

View file

@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
renderer=Device[Device.DEFAULT].renderer).src[2].src) renderer=Device[Device.DEFAULT].renderer).src[2].src)
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD]) num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
assert num_loads <= 4, "more load uops than needed" assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" assert num_loads >= 1, "expected at least one load uop"
@unittest.skip("this is handled at higher level now") @unittest.skip("this is handled at higher level now")
def test_upcast_cse(self): def test_upcast_cse(self):

View file

@ -14,7 +14,7 @@ from test.helpers import not_support_multi_device, needs_second_gpu, slow
@slow @slow
class TestNN(unittest.TestCase): class TestNN(unittest.TestCase):
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True): def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
with Tensor.train(training): with Context(TRAINING=training):
szs = [4, 8, 16, 32] szs = [4, 8, 16, 32]
for sz in szs: for sz in szs:
# create in tinygrad # create in tinygrad

View file

@ -41,7 +41,7 @@ class TestStunning(unittest.TestCase):
X_samp, Y_samp = X_train[samples], Y_train[samples] X_samp, Y_samp = X_train[samples], Y_train[samples]
vi = Variable('i', 0, samples.shape[0]-1) vi = Variable('i', 0, samples.shape[0]-1)
with Context(SPLIT_REDUCEOP=0): with Context(SPLIT_REDUCEOP=0):
with Tensor.train(): with Context(TRAINING=1):
losses = [] losses = []
for i in range(samples.shape[0]): for i in range(samples.shape[0]):
vib = vi.bind(i) vib = vi.bind(i)

View file

@ -1,5 +1,5 @@
import unittest import unittest
from tinygrad import Tensor, Variable, GlobalCounters from tinygrad import Tensor, Variable, GlobalCounters, Context
from tinygrad.uop.ops import sym_infer from tinygrad.uop.ops import sym_infer
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from examples.gpt2 import Attention from examples.gpt2 import Attention
@ -63,7 +63,7 @@ class TestSymbolicOps(unittest.TestCase):
self.test_attention(imin=4, imax=5, use_symbolic=True) self.test_attention(imin=4, imax=5, use_symbolic=True)
def test_attention_training(self): def test_attention_training(self):
with Tensor.train(): with Context(TRAINING=1):
self.test_attention(dropout_p=0.0) self.test_attention(dropout_p=0.0)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# symbolic shape dropout is not supported # symbolic shape dropout is not supported

View file

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import torch import torch
import unittest, copy, mmap, random, math, array import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes, nn from tinygrad import Tensor, Device, dtypes, nn, Context
from tinygrad.helpers import getenv, temp, mv_address from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat from hypothesis import given, settings, strategies as strat
@ -203,7 +203,7 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(x, y, atol=1e-5) np.testing.assert_allclose(x, y, atol=1e-5)
def test_dropout(self): def test_dropout(self):
with Tensor.train(): with Context(TRAINING=1):
n, rate = 1_000_000, 0.1 n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate) w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.numpy()) non_zeros = np.count_nonzero(w.numpy())

View file

@ -0,0 +1,67 @@
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import AdamW
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TinyNet:
def __init__(self):
self.x = Tensor(x_init.copy())
self.W = Tensor(W_init.copy())
self.m = Tensor(m_init.copy())
def forward(self):
out = self.x.matmul(self.W).relu()
out = out.log_softmax(1)
out = out.mul(self.m).add(self.m).sum()
return out
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__':
unittest.main()

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest, math import unittest
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.optimizers import Lamb from tensorflow.keras.optimizers import Lamb
@ -7,11 +7,11 @@ from tensorflow.python.ops import math_ops
from extra.lr_scheduler import LRSchedulerGroup from extra.lr_scheduler import LRSchedulerGroup
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
np.random.seed(1337) np.random.seed(1337)
@ -173,48 +173,5 @@ class ExternalTestOptim(unittest.TestCase):
'warmup': steps_per_epoch * warmup_epochs, 'warmup': steps_per_epoch * warmup_epochs,
}, 1e-5, 1e-5, do_optim=False) }, 1e-5, 1e-5, do_optim=False)
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -1,6 +1,6 @@
import unittest, os import unittest, os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from tinygrad import Tensor from tinygrad import Context
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from examples.mlperf.model_train import train_stable_diffusion from examples.mlperf.model_train import train_stable_diffusion
@ -14,7 +14,7 @@ class TestTrain(unittest.TestCase):
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion" if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
with TemporaryDirectory(prefix="test-train") as tmp: with TemporaryDirectory(prefix="test-train") as tmp:
os.environ["UNET_CKPTDIR"] = tmp os.environ["UNET_CKPTDIR"] = tmp
with Tensor.train(): with Context(TRAINING=1):
saved_ckpts = train_stable_diffusion() saved_ckpts = train_stable_diffusion()
expected_ckpt = f"{tmp}/{num_steps}.safetensors" expected_ckpt = f"{tmp}/{num_steps}.safetensors"
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt

View file

@ -3,7 +3,7 @@ import ast, pathlib, unittest
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tinygrad import Tensor from tinygrad import Tensor, Context
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from test.helpers import slow from test.helpers import slow
from extra.models.efficientnet import EfficientNet from extra.models.efficientnet import EfficientNet
@ -40,7 +40,7 @@ def preprocess(img, new=False):
return img return img
def _infer(model: EfficientNet, img): def _infer(model: EfficientNet, img):
with Tensor.train(False): with Context(TRAINING=0):
out = model.forward(Tensor(img)).argmax(axis=-1) out = model.forward(Tensor(img)).argmax(axis=-1)
return out.tolist() return out.tolist()

View file

@ -5,10 +5,11 @@ import numpy as np
from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import Context
from extra.datasets import fetch_mnist from extra.datasets import fetch_mnist
def compare_tiny_torch(model, model_torch, X, Y): def compare_tiny_torch(model, model_torch, X, Y):
with Tensor.train(): with Context(TRAINING=1):
model_torch.train() model_torch.train()
model_state_dict = get_state_dict(model) model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters(): for k,v in model_torch.named_parameters():

View file

@ -106,7 +106,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_train_mnist(self): def test_train_mnist(self):
from examples.beautiful_mnist import Model from examples.beautiful_mnist import Model
with Tensor.train(): with Context(TRAINING=1):
model = Model() model = Model()
optimizer = optim.Adam(get_parameters(model)) optimizer = optim.Adam(get_parameters(model))
BS = 32 BS = 32
@ -125,7 +125,7 @@ class TestRealWorld(unittest.TestCase):
def test_forward_cifar(self): def test_forward_cifar(self):
BS = 32 BS = 32
# with training batchnorm still though # with training batchnorm still though
with Tensor.train(): with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
@TinyJit @TinyJit
def run(X): return model(X) def run(X): return model(X)
@ -133,7 +133,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_train_cifar(self): def test_train_cifar(self):
with Tensor.train(): with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15) optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
BS = 32 BS = 32
@ -151,7 +151,7 @@ class TestRealWorld(unittest.TestCase):
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16") @unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_train_cifar_hyp(self): def test_train_cifar_hyp(self):
dtypes.default_float = dtypes.float16 dtypes.default_float = dtypes.float16
with Tensor.train(): with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay']) optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
initial_div_factor = hyp['opt']['initial_div_factor'] initial_div_factor = hyp['opt']['initial_div_factor']
@ -163,7 +163,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_bert(self): def test_bert(self):
with Tensor.train(): with Context(TRAINING=1):
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2, args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2} "max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
model = BertForPretraining(**args_tiny) model = BertForPretraining(**args_tiny)

View file

@ -1093,14 +1093,14 @@ class TestSchedule(unittest.TestCase):
#@unittest.skip("may want to reconsider this") #@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self): def test_fold_batchnorm(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(1,32,4,4) img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img) out = bn(img)
check_schedule(out, 3, nn.state.get_parameters(bn)) check_schedule(out, 3, nn.state.get_parameters(bn))
def test_fold_conv_batchnorm_notrain(self): def test_fold_conv_batchnorm_notrain(self):
with Tensor.train(False): with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=True) bn = nn.BatchNorm2d(32, track_running_stats=True)
@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)]) check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm_notrain_no_running_stats(self): def test_fold_conv_batchnorm_notrain_no_running_stats(self):
with Tensor.train(False): with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1116,7 +1116,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)]) check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm(self): def test_fold_conv_batchnorm(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1125,7 +1125,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self, adam=False): def test_fold_conv_batchnorm_optim(self, adam=False):
optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15) optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15)
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.ones(1,3,4,4) img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1139,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True) def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
def test_fold_batchnorm_backward(self): def test_fold_batchnorm_backward(self):
with Tensor.train(): with Context(TRAINING=1):
x = Tensor.empty((2, 16, 8, 8)).contiguous() x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16) bn = nn.BatchNorm2d(16)
fw = bn(x).contiguous_backward().relu().contiguous() fw = bn(x).contiguous_backward().relu().contiguous()
@ -1484,7 +1484,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4) check_schedule(out, 4)
def test_adam_step_fusion(self): def test_adam_step_fusion(self):
with Tensor.train(): with Context(TRAINING=1):
x = Tensor.empty(4, 64, 32) x = Tensor.empty(4, 64, 32)
layer = nn.Linear(32, 32*4) layer = nn.Linear(32, 32*4)
_realize_weights(layer) _realize_weights(layer)
@ -1494,7 +1494,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13) check_schedule(opt.schedule_step(), 13)
def test_adam_conv_fuse(self): def test_adam_conv_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
_realize_weights(c1) _realize_weights(c1)
@ -1505,7 +1505,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13) check_schedule(opt.schedule_step(), 13)
def test_adam_2convs_fuse(self): def test_adam_2convs_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1517,7 +1517,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15) check_schedule(opt.schedule_step(), 15)
def test_sgd_conv_fuse(self): def test_sgd_conv_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
_realize_weights(c1) _realize_weights(c1)
@ -1527,7 +1527,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 5) # TODO: 3? check_schedule(opt.schedule_step(), 5) # TODO: 3?
def test_sgd_2convs_fuse(self): def test_sgd_2convs_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1538,7 +1538,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 7) check_schedule(opt.schedule_step(), 7)
def test_fold_2convs_sgd_nesterov_momentum_wd(self): def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1550,7 +1550,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 11) check_schedule(opt.schedule_step(), 11)
def test_sgd_4convs_fuse(self): def test_sgd_4convs_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16) img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False) c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False)
@ -1563,7 +1563,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15) check_schedule(opt.schedule_step(), 15)
def test_sgd_4convs_fuse_conv_bw(self): def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16) img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False) c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False)
@ -1664,7 +1664,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0) self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0)
def test_resnet_block(self): def test_resnet_block(self):
with Tensor.train(False): with Context(TRAINING=0):
in_planes, planes = 64, 64 in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes) bn1 = nn.BatchNorm2d(planes)

View file

@ -16,41 +16,17 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
def get_gated_load_uop(valid:UOp, idx:UOp): def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, ( return UOp(Ops.LOAD, dtypes.float, (
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True), UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
)) ))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]): def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), ( return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True), UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
)) ))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr) def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(nmax, n) def Range(n, nmax): return UOp.range(nmax, n)
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
f3 = (idx2*2)+ridx1+(-1)
self.assertFalse(f0.is_increasing())
self.assertTrue(f1.is_increasing())
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())
class TestValidIdxSimplification(unittest.TestCase): class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid, extra=()): def check(self, load, sidx, svalid, extra=()):
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0] load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
@ -506,6 +482,16 @@ class TestImageSimplification(unittest.TestCase):
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))", self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0") "(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
def test_drop_non_monotonic_window(self):
# two-sided window valid (645 <= gidx0 < 653) on a non-monotonic index (lane split via %4 and //4):
# gidx0 outside the window pushes idx_x out of the (1, 48) image, so the gate is dropped
gidx0 = Special("gidx0", 1064)
r12 = Range(12, 3)
valid = ((gidx0 < 645).ne(True)) & (gidx0 < 653)
idx = (r12*4 + (gidx0+3)%4 + (gidx0+3)//4*24 - 3888, UOp.const(dtypes.weakint, 0))
load = get_load_image_uop((1, 48, 4), valid, idx)
self.check(load, None, "(r12*4+(gidx0+3)%4+(gidx0+3)//4*24+-3888)", "0")
class TestDropTrueGate(unittest.TestCase): class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self): def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True valid gets simplified to drop the valid # test that INDEX with a constant True valid gets simplified to drop the valid

View file

@ -1,7 +1,7 @@
# tensor tests that pass on NULL backend (no copyout needed) # tensor tests that pass on NULL backend (no copyout needed)
import numpy as np import numpy as np
import unittest import unittest
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import Ops, UOp from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
@ -15,7 +15,7 @@ m_init = np.random.randn(1,3).astype(np.float32)
class TestTrainMode(unittest.TestCase): class TestTrainMode(unittest.TestCase):
def test_train_mode(self): def test_train_mode(self):
assert not Tensor.training assert not Tensor.training
@Tensor.train() @Context(TRAINING=1)
def f(): def f():
assert Tensor.training assert Tensor.training
f() f()

View file

@ -207,7 +207,7 @@ class TestMultiTensor(unittest.TestCase):
out.numpy() out.numpy()
def test_backprop_conv(self): def test_backprop_conv(self):
with Tensor.train(): with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2) for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.Adam(get_parameters(conv)) optim = nn.optim.Adam(get_parameters(conv))
@ -511,7 +511,7 @@ class TestMultiTensor(unittest.TestCase):
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0) def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_dropout_on_shard(self): def test_dropout_on_shard(self):
with Tensor.train(): with Context(TRAINING=1):
X = Tensor.ones(256).to(devices_2) X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5).numpy() output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True) unique, counts = np.unique(output, return_counts=True)
@ -519,7 +519,7 @@ class TestMultiTensor(unittest.TestCase):
assert 96 < counts[0] < 160, counts[0] assert 96 < counts[0] < 160, counts[0]
def test_dropout_on_shard_axis(self): def test_dropout_on_shard_axis(self):
with Tensor.train(): with Context(TRAINING=1):
X = Tensor.ones(512).shard(devices_2, axis=0) X = Tensor.ones(512).shard(devices_2, axis=0)
output = X.dropout(0.5).numpy() output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True) unique, counts = np.unique(output, return_counts=True)
@ -664,7 +664,7 @@ class TestBatchNorm(unittest.TestCase):
def setUp(self): pass def setUp(self): pass
def test_unsynced_backprop_conv_bn(self): def test_unsynced_backprop_conv_bn(self):
with Tensor.train(): with Context(TRAINING=1):
from extra.lr_scheduler import OneCycleLR from extra.lr_scheduler import OneCycleLR
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)] convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
@ -709,7 +709,7 @@ class TestBatchNorm(unittest.TestCase):
bn_ts.append(bni) bn_ts.append(bni)
return bn_ts[0].cat(*bn_ts[1:]) return bn_ts[0].cat(*bn_ts[1:])
with Tensor.train(): with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
bn = BatchNorm(16) bn = BatchNorm(16)
@ -731,7 +731,7 @@ class TestBatchNorm(unittest.TestCase):
from examples.hlb_cifar10 import UnsyncedBatchNorm from examples.hlb_cifar10 import UnsyncedBatchNorm
GPUS = (d1, d2) GPUS = (d1, d2)
with Tensor.train(): with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS)) bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
@ -756,7 +756,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0) x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0)
with Tensor.train(is_training): with Context(TRAINING=is_training):
bns = [] bns = []
for _ in range(len(devices)): for _ in range(len(devices)):
bn = nn.BatchNorm2d(8) bn = nn.BatchNorm2d(8)
@ -777,7 +777,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train(): with Context(TRAINING=1):
synced_bn = BatchNorm2d(8) synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices)) unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))

View file

@ -13,7 +13,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType
# import all pattern matchers here # import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps, get_simplifying_rewrite_patterns
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
@ -23,6 +23,7 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
from tinygrad.codegen.late.coalese import memory_coalesing
pm_index_is_shrink = PatternMatcher([ pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK # rewrite non-image INDEX to SHRINK
@ -52,6 +53,10 @@ pm_number_params = PatternMatcher([
(UPat(Ops.PARAM, name="x"), do_number_param), (UPat(Ops.PARAM, name="x"), do_number_param),
]) ])
pm_no_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int))
])
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast)) if DEBUG >= 5: print(pyrender(ast))
@ -78,7 +83,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = apply_opts(sink, ren, beam=ast.arg.beam) sink = apply_opts(sink, ren, beam=ast.arg.beam)
# ** expander (expand_rewrite) ** # ** expander (expand_rewrite) **
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic") sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
# expand # expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander") sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
@ -113,18 +118,23 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# optional pre matcher # optional pre matcher
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher") if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
# decompositions # floordiv+mod / dtype decomp (early)
supported_ops = tuple(ren.code_for_op.keys()) supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV)) pm_decomp = symbolic_simple+get_simplifying_rewrite_patterns(supported_ops)
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2) sink = graph_rewrite(sink, pm_decomp, name="early decompositions")
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes") sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
# GEP/STACK stuff # do memory coalesing (late)
sink = memory_coalesing(sink, ren)
# instruction selection decompositions
pm_decomp = pm_decomp+\
get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))+\
get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="late decompositions")
# this is new style (TODO: this should all be removed)
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack") sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
# this is new style
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink") sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style") sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
@ -133,7 +143,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# final rules for the renderer (without sym) # final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([]) extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends+pm_no_weakints
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite") sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
# this was the linearizer # this was the linearizer

View file

@ -0,0 +1,73 @@
from typing import Any
import itertools
from collections import defaultdict
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import getenv
from tinygrad.renderer import Renderer
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
if getenv("DMC"): return sink
# collect
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
for u in sink.toposort():
# TODO: this should handle images too, it's just memory coalesing
if u.op in {Ops.LOAD, Ops.STORE} and not isinstance(u.src[0].src[0].dtype, ImageDType):
assert len(u.src) == (2 if u.op is Ops.STORE else 1), "memory coalesing does not support gated loads/stores"
assert u.src[0].op is Ops.INDEX
buf, idx_u = u.src[0].src
if buf.addrspace == AddrSpace.REG: continue
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
# build replacements
replacements = {}
for (op,buf,base,valid),offsets in memory.items():
# allowed lengths (copied in)
lengths = []
must_divide = True
if ctx is not None and ctx.target.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded
# do the grouping
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for full_grp in grouped_offsets:
while len(full_grp):
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
grp = full_grp[:length]
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op == Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1, f"attempting multiple stores: {len(offsets[g])}"
datas.append(offsets[g][0].src[1])
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
store = idx.store(data, valid) if valid is not None else idx.store(data)
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
else:
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
for i,g in enumerate(grp):
for oo in offsets[g]:
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
full_grp = full_grp[length:]
# apply
return sink.substitute(replacements, name="memory coalesing")

View file

@ -14,7 +14,7 @@ from tinygrad.renderer import Renderer
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]: def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
# can drop valid if idx is out of bound when valid is False # can drop valid if idx is out of bound when valid is False
drop_stmt = [] drop_stmt = []
for stmt in valid.split_uop(Ops.AND): for i,stmt in enumerate(valid.split_uop(Ops.AND)):
if (res:=parse_valid(stmt)) is None: continue if (res:=parse_valid(stmt)) is None: continue
X, is_upper_bound, c = res X, is_upper_bound, c = res
@ -25,12 +25,12 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
drop_stmt.append(stmt) drop_stmt.append(stmt)
continue continue
# if X <= c, check if it's out of bound when X = c+1 # check if idx is out of bound when X is on the wrong side of the bound: X in [c+1, vmax] or [vmin, c-1]
# if X >= c, check if it's out of bound when X = c-1 lo, hi = (c + 1, X.vmax) if is_upper_bound else (X.vmin, c - 1)
test_value = c + 1 if is_upper_bound else c - 1 if lo <= hi:
for i,b in zip(idx.src, (width, height)): fake = UOp.variable(f"fake{i}", lo, hi, X.dtype)
if i.is_increasing(): for coord,b in zip(idx.src, (width, height)):
rw = i.substitute({X:X.const_like(test_value)}) rw = coord.substitute({X:fake}).simplify()
if rw.vmin >= b or rw.vmax < 0: if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt) drop_stmt.append(stmt)
break break
@ -162,18 +162,8 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# determine fold lengths # determine fold lengths
lengths = [] lengths = []
must_divide = True must_divide = True
if ctx is not None and ctx.target.device == "DSP": # TODO: this belongs in coalese
lengths = [128,64,32,16,8,4] if isinstance(buf.dtype, ImageDType): lengths = [4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded lengths.append(1) # worst case, it's not folded
# filter fold lengths that don't divide # filter fold lengths that don't divide

View file

@ -101,6 +101,12 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# for Schedule, we check if the range is used in INDEX gates or WHERE gates # for Schedule, we check if the range is used in INDEX gates or WHERE gates
is_masked = k.rngs[axis] in where_gate_rngs is_masked = k.rngs[axis] in where_gate_rngs
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7: if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
# upcasting a masked global axis moves that range out of the launch grid into each work-item
# under IMAGE, skip the upcast unless enough global work-items remain after it to hide memory latency
if IMAGE and k.axis_types[axis] is AxisType.GLOBAL:
global_upcast = prod(k.full_shape[i] for i in to_upcast if k.axis_types[i] is AxisType.GLOBAL) * k.full_shape[axis]
global_items_after = prod(k.full_shape[i] for i in k.axes_of(AxisType.GLOBAL)) // global_upcast
if resolve(global_items_after < getenv("OCCUPANCY_FLOOR", 4096), False): continue
if DEBUG >= 4: print(f"upcasting masked axis : {axis}") if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis) to_upcast.append(axis)
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0)) for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))

View file

@ -240,6 +240,7 @@ DEV, DEBUG, BEAM, NOOPT = _DEV("DEV", ""), ContextVar("DEBUG", 0), ContextVar("B
IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0) IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0)
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32) JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0) WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0)
TRAINING = ContextVar("TRAINING", 0)
USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0) USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0)
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0) TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)

View file

@ -6,7 +6,7 @@ from tinygrad.mixin.movement import MovementMixin
from tinygrad.mixin.reduce import ReduceMixin from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DTypeLike, Invalid, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod
from tinygrad.helpers import resolve_pool_pads, round_up from tinygrad.helpers import resolve_pool_pads, round_up
@ -17,9 +17,6 @@ ReductionStr = Literal["mean", "sum", "none"]
class OpMixin(ElementwiseMixin, ReduceMixin): class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory") def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory")
def item(self) -> PyConst: def item(self) -> PyConst:
@ -34,48 +31,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
assert self.numel() == 1, "must have one element for item" assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)] return self.data()[(0,) * len(self.shape)]
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Self]) -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
@classmethod
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
```
"""
# TODO: enable this check
# if not buffer: assert device is None, "buffer=False does not support device specification"
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def __getitem__(self, indices) -> Self: def __getitem__(self, indices) -> Self:
""" """
Retrieves a sub-tensor using indexing. Retrieves a sub-tensor using indexing.
@ -207,40 +162,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops))) vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self) return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
@classmethod @classmethod
def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self: def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self:
""" """

View file

@ -1,10 +1,24 @@
from typing import Self from typing import TYPE_CHECKING, Callable, Self
from tinygrad.dtype import ConstType, DType, DTypeLike from tinygrad.dtype import ConstType, DTypeLike, Invalid, dtypes, to_dtype
from tinygrad.helpers import argfix
from tinygrad.mixin.dtype import DTypeMixin from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin
if TYPE_CHECKING:
from tinygrad.uop.ops import sint, UOp
class CreationMixin(DTypeMixin, MovementMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
class CreationMixin(DTypeMixin):
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b)) def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
def _multi_like(self, fxn:'Callable[[tuple[sint, ...], str|None], Self]') -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self: def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
""" """
Creates an empty tensor with the same shape as `self`. Creates an empty tensor with the same shape as `self`.
@ -12,9 +26,76 @@ class CreationMixin(DTypeMixin):
""" """
return self._wrap_uop(self._uop.empty_like(dtype, device)) return self._wrap_uop(self._uop.empty_like(dtype, device))
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self: @classmethod
"""Creates a tensor with the same shape as `self`, filled with the given value.""" def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype) """
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:'tuple[sint, ...]', fill_value:'ConstType|UOp', dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
```
"""
# TODO: enable this check
# if not buffer: assert device is None, "buffer=False does not support device specification"
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def full_like(self, fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: type(self).full(shape, fill_value, dtype=dtype or self.dtype, device=dev, buffer=buffer))
return type(self).full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device, buffer=buffer)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
def zeros_like(self, **kwargs) -> Self: def zeros_like(self, **kwargs) -> Self:
""" """
@ -29,6 +110,23 @@ class CreationMixin(DTypeMixin):
""" """
return self.full_like(0, **kwargs) return self.full_like(0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
def ones_like(self, **kwargs) -> Self: def ones_like(self, **kwargs) -> Self:
""" """
Creates a tensor with the same shape as `self`, filled with ones. Creates a tensor with the same shape as `self`, filled with ones.

View file

@ -45,7 +45,11 @@ class ElementwiseMixin(CreationMixin):
""" """
return self.cast(dtypes.bool).ne(True) return self.cast(dtypes.bool).ne(True)
def contiguous(self, *args, **kwargs) -> Self: raise NotImplementedError def contiguous(self, **kwargs) -> Self:
"""
Returns a contiguous tensor.
"""
return self._wrap_uop(self._uop.contiguous(**kwargs))
def contiguous_backward(self) -> Self: def contiguous_backward(self) -> Self:
""" """

View file

@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
import math import math
from typing import Self, cast from typing import Self, cast
from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype from tinygrad.dtype import DType, DTypeLike, dtypes, least_upper_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, prod from tinygrad.helpers import all_int, argfix, ceildiv, prod, TRAINING
from tinygrad.mixin import OpMixin from tinygrad.mixin import OpMixin
from tinygrad.device import canonicalize_device from tinygrad.device import canonicalize_device
@ -273,3 +273,54 @@ class RandMixin(OpMixin):
# Efraimidis-Spirakis # Efraimidis-Spirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1] indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32) return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
def dropout(self, p=0.5) -> Self:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `TRAINING` is set (e.g. inside `Context(TRAINING=1)`).
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Context(TRAINING=1):
print(t.dropout().numpy())
```
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not TRAINING or p == 0: return self
if p == 1: return self.const_like(0)
return (self.rand_like(dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def scaled_dot_product_attention(self, key:Self, value:Self, attn_mask:Self|None=None, dropout_p:float=0.0,
is_causal:bool=False, enable_gqa:bool=False) -> Self:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if enable_gqa:
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
q = self
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value

View file

@ -22,6 +22,7 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \ (UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \
if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None), if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: ctx[x.src[0]] if x.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else None),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"),
# GPU stuff # GPU stuff

View file

@ -18,7 +18,7 @@ import sys
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
def add_ranges_to_store(ctx, x): def add_ranges_to_store(ctx, x):
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == (): return None if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == () or x.src[0].max_numel() == x.src[1].max_numel() == 1: return None
assert x.src[0].shape == x.src[1].shape, "bad store shape" assert x.src[0].shape == x.src[1].shape, "bad store shape"
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape] idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs) return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)

View file

@ -1,14 +1,13 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations from __future__ import annotations
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
from contextlib import ContextDecorator from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, to_dtype
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.helpers import suppress_finalizing, disable_gc, TRAINING
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
from tinygrad.mixin.rand import RandMixin from tinygrad.mixin.rand import RandMixin
from tinygrad.schedule import create_linear_with_vars from tinygrad.schedule import create_linear_with_vars
@ -59,19 +58,25 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
assert isinstance(ret, Tensor), "sum didn't return a Tensor" assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret return ret
class Tensor(RandMixin): # TODO: deprecate this, always use TRAINING
class TensorMeta(type):
@property
def training(cls) -> bool: return bool(TRAINING.value)
@training.setter
def training(cls, mode:bool): TRAINING.value = int(mode)
class Tensor(RandMixin, metaclass=TensorMeta):
""" """
A `Tensor` is a multi-dimensional matrix containing elements of a single data type. A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
```python exec="true" session="tensor" ```python exec="true" session="tensor"
from tinygrad import Tensor, dtypes, nn from tinygrad import Tensor, dtypes, nn, Context
import numpy as np import numpy as np
import math import math
np.set_printoptions(precision=4) np.set_printoptions(precision=4)
``` ```
""" """
__slots__ = "uop", "is_param", "grad" __slots__ = "uop", "is_param", "grad"
training: ClassVar[bool] = False
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None, def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
device:str|tuple|list|None=None, dtype:DTypeLike|None=None): device:str|tuple|list|None=None, dtype:DTypeLike|None=None):
@ -125,9 +130,9 @@ class Tensor(RandMixin):
@suppress_finalizing @suppress_finalizing
def __del__(self): all_tensors.pop(weakref.ref(self), None) def __del__(self): all_tensors.pop(weakref.ref(self), None)
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, extra_args=(), **kwargs) -> Tensor: def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, **kwargs) -> Tensor:
srcs = (self,)+x srcs = (self,)+x
new_uop: UOp = fxn(*[t.uop for t in srcs], *extra_args, **kwargs) new_uop: UOp = fxn(*[t.uop for t in srcs], **kwargs)
if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,) if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
# directly create the Tensor # directly create the Tensor
ret = Tensor.__new__(Tensor) ret = Tensor.__new__(Tensor)
@ -148,11 +153,6 @@ class Tensor(RandMixin):
self.is_param = is_param self.is_param = is_param
return self return self
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
def __repr__(self): def __repr__(self):
ld = self.uop ld = self.uop
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>" ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>"
@ -264,6 +264,7 @@ class Tensor(RandMixin):
x = self.cast(self.dtype.base).contiguous() x = self.cast(self.dtype.base).contiguous()
if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU") if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU")
return cast(Buffer, x.realize().uop.buffer).ensure_allocated() return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
def _data(self) -> memoryview: return self._buffer().as_memoryview() def _data(self) -> memoryview: return self._buffer().as_memoryview()
def data(self) -> memoryview: def data(self) -> memoryview:
@ -279,7 +280,7 @@ class Tensor(RandMixin):
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}" assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12) assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape) return self._data().cast(self.dtype.base.fmt, self.shape)
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions) # NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
def tolist(self) -> PyConst|list[Any]: def tolist(self) -> PyConst|list[Any]:
@ -503,25 +504,6 @@ class Tensor(RandMixin):
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff)) high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
return Tensor._device_seeds[device], low.cat(high) return Tensor._device_seeds[device], low.cat(high)
# ***** creation helper functions *****
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
# ***** toposort and backward pass ***** # ***** toposort and backward pass *****
def backward(self, gradient:Tensor|None=None) -> Tensor: def backward(self, gradient:Tensor|None=None) -> Tensor:
@ -548,7 +530,7 @@ class Tensor(RandMixin):
# ***** movement ops ***** # ***** movement ops *****
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg) def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, op=op, arg=arg)
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis) def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None: def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
@ -737,14 +719,6 @@ class Tensor(RandMixin):
if IMAGE: return self.image_dot(w, dtype) if IMAGE: return self.image_dot(w, dtype)
return super().dot(w, dtype) return super().dot(w, dtype)
# ***** unary ops *****
def contiguous(self, *args, **kwargs) -> Tensor:
"""
Returns a contiguous tensor.
"""
return self._apply_uop(UOp.contiguous, extra_args=args, **kwargs)
# ***** broadcasted elementwise ops ***** # ***** broadcasted elementwise ops *****
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor: def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
@ -804,59 +778,6 @@ class Tensor(RandMixin):
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec") fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos))) return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)))
# ***** functional nn ops *****
def dropout(self, p=0.5) -> Tensor:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `Tensor.training` is `True`.
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Tensor.train():
print(t.dropout().numpy())
```
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not Tensor.training or p == 0: return self
if p == 1: return self.const_like(0)
return (Tensor.rand_like(self, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if enable_gqa:
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
q = self
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
# ***** cast ops ***** # ***** cast ops *****
def bitcast(self, dtype:DTypeLike) -> Tensor: def bitcast(self, dtype:DTypeLike) -> Tensor:

View file

@ -454,7 +454,8 @@ def floormod_to_mod(a:UOp, b:UOp) -> UOp:
powers_of_two: dict[int, int] = {2**i:i for i in range(64)} powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
@functools.cache @functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher: def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
# these are rewrites that make things simpler
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)] pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod # FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)) if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
@ -463,6 +464,11 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32)) if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends) # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]))) if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
return PatternMatcher(pat)
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
pat: list[tuple[UPat, Callable]] = []
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(), if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())] lambda x,y: (x | y).logical_not())]
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y) # rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)

View file

@ -919,12 +919,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# *** uop symbolic stuff *** # *** uop symbolic stuff ***
def is_increasing(self:UOp) -> bool:
# is f a monotonically increasing function regards its input
if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
return False # False if not sure
def const_factor(self) -> int: def const_factor(self) -> int:
"""largest known int that divides self""" """largest known int that divides self"""
# TODO: for negatives it's not the largest # TODO: for negatives it's not the largest
@ -1190,8 +1184,9 @@ python_alu: dict[Ops, Callable] = {
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq} Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True): def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
if dtype.count > 1: if any(isinstance(x, tuple) for x in operands):
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)]) count = max(len(x) for x in operands if isinstance(x, tuple))
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(count)])
if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid
alu = python_alu[op](*operands) alu = python_alu[op](*operands)
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
@ -1672,6 +1667,8 @@ pm_lower_index_dtype = PatternMatcher([
lambda var,val: var.bind(val).cast(dtypes.weakint)), lambda var,val: var.bind(val).cast(dtypes.weakint)),
# remove hanging casts # remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("slen", dtypes.ints).cast(),), name="shrink"),
lambda shrink,buf,idx,slen: shrink.replace(src=(buf,idx,slen))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)), lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
# remove hanging casts for images # remove hanging casts for images

View file

@ -121,6 +121,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([
# TODO: combine this with "# rules for threefry" below # TODO: combine this with "# rules for threefry" below
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"), ((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None), lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
((UPat.var("x") & UPat.cvar("mask")) // UPat.cvar("c"),
lambda x,mask,c: x // c.arg if c.arg > 0 and c.arg & (c.arg-1) == 0 and mask.arg | (c.arg-1) == -1 else None),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"), (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# ** constant folding ** # ** constant folding **
@ -160,6 +162,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x), (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x), (((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
# ** simple where folding ** # ** simple where folding **
# a conditional with the same results either way is a noop, also fold const conditionals # a conditional with the same results either way is a noop, also fold const conditionals
@ -167,6 +170,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), (UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# a.where(b.where(c, d), d) -> (a & b).where(c, d) # a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)), (UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
# STACK on INDEX CONST (TODO: remove all the GEP crap)
(UPat(Ops.STACK, src=UPat(Ops.INDEX, src=(UPat.var("src"), UPat(Ops.CONST))), name="stk"),
lambda src,stk: src if stk.shape == src.shape and list(range(len(stk.src))) == [x.src[1].arg for x in stk.src] else None),
]) ])
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ******** # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********