Compare commits

..

39 commits

Author SHA1 Message Date
George Hotz
d319b5f614
Merge branch 'master' into codegen2 2026-06-21 19:18:24 -07:00
George Hotz
0decf136fe
Merge branch 'master' into codegen2 2026-06-19 18:29:13 -07:00
George Hotz
a96db15419 fixes 2026-06-19 17:57:59 -07:00
George Hotz
b0f7e2a5e1 fix imports 2026-06-19 17:44:39 -07:00
George Hotz
566fe4c7dc
Merge branch 'master' into codegen2 2026-06-19 16:57:09 -07:00
George Hotz
c2f73b102e
Merge branch 'master' into codegen2 2026-06-19 12:39:00 -07:00
George Hotz
c6f195f410 fixes 2026-06-18 18:08:14 -07:00
George Hotz
f26378264b
Merge branch 'master' into codegen2 2026-06-18 18:03:15 -07:00
George Hotz
1168ed9730
Merge branch 'master' into codegen2 2026-06-17 00:37:09 -07:00
George Hotz
017edbbbb5 param -1 2026-06-16 21:52:07 -07:00
George Hotz
daa72812b0 add gpu dims 2026-06-16 21:37:59 -07:00
George Hotz
fd325d662c
Merge branch 'master' into codegen2 2026-06-16 21:29:09 -07:00
George Hotz
8d36539656 test tiny passes 2026-06-16 14:57:12 -07:00
George Hotz
db2c71536b almost passing 2026-06-16 13:29:38 -07:00
George Hotz
4112b34a32 closer 2026-06-16 13:23:30 -07:00
George Hotz
1ad72dff08 more passing 2026-06-16 12:54:54 -07:00
George Hotz
6f1eaa8d46 fixes 2026-06-16 12:38:17 -07:00
George Hotz
35d2882991 no vec 2026-06-16 10:47:19 -07:00
George Hotz
a31732d819
Merge branch 'master' into codegen2 2026-06-16 10:33:34 -07:00
George Hotz
43d62c4211 hreduce 2026-06-16 09:36:47 -07:00
George Hotz
4d0429090c split reduce types 2026-06-16 09:27:09 -07:00
George Hotz
2c7a1450e7 fix reduce 2026-06-16 08:40:00 -07:00
George Hotz
6ffb55cc74
Merge branch 'master' into codegen2 2026-06-15 17:19:25 -07:00
George Hotz
1a280829ca
Merge branch 'master' into codegen2 2026-06-15 12:48:46 -07:00
George Hotz
3b426b1072 devec 2026-06-15 08:57:52 -07:00
George Hotz
ce2cdc3708
Merge branch 'master' into codegen2 2026-06-14 16:43:48 -07:00
George Hotz
333f062eee new expander 2026-06-14 13:54:13 -07:00
George Hotz
0d5bf3ca6d revert that 2026-06-14 13:28:28 -07:00
George Hotz
56bad940df disable that 2026-06-14 13:28:02 -07:00
George Hotz
f98deb9250 preprocess 2026-06-14 13:24:19 -07:00
George Hotz
bdfcb1cb98 test ops passes 2026-06-14 12:58:18 -07:00
George Hotz
a6fdb53a1e
Merge branch 'master' into codegen2 2026-06-14 10:09:00 -07:00
George Hotz
49deb9714b test_tiny passes 2026-06-14 09:36:51 -07:00
George Hotz
afab220947
Merge branch 'master' into codegen2 2026-06-14 08:52:36 -07:00
George Hotz
a7523b2596 simpler 2026-06-13 10:40:52 -07:00
George Hotz
21806848df improve new codegen 2026-06-12 20:08:20 -07:00
George Hotz
6fda6c704d
Merge branch 'master' into codegen2 2026-06-12 20:01:43 -07:00
George Hotz
3f7ec187df work 2026-06-12 19:24:56 -07:00
George Hotz
af9284e9b1 try for a full rewrite of codegen 2026-06-12 19:11:54 -07:00
54 changed files with 686 additions and 660 deletions

View file

@ -133,26 +133,46 @@ jobs:
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20 run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
- name: Test IMAGE support - name: Test IMAGE support
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
- name: Test emulated tensor cores - name: Test emulated METAL tensor cores
env: env:
DEBUG: 2 DEV: 'PYTHON::METAL'
N: 64
CNT: 1
SHOULD_USE_TC: 1
run: | run: |
parallel -k --link --tagstring '[{1}]' '{2} python3 ./extra/gemm/simple_matmul.py' \ DEBUG=2 python3 test/backend/test_ops.py TestOps.test_big_gemm
::: metal gfx950 gfx1100 gfx1100_acchalf gfx1201 gfx1201_acchalf sm_75 sm_80_half sm_80_tf32 \ python3 -m pytest -nauto test/opt/test_tensor_cores.py
::: 'DEV=PYTHON::METAL' 'DEV=PYTHON::gfx950 HALF=1 ACC_HALF=0' \ - name: Test emulated AMD tensor cores
'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=1 ATOL=1e-3' \ env:
'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=1 ATOL=1e-3' \ DEV: 'PYTHON::gfx1100'
'DEV=PYTHON::sm_75 HALF=1' 'DEV=PYTHON::sm_80 HALF=1' 'DEV=PYTHON::sm_80 ALLOW_TF32=1'
- name: Run additional tensor core tests
run: | run: |
DEV=PYTHON::METAL python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEV=PYTHON::gfx1100 python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEV=PYTHON::gfx950 python3 -m pytest -nauto test/opt/test_tensor_cores.py DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
DEV=PYTHON::gfx1201 python3 -m pytest -nauto test/opt/test_tensor_cores.py DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD MFMA tensor cores
env:
DEV: 'PYTHON::gfx950'
run: |
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD RDNA4 tensor cores
env:
DEV: 'PYTHON::gfx1201'
run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated CUDA tensor cores
run: |
DEBUG=2 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test device flop counts
run: |
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
linter: linter:
@ -247,6 +267,13 @@ jobs:
run: python3 test/external/external_benchmark_schedule.py run: python3 test/external/external_benchmark_schedule.py
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
- name: Regen dataset on test_tiny
run: |
test/external/process_replay/reset.py
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
python extra/optimization/extract_dataset.py
gzip -c /tmp/sops > extra/datasets/sops.gz
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
- name: Repo line count < 25000 lines - name: Repo line count < 25000 lines
run: MAX_LINE_COUNT=25000 python sz.py run: MAX_LINE_COUNT=25000 python sz.py
@ -311,6 +338,31 @@ jobs:
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
testgpumisc:
name: CL Misc tests
runs-on: *linux
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: gen-dataset
deps: testing
opencl: 'true'
- name: Generate Dataset
run: DEV=CL extra/optimization/generate_dataset.sh
- name: Run Kernel Count Test
run: DEV=CL python -m pytest -n=auto test/external/external_test_opt.py
- name: Run fused optimizer tests
run: DEV=CL FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py test/backend/test_optim.py -k "not muon"
- name: Upload artifact
uses: actions/upload-artifact@v7
with:
name: sops.gz
path: /tmp/sops.gz
testopenpilot: testopenpilot:
name: openpilot Compile Tests name: openpilot Compile Tests
runs-on: *linux runs-on: *linux
@ -327,7 +379,7 @@ jobs:
llvm: 'true' llvm: 'true'
- name: Test openpilot model kernel count and gate usage - name: Test openpilot model kernel count and gate usage
run: | run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness) - name: Test openpilot CL compile fp32 (test correctness)
run: | run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
@ -370,6 +422,7 @@ jobs:
with: with:
key: optim key: optim
deps: testing deps: testing
pydeps: "tensorflow==2.19"
opencl: 'true' opencl: 'true'
#- name: Test Optimization Helpers #- name: Test Optimization Helpers
# run: DEBUG=1 python3 extra/optimization/test_helpers.py # run: DEBUG=1 python3 extra/optimization/test_helpers.py
@ -378,7 +431,7 @@ jobs:
- name: Test Beam Search - name: Test Beam Search
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Test MLPerf stuff - name: Test MLPerf stuff
run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20 run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
- name: DEV=NULL beautiful_mnist_multigpu - name: DEV=NULL beautiful_mnist_multigpu
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
- name: Test Bert training - name: Test Bert training

View file

@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
Throw in an optimizer, a data loader, and some compute, and you have all you need. Throw in an optimizer, a data loader, and some compute, and you have all you need.
```python ```python
from tinygrad import Tensor, nn, Context from tinygrad import Tensor, nn
class LinearNet: class LinearNet:
def __init__(self): def __init__(self):
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
with Context(TRAINING=1): with Tensor.train():
for i in range(10): for i in range(10):
optim.zero_grad() optim.zero_grad()
loss = model(x).sparse_categorical_crossentropy(y).backward() loss = model(x).sparse_categorical_crossentropy(y).backward()

View file

@ -165,14 +165,13 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network. Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64. We will be training for 1000 steps with a batch size of 64.
We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training. We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager. Upon exit, the flag is restored to its previous value by the context manager.
```python ```python
from tinygrad import Context
X_train, Y_train, X_test, Y_test = fetch_mnist() X_train, Y_train, X_test, Y_test = fetch_mnist()
with Context(TRAINING=1): with Tensor.train():
for step in range(1000): for step in range(1000):
# random sample a batch # random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64)) samp = np.random.randint(0, X_train.shape[0], size=(64))

View file

@ -1,6 +1,6 @@
from typing import Tuple from typing import Tuple
import time import time
from tinygrad import Tensor, TinyJit, nn, Context from tinygrad import Tensor, TinyJit, nn
import gymnasium as gym import gymnasium as gym
from tinygrad.helpers import trange from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import import numpy as np # TODO: remove numpy import
@ -55,7 +55,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]: def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Context(TRAINING=1): with Tensor.train():
log_dist, value = model(x) log_dist, value = model(x)
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float() action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()

View file

@ -122,7 +122,7 @@ if __name__ == "__main__":
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit @TinyJit
@Context(TRAINING=1) @Tensor.train()
def train_step(idxs:Tensor) -> Tensor: def train_step(idxs:Tensor) -> Tensor:
X, Y = X_train[idxs], Y_train[idxs] X, Y = X_train[idxs], Y_train[idxs]
if len(GPUS) > 1: if len(GPUS) > 1:

View file

@ -1,6 +1,6 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -19,7 +19,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit @TinyJit
@Context(TRAINING=1) @Tensor.train()
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor: def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])

View file

@ -1,6 +1,6 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -31,7 +31,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step() -> Tensor: def train_step() -> Tensor:
with Context(TRAINING=1): with Tensor.train():
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0 Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0

View file

@ -1,6 +1,6 @@
import itertools import itertools
from typing import Callable from typing import Callable
from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context from tinygrad import nn, Tensor, dtypes, Device, TinyJit
from tinygrad.helpers import getenv, trange, partition from tinygrad.helpers import getenv, trange, partition
class Model: class Model:
@ -59,7 +59,7 @@ if __name__ == "__main__":
Tensor.realize(*params, *buffers, *adam_params, loss, grads) Tensor.realize(*params, *buffers, *adam_params, loss, grads)
@TinyJit @TinyJit
@Context(TRAINING=1) @Tensor.train()
def microbatch(): def microbatch():
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0]) samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
for t in params: t.grad = None for t in params: t.grad = None

View file

@ -359,7 +359,7 @@ def train_cifar():
i = 0 i = 0
eval_acc_pct = 0.0 eval_acc_pct = 0.0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Context(TRAINING=1): with Tensor.train():
st = time.monotonic() st = time.monotonic()
while i <= STEPS: while i <= STEPS:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"): if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os, math, time import os, math, time
import numpy as np import numpy as np
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
@ -177,7 +177,7 @@ if __name__ == "__main__":
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0) if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit @TinyJit
@Context(TRAINING=1) @Tensor.train()
def step(x:Tensor, y:Tensor) -> Tensor: def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y) _, loss = model(x, y)
optimizer.zero_grad() optimizer.zero_grad()
@ -204,3 +204,4 @@ if __name__ == "__main__":
top_k = 40 top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist())) print(decode(y[0].tolist()))

View file

@ -1,5 +1,5 @@
# much taken from https://github.com/cloneofsimo/minRF # much taken from https://github.com/cloneofsimo/minRF
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context from tinygrad import Tensor, nn, GlobalCounters, TinyJit
from tinygrad.helpers import getenv, trange from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@ -135,7 +135,7 @@ if __name__ == "__main__":
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4) optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
@TinyJit @TinyJit
@Context(TRAINING=1) @Tensor.train()
def train_step(): def train_step():
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int') if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0]) else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])

View file

@ -358,7 +358,7 @@ def eval_stable_diffusion():
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape)) batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
return batch, unpadded_bs return batch, unpadded_bs
@Context(TRAINING=0) @Tensor.train(mode=False)
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL, def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]: inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
# Eval is divided into 5 jits, one per model # Eval is divided into 5 jits, one per model

View file

@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
from pathlib import Path from pathlib import Path
import multiprocessing import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@ -614,7 +614,7 @@ def train_retinanet():
if getenv("RESET_STEP", 1): _train_step.reset() if getenv("RESET_STEP", 1): _train_step.reset()
with Context(TRAINING=0): with Tensor.train(mode=False):
if not RUNMLPERF: if not RUNMLPERF:
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True)) i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
else: else:
@ -784,7 +784,7 @@ def train_unet3d():
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit @TinyJit
@Context(TRAINING=1) @Tensor.train()
def train_step(model, x, y): def train_step(model, x, y):
optim.zero_grad() optim.zero_grad()
@ -795,7 +795,7 @@ def train_unet3d():
optim.step() optim.step()
return loss.realize() return loss.realize()
@Context(TRAINING=0) @Tensor.train(mode=False)
def eval_step(model, x, y): def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS) y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y) y_hat, y = Tensor(y_hat), Tensor(y)
@ -1490,7 +1490,7 @@ def train_llama3():
return lr_cpu, grad_norm_cpu return lr_cpu, grad_norm_cpu
@TinyJit @TinyJit
@Context(TRAINING=0) @Tensor.train(False)
def eval_step(tokens:Tensor): def eval_step(tokens:Tensor):
if is_dp: tokens = tokens.to(None).shard(device, 0) if is_dp: tokens = tokens.to(None).shard(device, 0)
if is_mp: tokens = tokens.shard(device) if is_mp: tokens = tokens.shard(device)
@ -1803,7 +1803,7 @@ if __name__ == "__main__":
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN) elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
else: bench_log_manager = contextlib.nullcontext() else: bench_log_manager = contextlib.nullcontext()
with Context(TRAINING=1): with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","): for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
nm = f"train_{m}" nm = f"train_{m}"
if nm in globals(): if nm in globals():

View file

@ -38,7 +38,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None, def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None, x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]: grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
if not fp8: if not fp8:
if ASM_GEMM: if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
@ -47,14 +47,12 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)" assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
if MXFP8: if MXFP8:
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d) x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
if can_use_asm_gemm(x_q, w.T): if can_use_asm_gemm(x_q, w.T):
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale), out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
mx_w_stored=True).reshape(*l_shape, w.shape[0]) mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0])
else: else:
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1]) x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1])
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
return out, (amax_x.detach() if amax_x is not None else None), x_q return out, (amax_x.detach() if amax_x is not None else None), x_q
if x_fp8 is None: if x_fp8 is None:
@ -128,8 +126,10 @@ class FlatTransformer:
# FeedForward # FeedForward
if SPLIT_W13: if SPLIT_W13:
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim) if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim)
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim) else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02)
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
else: else:
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2) self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std) self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
@ -160,7 +160,7 @@ class FlatTransformer:
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None): def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
if w is None: if w is None:
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features) if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std) else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize()
if MXFP8: if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8 from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features)) w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
@ -216,15 +216,8 @@ class FlatTransformer:
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"]) x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
amaxs.append(new_amax) amaxs.append(new_amax)
saves.extend([*s, x_w3]) saves.extend([*s, x_w3])
if FUSED_SILU_W13 and MXFP8: out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8 grad_amax_state=kwargs["grad_amax_xout"])
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
else:
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
grad_amax_state=kwargs["grad_amax_xout"])
amaxs.append(new_amax) amaxs.append(new_amax)
saves.extend([*s, out]) saves.extend([*s, out])
else: else:
@ -254,30 +247,20 @@ class FlatTransformer:
for v in get_parameters(self): v.shard_(device, axis=None) for v in get_parameters(self): v.shard_(device, axis=None)
else: else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer # flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
def _shard_fp8(name:str, axis:int, std:float=0.02): def _shard_fp8(name:str, axis:int):
w = getattr(self, name) getattr(self, name).shard_(device, axis=axis)
if MXFP8: scale_axis = axis if MXFP8 else (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
from extra.gemm.cdna_asm_gemm import quantize_mxfp8 self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
w_q, w_e8, _ = quantize_mxfp8(w_bf16) Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
w.replace(w_q)
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
else:
w.shard_(device, axis=axis)
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
sstd = 0.02 / math.sqrt(2 * self.n_layers)
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out _shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
_shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in _shard_fp8("wo", 2) # (n_layers, dim, in) shard in
if SPLIT_W13: if SPLIT_W13:
_shard_fp8("w1", 1) _shard_fp8("w1", 1)
_shard_fp8("w3", 1) _shard_fp8("w3", 1)
else: else:
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out _shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
_shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in _shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize() self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize() self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize() self.norm.weight.shard_(device, axis=None).realize()

View file

@ -3,7 +3,7 @@ import torch
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import trange, Context from tinygrad.helpers import trange
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -86,7 +86,7 @@ if __name__ == "__main__":
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium! optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5) optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop # training loop
with Context(TRAINING=1): with Tensor.train():
for epoch in (t := trange(epochs)): for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0 loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps): for _ in range(n_steps):

View file

@ -5,7 +5,7 @@
# - symbolic removal # - symbolic removal
from examples.beautiful_mnist import Model from examples.beautiful_mnist import Model
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
from tinygrad.helpers import trange from tinygrad.helpers import trange
@ -26,7 +26,7 @@ if __name__ == "__main__":
X_samp, Y_samp = X_train[samples], Y_train[samples] X_samp, Y_samp = X_train[samples], Y_train[samples]
print("*** got samples") print("*** got samples")
with Context(TRAINING=1): with Tensor.train():
""" """
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i) losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)

View file

@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float) # accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG) acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
acc = acc.after(acc.store(acc.zeros_like(buffer=False))) acc = acc.after(acc.store(acc.zeros_like()))
if use_wmma: if use_wmma:
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE) k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)

View file

@ -2675,8 +2675,8 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]: def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# 1x32 block scaling along the last axis # 1x32 block scaling along the last axis
*batch, K = x.shape *batch, K = x.shape
scale_K = K // 32 scale_K, k_iters = K // 32, K // 128
amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1) amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8) e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K) qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
x_scaled = x.float() * qscale x_scaled = x.float() * qscale

View file

@ -143,17 +143,14 @@ def make_getaddr(u, device=None):
def make_ins(op, *srcs): def make_ins(op, *srcs):
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op) return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
dt = dtype or val.dtype
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
def make_cmdbuf(lin, devs, tag): def make_cmdbuf(lin, devs, tag):
blob, patches = b'', [] blob, patches = b'', []
for s in (s for ins in lin.src for s in ins.src): for s in (s for ins in lin.src for s in ins.src):
if s.op is not Ops.CONST: patches.append((len(blob), s)) if s.op is not Ops.CONST: patches.append((len(blob), s))
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0) blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag) buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches]) stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops)) def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
@ -214,11 +211,15 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call)) return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
def prep_kernargs(call:UOp, prg:UOp) -> UOp: def prep_kernargs(call:UOp, prg:UOp) -> UOp:
(data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device) data, info = prg.arg
buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs") patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \ dtypes.uint64) for i,gi in enumerate(info.globals)] \
+ [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)] + [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
pm_prep_runtime = PatternMatcher([ pm_prep_runtime = PatternMatcher([
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering # bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
@ -306,7 +307,7 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
for (_, lane), dep in latest.items(): deps[dep] += (lane,) for (_, lane), dep in latest.items(): deps[dep] += (lane,)
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps") if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}),))) new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
return linear.replace(src=tuple(new_src)) return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)]) pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
@ -531,9 +532,9 @@ pm_resolve_patches = PatternMatcher([
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack), (UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack), (UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
# shrink on slice is shrink on base at offset # index on slice is index
(UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"), (UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))), lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
# getaddr # getaddr
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x) (UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
@ -541,8 +542,8 @@ pm_resolve_patches = PatternMatcher([
# folders # folders
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store), (UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast() (UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
.store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store), fold_const_store),
]) + symbolic_simple ]) + symbolic_simple
# ***************** # *****************

View file

@ -1,104 +0,0 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
LOG2E = 1.4426950408889634
@functools.cache
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
act = w1 * sig * w3
abs_a = (act < 0.0).where(-act, act)
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
@functools.cache
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
VEC = 8
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
nwg = n_elems // (THREADS_PER_WG * VEC)
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
lane = UOp.range(VEC, 2, AxisType.UNROLL)
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
e8v = e8[idx // BLK].cast(dtypes.float)
qscale = (127.0 - e8v).exp2()
ga = grad_aq[idx].cast(dtypes.float) * qscale
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
s = w1 * sig
sprime = sig * (1.0 + w1 * (1.0 - sig))
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
device = x_w1.device
rows, K = x_w1.shape
axis = x_w1.axis if isinstance(device, tuple) else None
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
fxn=_custom_silu_mul_bwd_mxfp8)
return (None, None, None, gx1.uop, gx3.uop)
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x_w1.shape
scale_K = K // BLK
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
return fp8_out, e8_out, si_out

View file

@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
step = THREADS_PER_WG // 2 step = THREADS_PER_WG // 2
while step: while step:
active = tid < step active = tid < step
other = lds[(tid + step).valid(active)].load() other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier()) lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
step //= 2 step //= 2
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0]) amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])

View file

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import trange, Context from tinygrad.helpers import trange
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit
@ -22,7 +22,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
if allow_jit: train_step = TinyJit(train_step) if allow_jit: train_step = TinyJit(train_step)
with Context(TRAINING=1): with Tensor.train():
losses, accuracies = [], [] losses, accuracies = [], []
for i in (t := trange(steps, disable=None)): for i in (t := trange(steps, disable=None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS)) samp = np.random.randint(0, X_train.shape[0], size=(BS))
@ -55,3 +55,4 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
acc, Y_test_pred = numpy_eval(Y_test, num_classes) acc, Y_test_pred = numpy_eval(Y_test, num_classes)
print("test set accuracy is %f" % acc) print("test set accuracy is %f" % acc)
return (acc, Y_test_pred) if return_predict else acc return (acc, Y_test_pred) if return_predict else acc

View file

@ -25,7 +25,7 @@
import unittest import unittest
import numpy as np import numpy as np
import torch import torch
from tinygrad import Tensor, dtypes, nn, Context from tinygrad import Tensor, dtypes, nn
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.helpers import DEV from tinygrad.helpers import DEV
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
@ -101,7 +101,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
# we don't need more of these # we don't need more of these
def test_dropout_rate_one(self): def test_dropout_rate_one(self):
with Context(TRAINING=1): with Tensor.train():
out = Tensor.ones(100).dropout(1.0) out = Tensor.ones(100).dropout(1.0)
np.testing.assert_allclose(out.numpy(), np.zeros(100)) np.testing.assert_allclose(out.numpy(), np.zeros(100))
@ -109,7 +109,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
torch.nn.functional.dropout(torch.ones(10), -0.1, True) torch.nn.functional.dropout(torch.ones(10), -0.1, True)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
with Context(TRAINING=1): with Tensor.train():
Tensor.ones(10).dropout(-0.1) Tensor.ones(10).dropout(-0.1)
class TestInputValidation(unittest.TestCase): class TestInputValidation(unittest.TestCase):

View file

@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
renderer=Device[Device.DEFAULT].renderer).src[2].src) renderer=Device[Device.DEFAULT].renderer).src[2].src)
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD]) num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
assert num_loads <= 4, "more load uops than needed" assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 1, "expected at least one load uop" assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
@unittest.skip("this is handled at higher level now") @unittest.skip("this is handled at higher level now")
def test_upcast_cse(self): def test_upcast_cse(self):

View file

@ -14,7 +14,7 @@ from test.helpers import not_support_multi_device, needs_second_gpu, slow
@slow @slow
class TestNN(unittest.TestCase): class TestNN(unittest.TestCase):
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True): def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
with Context(TRAINING=training): with Tensor.train(training):
szs = [4, 8, 16, 32] szs = [4, 8, 16, 32]
for sz in szs: for sz in szs:
# create in tinygrad # create in tinygrad

View file

@ -41,7 +41,7 @@ class TestStunning(unittest.TestCase):
X_samp, Y_samp = X_train[samples], Y_train[samples] X_samp, Y_samp = X_train[samples], Y_train[samples]
vi = Variable('i', 0, samples.shape[0]-1) vi = Variable('i', 0, samples.shape[0]-1)
with Context(SPLIT_REDUCEOP=0): with Context(SPLIT_REDUCEOP=0):
with Context(TRAINING=1): with Tensor.train():
losses = [] losses = []
for i in range(samples.shape[0]): for i in range(samples.shape[0]):
vib = vi.bind(i) vib = vi.bind(i)

View file

@ -1,5 +1,5 @@
import unittest import unittest
from tinygrad import Tensor, Variable, GlobalCounters, Context from tinygrad import Tensor, Variable, GlobalCounters
from tinygrad.uop.ops import sym_infer from tinygrad.uop.ops import sym_infer
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from examples.gpt2 import Attention from examples.gpt2 import Attention
@ -63,7 +63,7 @@ class TestSymbolicOps(unittest.TestCase):
self.test_attention(imin=4, imax=5, use_symbolic=True) self.test_attention(imin=4, imax=5, use_symbolic=True)
def test_attention_training(self): def test_attention_training(self):
with Context(TRAINING=1): with Tensor.train():
self.test_attention(dropout_p=0.0) self.test_attention(dropout_p=0.0)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# symbolic shape dropout is not supported # symbolic shape dropout is not supported

View file

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import torch import torch
import unittest, copy, mmap, random, math, array import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes, nn, Context from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.helpers import getenv, temp, mv_address from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat from hypothesis import given, settings, strategies as strat
@ -203,7 +203,7 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(x, y, atol=1e-5) np.testing.assert_allclose(x, y, atol=1e-5)
def test_dropout(self): def test_dropout(self):
with Context(TRAINING=1): with Tensor.train():
n, rate = 1_000_000, 0.1 n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate) w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.numpy()) non_zeros = np.count_nonzero(w.numpy())

View file

@ -1,67 +0,0 @@
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import AdamW
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TinyNet:
def __init__(self):
self.x = Tensor(x_init.copy())
self.W = Tensor(W_init.copy())
self.m = Tensor(m_init.copy())
def forward(self):
out = self.x.matmul(self.W).relu()
out = out.log_softmax(1)
out = out.mul(self.m).add(self.m).sum()
return out
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__':
unittest.main()

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest import unittest, math
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.optimizers import Lamb from tensorflow.keras.optimizers import Lamb
@ -7,11 +7,11 @@ from tensorflow.python.ops import math_ops
from extra.lr_scheduler import LRSchedulerGroup from extra.lr_scheduler import LRSchedulerGroup
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
np.random.seed(1337) np.random.seed(1337)
@ -173,5 +173,48 @@ class ExternalTestOptim(unittest.TestCase):
'warmup': steps_per_epoch * warmup_epochs, 'warmup': steps_per_epoch * warmup_epochs,
}, 1e-5, 1e-5, do_optim=False) }, 1e-5, 1e-5, do_optim=False)
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -1,6 +1,6 @@
import unittest, os import unittest, os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from tinygrad import Context from tinygrad import Tensor
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from examples.mlperf.model_train import train_stable_diffusion from examples.mlperf.model_train import train_stable_diffusion
@ -14,10 +14,10 @@ class TestTrain(unittest.TestCase):
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion" if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
with TemporaryDirectory(prefix="test-train") as tmp: with TemporaryDirectory(prefix="test-train") as tmp:
os.environ["UNET_CKPTDIR"] = tmp os.environ["UNET_CKPTDIR"] = tmp
with Context(TRAINING=1): with Tensor.train():
saved_ckpts = train_stable_diffusion() saved_ckpts = train_stable_diffusion()
expected_ckpt = f"{tmp}/{num_steps}.safetensors" expected_ckpt = f"{tmp}/{num_steps}.safetensors"
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
if __name__=="__main__": if __name__=="__main__":
unittest.main() unittest.main()

View file

@ -3,7 +3,7 @@ import ast, pathlib, unittest
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tinygrad import Tensor, Context from tinygrad import Tensor
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from test.helpers import slow from test.helpers import slow
from extra.models.efficientnet import EfficientNet from extra.models.efficientnet import EfficientNet
@ -40,7 +40,7 @@ def preprocess(img, new=False):
return img return img
def _infer(model: EfficientNet, img): def _infer(model: EfficientNet, img):
with Context(TRAINING=0): with Tensor.train(False):
out = model.forward(Tensor(img)).argmax(axis=-1) out = model.forward(Tensor(img)).argmax(axis=-1)
return out.tolist() return out.tolist()

View file

@ -5,11 +5,10 @@ import numpy as np
from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import Context
from extra.datasets import fetch_mnist from extra.datasets import fetch_mnist
def compare_tiny_torch(model, model_torch, X, Y): def compare_tiny_torch(model, model_torch, X, Y):
with Context(TRAINING=1): with Tensor.train():
model_torch.train() model_torch.train()
model_state_dict = get_state_dict(model) model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters(): for k,v in model_torch.named_parameters():

View file

@ -106,7 +106,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_train_mnist(self): def test_train_mnist(self):
from examples.beautiful_mnist import Model from examples.beautiful_mnist import Model
with Context(TRAINING=1): with Tensor.train():
model = Model() model = Model()
optimizer = optim.Adam(get_parameters(model)) optimizer = optim.Adam(get_parameters(model))
BS = 32 BS = 32
@ -125,7 +125,7 @@ class TestRealWorld(unittest.TestCase):
def test_forward_cifar(self): def test_forward_cifar(self):
BS = 32 BS = 32
# with training batchnorm still though # with training batchnorm still though
with Context(TRAINING=1): with Tensor.train():
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
@TinyJit @TinyJit
def run(X): return model(X) def run(X): return model(X)
@ -133,7 +133,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_train_cifar(self): def test_train_cifar(self):
with Context(TRAINING=1): with Tensor.train():
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15) optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
BS = 32 BS = 32
@ -151,7 +151,7 @@ class TestRealWorld(unittest.TestCase):
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16") @unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_train_cifar_hyp(self): def test_train_cifar_hyp(self):
dtypes.default_float = dtypes.float16 dtypes.default_float = dtypes.float16
with Context(TRAINING=1): with Tensor.train():
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay']) optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
initial_div_factor = hyp['opt']['initial_div_factor'] initial_div_factor = hyp['opt']['initial_div_factor']
@ -163,7 +163,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_bert(self): def test_bert(self):
with Context(TRAINING=1): with Tensor.train():
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2, args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2} "max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
model = BertForPretraining(**args_tiny) model = BertForPretraining(**args_tiny)

View file

@ -1093,14 +1093,14 @@ class TestSchedule(unittest.TestCase):
#@unittest.skip("may want to reconsider this") #@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self): def test_fold_batchnorm(self):
with Context(TRAINING=1): with Tensor.train():
img = Tensor.empty(1,32,4,4) img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img) out = bn(img)
check_schedule(out, 3, nn.state.get_parameters(bn)) check_schedule(out, 3, nn.state.get_parameters(bn))
def test_fold_conv_batchnorm_notrain(self): def test_fold_conv_batchnorm_notrain(self):
with Context(TRAINING=0): with Tensor.train(False):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=True) bn = nn.BatchNorm2d(32, track_running_stats=True)
@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)]) check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm_notrain_no_running_stats(self): def test_fold_conv_batchnorm_notrain_no_running_stats(self):
with Context(TRAINING=0): with Tensor.train(False):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1116,7 +1116,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)]) check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm(self): def test_fold_conv_batchnorm(self):
with Context(TRAINING=1): with Tensor.train():
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1125,7 +1125,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self, adam=False): def test_fold_conv_batchnorm_optim(self, adam=False):
optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15) optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15)
with Context(TRAINING=1): with Tensor.train():
img = Tensor.ones(1,3,4,4) img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1139,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True) def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
def test_fold_batchnorm_backward(self): def test_fold_batchnorm_backward(self):
with Context(TRAINING=1): with Tensor.train():
x = Tensor.empty((2, 16, 8, 8)).contiguous() x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16) bn = nn.BatchNorm2d(16)
fw = bn(x).contiguous_backward().relu().contiguous() fw = bn(x).contiguous_backward().relu().contiguous()
@ -1484,7 +1484,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4) check_schedule(out, 4)
def test_adam_step_fusion(self): def test_adam_step_fusion(self):
with Context(TRAINING=1): with Tensor.train():
x = Tensor.empty(4, 64, 32) x = Tensor.empty(4, 64, 32)
layer = nn.Linear(32, 32*4) layer = nn.Linear(32, 32*4)
_realize_weights(layer) _realize_weights(layer)
@ -1494,7 +1494,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13) check_schedule(opt.schedule_step(), 13)
def test_adam_conv_fuse(self): def test_adam_conv_fuse(self):
with Context(TRAINING=1): with Tensor.train():
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
_realize_weights(c1) _realize_weights(c1)
@ -1505,7 +1505,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13) check_schedule(opt.schedule_step(), 13)
def test_adam_2convs_fuse(self): def test_adam_2convs_fuse(self):
with Context(TRAINING=1): with Tensor.train():
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1517,7 +1517,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15) check_schedule(opt.schedule_step(), 15)
def test_sgd_conv_fuse(self): def test_sgd_conv_fuse(self):
with Context(TRAINING=1): with Tensor.train():
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
_realize_weights(c1) _realize_weights(c1)
@ -1527,7 +1527,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 5) # TODO: 3? check_schedule(opt.schedule_step(), 5) # TODO: 3?
def test_sgd_2convs_fuse(self): def test_sgd_2convs_fuse(self):
with Context(TRAINING=1): with Tensor.train():
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1538,7 +1538,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 7) check_schedule(opt.schedule_step(), 7)
def test_fold_2convs_sgd_nesterov_momentum_wd(self): def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Context(TRAINING=1): with Tensor.train():
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1550,7 +1550,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 11) check_schedule(opt.schedule_step(), 11)
def test_sgd_4convs_fuse(self): def test_sgd_4convs_fuse(self):
with Context(TRAINING=1): with Tensor.train():
img = Tensor.empty(2,3,16,16) img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False) c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False)
@ -1563,7 +1563,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15) check_schedule(opt.schedule_step(), 15)
def test_sgd_4convs_fuse_conv_bw(self): def test_sgd_4convs_fuse_conv_bw(self):
with Context(TRAINING=1): with Tensor.train():
img = Tensor.empty(2,3,16,16) img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False) c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False)
@ -1664,7 +1664,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0) self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0)
def test_resnet_block(self): def test_resnet_block(self):
with Context(TRAINING=0): with Tensor.train(False):
in_planes, planes = 64, 64 in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes) bn1 = nn.BatchNorm2d(planes)

View file

@ -16,17 +16,41 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
def get_gated_load_uop(valid:UOp, idx:UOp): def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, ( return UOp(Ops.LOAD, dtypes.float, (
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True), UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
)) ))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]): def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), ( return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True), UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
)) ))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr) def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(nmax, n) def Range(n, nmax): return UOp.range(nmax, n)
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
f3 = (idx2*2)+ridx1+(-1)
self.assertFalse(f0.is_increasing())
self.assertTrue(f1.is_increasing())
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())
class TestValidIdxSimplification(unittest.TestCase): class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid, extra=()): def check(self, load, sidx, svalid, extra=()):
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0] load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
@ -482,16 +506,6 @@ class TestImageSimplification(unittest.TestCase):
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))", self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0") "(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
def test_drop_non_monotonic_window(self):
# two-sided window valid (645 <= gidx0 < 653) on a non-monotonic index (lane split via %4 and //4):
# gidx0 outside the window pushes idx_x out of the (1, 48) image, so the gate is dropped
gidx0 = Special("gidx0", 1064)
r12 = Range(12, 3)
valid = ((gidx0 < 645).ne(True)) & (gidx0 < 653)
idx = (r12*4 + (gidx0+3)%4 + (gidx0+3)//4*24 - 3888, UOp.const(dtypes.weakint, 0))
load = get_load_image_uop((1, 48, 4), valid, idx)
self.check(load, None, "(r12*4+(gidx0+3)%4+(gidx0+3)//4*24+-3888)", "0")
class TestDropTrueGate(unittest.TestCase): class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self): def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True valid gets simplified to drop the valid # test that INDEX with a constant True valid gets simplified to drop the valid

View file

@ -1,7 +1,7 @@
# tensor tests that pass on NULL backend (no copyout needed) # tensor tests that pass on NULL backend (no copyout needed)
import numpy as np import numpy as np
import unittest import unittest
from tinygrad import Tensor, Device, dtypes, Context from tinygrad import Tensor, Device, dtypes
from tinygrad.uop.ops import Ops, UOp from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
@ -15,7 +15,7 @@ m_init = np.random.randn(1,3).astype(np.float32)
class TestTrainMode(unittest.TestCase): class TestTrainMode(unittest.TestCase):
def test_train_mode(self): def test_train_mode(self):
assert not Tensor.training assert not Tensor.training
@Context(TRAINING=1) @Tensor.train()
def f(): def f():
assert Tensor.training assert Tensor.training
f() f()

View file

@ -207,7 +207,7 @@ class TestMultiTensor(unittest.TestCase):
out.numpy() out.numpy()
def test_backprop_conv(self): def test_backprop_conv(self):
with Context(TRAINING=1): with Tensor.train():
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2) for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.Adam(get_parameters(conv)) optim = nn.optim.Adam(get_parameters(conv))
@ -511,7 +511,7 @@ class TestMultiTensor(unittest.TestCase):
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0) def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_dropout_on_shard(self): def test_dropout_on_shard(self):
with Context(TRAINING=1): with Tensor.train():
X = Tensor.ones(256).to(devices_2) X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5).numpy() output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True) unique, counts = np.unique(output, return_counts=True)
@ -519,7 +519,7 @@ class TestMultiTensor(unittest.TestCase):
assert 96 < counts[0] < 160, counts[0] assert 96 < counts[0] < 160, counts[0]
def test_dropout_on_shard_axis(self): def test_dropout_on_shard_axis(self):
with Context(TRAINING=1): with Tensor.train():
X = Tensor.ones(512).shard(devices_2, axis=0) X = Tensor.ones(512).shard(devices_2, axis=0)
output = X.dropout(0.5).numpy() output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True) unique, counts = np.unique(output, return_counts=True)
@ -664,7 +664,7 @@ class TestBatchNorm(unittest.TestCase):
def setUp(self): pass def setUp(self): pass
def test_unsynced_backprop_conv_bn(self): def test_unsynced_backprop_conv_bn(self):
with Context(TRAINING=1): with Tensor.train():
from extra.lr_scheduler import OneCycleLR from extra.lr_scheduler import OneCycleLR
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)] convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
@ -709,7 +709,7 @@ class TestBatchNorm(unittest.TestCase):
bn_ts.append(bni) bn_ts.append(bni)
return bn_ts[0].cat(*bn_ts[1:]) return bn_ts[0].cat(*bn_ts[1:])
with Context(TRAINING=1): with Tensor.train():
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
bn = BatchNorm(16) bn = BatchNorm(16)
@ -731,7 +731,7 @@ class TestBatchNorm(unittest.TestCase):
from examples.hlb_cifar10 import UnsyncedBatchNorm from examples.hlb_cifar10 import UnsyncedBatchNorm
GPUS = (d1, d2) GPUS = (d1, d2)
with Context(TRAINING=1): with Tensor.train():
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS)) bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
@ -756,7 +756,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0) x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0)
with Context(TRAINING=is_training): with Tensor.train(is_training):
bns = [] bns = []
for _ in range(len(devices)): for _ in range(len(devices)):
bn = nn.BatchNorm2d(8) bn = nn.BatchNorm2d(8)
@ -777,7 +777,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Context(TRAINING=1): with Tensor.train():
synced_bn = BatchNorm2d(8) synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices)) unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))

View file

@ -1,19 +1,21 @@
from typing import cast from typing import cast
from dataclasses import replace from dataclasses import replace
import itertools import itertools
import functools
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, all_same, flatten
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
from tinygrad.uop.ops import AxisType, _align_left, _broadcast_shape, identity_element
from tinygrad.uop.render import pyrender from tinygrad.uop.render import pyrender
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
from tinygrad.renderer import Renderer, Estimates from tinygrad.renderer import Renderer, Estimates
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
from tinygrad.dtype import dtypes, PtrDType, ImageDType from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
# import all pattern matchers here # import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps, get_simplifying_rewrite_patterns from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
@ -23,7 +25,6 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
from tinygrad.codegen.late.coalese import memory_coalesing
pm_index_is_shrink = PatternMatcher([ pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK # rewrite non-image INDEX to SHRINK
@ -44,6 +45,98 @@ pm_remove_vec_dtypes = PatternMatcher([
lambda x: x.replace(dtype=x.dtype.base.scalar().base)), lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
])+pm_clean_up_group_sink ])+pm_clean_up_group_sink
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_move_regs = PatternMatcher([
# BITCAST?
(UPat(GroupOp.Elementwise, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
])
pm_lower_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
])
def build_range_map(ctx, sink:UOp):
for x in sink.toposort():
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
ctx[x.arg[0]] = len(ctx)
def fix_reduce(ctx, r:UOp):
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
expander2 = PatternMatcher([
(UPat(Ops.SINK, name="sink"), build_range_map),
(UPat(Ops.REDUCE, name="r"), fix_reduce),
(UPat(Ops.RANGE, name="r"),
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
])+pm_flatten_range
def broadcast_binary(x:UOp):
shapes = [u.shape for u in x.src]
if all_same(shapes): return None
shaped_aligned = _align_left(*shapes)
broadcasted = _broadcast_shape(*shapes)
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
return x.replace(src=tuple(src_reshaped))
unbroadcast = PatternMatcher([
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
])
def do_devectorize(b:UOp):
if b.shape == (): return None
# broadcasting needs to be already unpacked
if not all_same([x.shape for x in b.src]): return None
src = []
for idx in itertools.product(*[range(x) for x in b.shape]):
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
return UOp.vectorize(*src).reshape(b.shape)
devectorizer2 = pm_mops+PatternMatcher([
# unpack broadcasting
(UPat(GroupOp.Elementwise|{Ops.LOAD, Ops.STORE}, name="b"), do_devectorize),
# INDEX into STACK is src
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
# stacked INDEX is many INDEX
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
# INDEX into RESHAPE moves the RESHAPE
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
# RESHAPE a void is removed (hack for AFTER)
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
# reshape of a single element shaped value to scalar is an index
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
# INDEX without src is nothing
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
])
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
ctx.acc_num += 1
topo = r.src[0].toposort()
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
acc_initted = acc.after(acc_init, *r.src[1:])
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
return acc.after(acc_out)
def expand_horizontal_reduce(r:UOp):
axes = r.arg[1]
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
pm_reduce_local = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
])+pm_clean_up_group_sink
def do_number_param(ctx:list[int], x:UOp): def do_number_param(ctx:list[int], x:UOp):
if x.arg.slot != -1: return None if x.arg.slot != -1: return None
ctx[0] += 1 ctx[0] += 1
@ -53,14 +146,94 @@ pm_number_params = PatternMatcher([
(UPat(Ops.PARAM, name="x"), do_number_param), (UPat(Ops.PARAM, name="x"), do_number_param),
]) ])
pm_no_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int))
])
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast)) if DEBUG >= 5: print(pyrender(ast))
if SPEC: type_verify(ast, spec_tensor) if SPEC: type_verify(ast, spec_tensor)
sink = ast
# preprocess. we need to simplify these
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
# this is new style
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
# first we optimize
if optimize:
# do postrange optimization, BEAM or hand_coded_optimizations
sink = apply_opts(sink, ren, beam=ast.arg.beam)
# do expander
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
# add locals (STAGE -> BUFFER)
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
# rewrite reduce after optimizations
sink = graph_rewrite(sink, pm_reduce_local, ctx=ReduceContext(), name="remove_reduce")
# add gpu dims
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
# add loads
sink = graph_rewrite(sink, pm_move_regs, name="move to registers", walk=True)
# symbolic (note: this does POW decomp)
sink = graph_rewrite(sink, sym, name="post index symbolic")
# ***** make it rendererable (within spec, tighten) *****
# decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="*** decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions more")
# split ends
sink = graph_rewrite(sink, pm_split_ends, name="split ends")
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# ***** this is where it gets large *****
# unbroadcast
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
# devectorizer
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, name="devectorizer")
# ***** make it rendererable (outside spec, transform) *****
# final symbolic
sink = graph_rewrite(sink, sym, name="post devectorizer sym")
# move gates from unrenderable INVALID where
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
# put registers in slots
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot != -1])
name_to_slot = {x:x.replace(arg=replace(x.arg, slot=num_params+i))
for i,x in enumerate(sorted([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot == -1]))}
sink = sink.substitute(name_to_slot, name="put variables in slots")
# remove all weakints
sink = graph_rewrite(sink, pm_lower_weakints, name="lower weakints", bottom_up=True)
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Output AST")
if SPEC: type_verify(sink, spec_program)
# return the rewritten sink
return sink
def old_full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
if SPEC: type_verify(ast, spec_tensor)
# preprocess # preprocess
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True) sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
@ -83,7 +256,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = apply_opts(sink, ren, beam=ast.arg.beam) sink = apply_opts(sink, ren, beam=ast.arg.beam)
# ** expander (expand_rewrite) ** # ** expander (expand_rewrite) **
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic") sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
# expand # expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander") sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
@ -118,23 +291,18 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# optional pre matcher # optional pre matcher
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher") if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
# floordiv+mod / dtype decomp (early) # decompositions
supported_ops = tuple(ren.code_for_op.keys()) supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_simplifying_rewrite_patterns(supported_ops) pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
sink = graph_rewrite(sink, pm_decomp, name="early decompositions") pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes") sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
# do memory coalesing (late) # GEP/STACK stuff
sink = memory_coalesing(sink, ren)
# instruction selection decompositions
pm_decomp = pm_decomp+\
get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))+\
get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="late decompositions")
# this is new style (TODO: this should all be removed)
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack") sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
# this is new style
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink") sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style") sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
@ -143,7 +311,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# final rules for the renderer (without sym) # final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([]) extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends+pm_no_weakints pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite") sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
# this was the linearizer # this was the linearizer

View file

@ -1,73 +0,0 @@
from typing import Any
import itertools
from collections import defaultdict
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import getenv
from tinygrad.renderer import Renderer
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
if getenv("DMC"): return sink
# collect
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
for u in sink.toposort():
# TODO: this should handle images too, it's just memory coalesing
if u.op in {Ops.LOAD, Ops.STORE} and not isinstance(u.src[0].src[0].dtype, ImageDType):
assert len(u.src) == (2 if u.op is Ops.STORE else 1), "memory coalesing does not support gated loads/stores"
assert u.src[0].op is Ops.INDEX
buf, idx_u = u.src[0].src
if buf.addrspace == AddrSpace.REG: continue
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
# build replacements
replacements = {}
for (op,buf,base,valid),offsets in memory.items():
# allowed lengths (copied in)
lengths = []
must_divide = True
if ctx is not None and ctx.target.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded
# do the grouping
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for full_grp in grouped_offsets:
while len(full_grp):
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
grp = full_grp[:length]
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op == Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1, f"attempting multiple stores: {len(offsets[g])}"
datas.append(offsets[g][0].src[1])
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
store = idx.store(data, valid) if valid is not None else idx.store(data)
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
else:
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
for i,g in enumerate(grp):
for oo in offsets[g]:
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
full_grp = full_grp[length:]
# apply
return sink.substitute(replacements, name="memory coalesing")

View file

@ -14,7 +14,7 @@ from tinygrad.renderer import Renderer
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]: def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
# can drop valid if idx is out of bound when valid is False # can drop valid if idx is out of bound when valid is False
drop_stmt = [] drop_stmt = []
for i,stmt in enumerate(valid.split_uop(Ops.AND)): for stmt in valid.split_uop(Ops.AND):
if (res:=parse_valid(stmt)) is None: continue if (res:=parse_valid(stmt)) is None: continue
X, is_upper_bound, c = res X, is_upper_bound, c = res
@ -25,12 +25,12 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
drop_stmt.append(stmt) drop_stmt.append(stmt)
continue continue
# check if idx is out of bound when X is on the wrong side of the bound: X in [c+1, vmax] or [vmin, c-1] # if X <= c, check if it's out of bound when X = c+1
lo, hi = (c + 1, X.vmax) if is_upper_bound else (X.vmin, c - 1) # if X >= c, check if it's out of bound when X = c-1
if lo <= hi: test_value = c + 1 if is_upper_bound else c - 1
fake = UOp.variable(f"fake{i}", lo, hi, X.dtype) for i,b in zip(idx.src, (width, height)):
for coord,b in zip(idx.src, (width, height)): if i.is_increasing():
rw = coord.substitute({X:fake}).simplify() rw = i.substitute({X:X.const_like(test_value)})
if rw.vmin >= b or rw.vmax < 0: if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt) drop_stmt.append(stmt)
break break
@ -162,8 +162,18 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# determine fold lengths # determine fold lengths
lengths = [] lengths = []
must_divide = True must_divide = True
# TODO: this belongs in coalese if ctx is not None and ctx.target.device == "DSP":
if isinstance(buf.dtype, ImageDType): lengths = [4] lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded lengths.append(1) # worst case, it's not folded
# filter fold lengths that don't divide # filter fold lengths that don't divide

View file

@ -101,12 +101,6 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# for Schedule, we check if the range is used in INDEX gates or WHERE gates # for Schedule, we check if the range is used in INDEX gates or WHERE gates
is_masked = k.rngs[axis] in where_gate_rngs is_masked = k.rngs[axis] in where_gate_rngs
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7: if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
# upcasting a masked global axis moves that range out of the launch grid into each work-item
# under IMAGE, skip the upcast unless enough global work-items remain after it to hide memory latency
if IMAGE and k.axis_types[axis] is AxisType.GLOBAL:
global_upcast = prod(k.full_shape[i] for i in to_upcast if k.axis_types[i] is AxisType.GLOBAL) * k.full_shape[axis]
global_items_after = prod(k.full_shape[i] for i in k.axes_of(AxisType.GLOBAL)) // global_upcast
if resolve(global_items_after < getenv("OCCUPANCY_FLOOR", 4096), False): continue
if DEBUG >= 4: print(f"upcasting masked axis : {axis}") if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis) to_upcast.append(axis)
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0)) for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))

View file

@ -240,7 +240,6 @@ DEV, DEBUG, BEAM, NOOPT = _DEV("DEV", ""), ContextVar("DEBUG", 0), ContextVar("B
IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0) IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0)
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32) JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0) WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0)
TRAINING = ContextVar("TRAINING", 0)
USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0) USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0)
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0) TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)
@ -275,7 +274,7 @@ SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward # allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0) USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# don't allow broadcast # don't allow broadcast
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1) DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
@dataclass(frozen=True) @dataclass(frozen=True)
class Metadata: class Metadata:

View file

@ -6,7 +6,7 @@ from tinygrad.mixin.movement import MovementMixin
from tinygrad.mixin.reduce import ReduceMixin from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype from tinygrad.dtype import ConstType, DTypeLike, Invalid, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod
from tinygrad.helpers import resolve_pool_pads, round_up from tinygrad.helpers import resolve_pool_pads, round_up
@ -17,6 +17,9 @@ ReductionStr = Literal["mean", "sum", "none"]
class OpMixin(ElementwiseMixin, ReduceMixin): class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory") def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory")
def item(self) -> PyConst: def item(self) -> PyConst:
@ -31,6 +34,48 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
assert self.numel() == 1, "must have one element for item" assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)] return self.data()[(0,) * len(self.shape)]
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Self]) -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
@classmethod
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
```
"""
# TODO: enable this check
# if not buffer: assert device is None, "buffer=False does not support device specification"
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def __getitem__(self, indices) -> Self: def __getitem__(self, indices) -> Self:
""" """
Retrieves a sub-tensor using indexing. Retrieves a sub-tensor using indexing.
@ -162,6 +207,40 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops))) vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self) return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
@classmethod @classmethod
def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self: def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self:
""" """

View file

@ -1,24 +1,10 @@
from typing import TYPE_CHECKING, Callable, Self from typing import Self
from tinygrad.dtype import ConstType, DTypeLike, Invalid, dtypes, to_dtype from tinygrad.dtype import ConstType, DType, DTypeLike
from tinygrad.helpers import argfix
from tinygrad.mixin.dtype import DTypeMixin from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin
if TYPE_CHECKING:
from tinygrad.uop.ops import sint, UOp
class CreationMixin(DTypeMixin, MovementMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
class CreationMixin(DTypeMixin):
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b)) def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
def _multi_like(self, fxn:'Callable[[tuple[sint, ...], str|None], Self]') -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self: def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
""" """
Creates an empty tensor with the same shape as `self`. Creates an empty tensor with the same shape as `self`.
@ -26,76 +12,9 @@ class CreationMixin(DTypeMixin, MovementMixin):
""" """
return self._wrap_uop(self._uop.empty_like(dtype, device)) return self._wrap_uop(self._uop.empty_like(dtype, device))
@classmethod def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self:
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self: """Creates a tensor with the same shape as `self`, filled with the given value."""
""" return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype)
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:'tuple[sint, ...]', fill_value:'ConstType|UOp', dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
```
"""
# TODO: enable this check
# if not buffer: assert device is None, "buffer=False does not support device specification"
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def full_like(self, fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: type(self).full(shape, fill_value, dtype=dtype or self.dtype, device=dev, buffer=buffer))
return type(self).full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device, buffer=buffer)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
def zeros_like(self, **kwargs) -> Self: def zeros_like(self, **kwargs) -> Self:
""" """
@ -110,23 +29,6 @@ class CreationMixin(DTypeMixin, MovementMixin):
""" """
return self.full_like(0, **kwargs) return self.full_like(0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
def ones_like(self, **kwargs) -> Self: def ones_like(self, **kwargs) -> Self:
""" """
Creates a tensor with the same shape as `self`, filled with ones. Creates a tensor with the same shape as `self`, filled with ones.

View file

@ -45,11 +45,7 @@ class ElementwiseMixin(CreationMixin):
""" """
return self.cast(dtypes.bool).ne(True) return self.cast(dtypes.bool).ne(True)
def contiguous(self, **kwargs) -> Self: def contiguous(self, *args, **kwargs) -> Self: raise NotImplementedError
"""
Returns a contiguous tensor.
"""
return self._wrap_uop(self._uop.contiguous(**kwargs))
def contiguous_backward(self) -> Self: def contiguous_backward(self) -> Self:
""" """

View file

@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
import math import math
from typing import Self, cast from typing import Self, cast
from tinygrad.dtype import DType, DTypeLike, dtypes, least_upper_dtype, to_dtype from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, prod, TRAINING from tinygrad.helpers import all_int, argfix, ceildiv, prod
from tinygrad.mixin import OpMixin from tinygrad.mixin import OpMixin
from tinygrad.device import canonicalize_device from tinygrad.device import canonicalize_device
@ -273,54 +273,3 @@ class RandMixin(OpMixin):
# Efraimidis-Spirakis # Efraimidis-Spirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1] indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32) return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
def dropout(self, p=0.5) -> Self:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `TRAINING` is set (e.g. inside `Context(TRAINING=1)`).
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Context(TRAINING=1):
print(t.dropout().numpy())
```
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not TRAINING or p == 0: return self
if p == 1: return self.const_like(0)
return (self.rand_like(dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def scaled_dot_product_attention(self, key:Self, value:Self, attn_mask:Self|None=None, dropout_p:float=0.0,
is_causal:bool=False, enable_gqa:bool=False) -> Self:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if enable_gqa:
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
q = self
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value

View file

@ -22,7 +22,6 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \ (UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \
if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None), if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: ctx[x.src[0]] if x.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else None),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"),
# GPU stuff # GPU stuff

View file

@ -18,7 +18,7 @@ import sys
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
def add_ranges_to_store(ctx, x): def add_ranges_to_store(ctx, x):
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == () or x.src[0].max_numel() == x.src[1].max_numel() == 1: return None if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == (): return None
assert x.src[0].shape == x.src[1].shape, "bad store shape" assert x.src[0].shape == x.src[1].shape, "bad store shape"
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape] idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs) return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)

View file

@ -1,13 +1,14 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations from __future__ import annotations
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING from contextlib import ContextDecorator
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, to_dtype from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc, TRAINING from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
from tinygrad.mixin.rand import RandMixin from tinygrad.mixin.rand import RandMixin
from tinygrad.schedule import create_linear_with_vars from tinygrad.schedule import create_linear_with_vars
@ -58,25 +59,19 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
assert isinstance(ret, Tensor), "sum didn't return a Tensor" assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret return ret
# TODO: deprecate this, always use TRAINING class Tensor(RandMixin):
class TensorMeta(type):
@property
def training(cls) -> bool: return bool(TRAINING.value)
@training.setter
def training(cls, mode:bool): TRAINING.value = int(mode)
class Tensor(RandMixin, metaclass=TensorMeta):
""" """
A `Tensor` is a multi-dimensional matrix containing elements of a single data type. A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
```python exec="true" session="tensor" ```python exec="true" session="tensor"
from tinygrad import Tensor, dtypes, nn, Context from tinygrad import Tensor, dtypes, nn
import numpy as np import numpy as np
import math import math
np.set_printoptions(precision=4) np.set_printoptions(precision=4)
``` ```
""" """
__slots__ = "uop", "is_param", "grad" __slots__ = "uop", "is_param", "grad"
training: ClassVar[bool] = False
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None, def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
device:str|tuple|list|None=None, dtype:DTypeLike|None=None): device:str|tuple|list|None=None, dtype:DTypeLike|None=None):
@ -130,9 +125,9 @@ class Tensor(RandMixin, metaclass=TensorMeta):
@suppress_finalizing @suppress_finalizing
def __del__(self): all_tensors.pop(weakref.ref(self), None) def __del__(self): all_tensors.pop(weakref.ref(self), None)
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, **kwargs) -> Tensor: def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, extra_args=(), **kwargs) -> Tensor:
srcs = (self,)+x srcs = (self,)+x
new_uop: UOp = fxn(*[t.uop for t in srcs], **kwargs) new_uop: UOp = fxn(*[t.uop for t in srcs], *extra_args, **kwargs)
if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,) if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
# directly create the Tensor # directly create the Tensor
ret = Tensor.__new__(Tensor) ret = Tensor.__new__(Tensor)
@ -153,6 +148,11 @@ class Tensor(RandMixin, metaclass=TensorMeta):
self.is_param = is_param self.is_param = is_param
return self return self
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
def __repr__(self): def __repr__(self):
ld = self.uop ld = self.uop
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>" ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>"
@ -264,7 +264,6 @@ class Tensor(RandMixin, metaclass=TensorMeta):
x = self.cast(self.dtype.base).contiguous() x = self.cast(self.dtype.base).contiguous()
if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU") if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU")
return cast(Buffer, x.realize().uop.buffer).ensure_allocated() return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
def _data(self) -> memoryview: return self._buffer().as_memoryview() def _data(self) -> memoryview: return self._buffer().as_memoryview()
def data(self) -> memoryview: def data(self) -> memoryview:
@ -280,7 +279,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}" assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12) assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
return self._data().cast(self.dtype.base.fmt, self.shape) return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape)
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions) # NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
def tolist(self) -> PyConst|list[Any]: def tolist(self) -> PyConst|list[Any]:
@ -504,6 +503,25 @@ class Tensor(RandMixin, metaclass=TensorMeta):
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff)) high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
return Tensor._device_seeds[device], low.cat(high) return Tensor._device_seeds[device], low.cat(high)
# ***** creation helper functions *****
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
# ***** toposort and backward pass ***** # ***** toposort and backward pass *****
def backward(self, gradient:Tensor|None=None) -> Tensor: def backward(self, gradient:Tensor|None=None) -> Tensor:
@ -530,7 +548,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
# ***** movement ops ***** # ***** movement ops *****
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, op=op, arg=arg) def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis) def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None: def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
@ -719,6 +737,14 @@ class Tensor(RandMixin, metaclass=TensorMeta):
if IMAGE: return self.image_dot(w, dtype) if IMAGE: return self.image_dot(w, dtype)
return super().dot(w, dtype) return super().dot(w, dtype)
# ***** unary ops *****
def contiguous(self, *args, **kwargs) -> Tensor:
"""
Returns a contiguous tensor.
"""
return self._apply_uop(UOp.contiguous, extra_args=args, **kwargs)
# ***** broadcasted elementwise ops ***** # ***** broadcasted elementwise ops *****
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor: def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
@ -778,6 +804,59 @@ class Tensor(RandMixin, metaclass=TensorMeta):
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec") fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos))) return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)))
# ***** functional nn ops *****
def dropout(self, p=0.5) -> Tensor:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `Tensor.training` is `True`.
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Tensor.train():
print(t.dropout().numpy())
```
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not Tensor.training or p == 0: return self
if p == 1: return self.const_like(0)
return (Tensor.rand_like(self, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if enable_gqa:
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
q = self
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
# ***** cast ops ***** # ***** cast ops *****
def bitcast(self, dtype:DTypeLike) -> Tensor: def bitcast(self, dtype:DTypeLike) -> Tensor:

View file

@ -454,8 +454,7 @@ def floormod_to_mod(a:UOp, b:UOp) -> UOp:
powers_of_two: dict[int, int] = {2**i:i for i in range(64)} powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
@functools.cache @functools.cache
def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher: def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
# these are rewrites that make things simpler
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)] pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod # FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)) if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
@ -464,11 +463,6 @@ def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32)) if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends) # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]))) if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
return PatternMatcher(pat)
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
pat: list[tuple[UPat, Callable]] = []
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(), if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())] lambda x,y: (x | y).logical_not())]
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y) # rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)

View file

@ -75,7 +75,9 @@ def fold_divmod_general(d: UOp) -> UOp|None:
# divide_by_gcd: x//y -> (x//gcd)//(y//gcd) # divide_by_gcd: x//y -> (x//gcd)//(y//gcd)
gcd = UOp.gcd(*all_uops, y).simplify() gcd = UOp.gcd(*all_uops, y).simplify()
if not (gcd.op is Ops.CONST and gcd.arg==1): if not (gcd.op is Ops.CONST and gcd.arg==1):
ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd))) x_div, y_div = x.divide_exact(gcd), y.divide_exact(gcd)
if x_div is None or y_div is None: return None
ret = x_div.alu(d.op, y_div)
return ret*gcd if d.op is Ops.FLOORMOD else ret return ret*gcd if d.op is Ops.FLOORMOD else ret
# factor_remainder: (d*x+y)//d -> x+y//d # factor_remainder: (d*x+y)//d -> x+y//d

View file

@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp: def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
if len(arg) == 0: return UOp(Ops.STACK) if len(arg) == 0: return UOp(Ops.STACK)
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg) elif all_int(arg): return UOp.const(dtypes.weakint, arg)
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg)) else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
def consumer_map_from_toposort(lst:Iterable[UOp]): def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {} ret: dict[UOp, dict[UOp, None]] = {}
@ -306,9 +306,10 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END: Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
return self.src[0]._shape return self.src[0]._shape
# REDUCE with empty axis is passthrough (lowered form) # REDUCE with empty axis is passthrough (lowered form)
case Ops.REDUCE if len(self.arg[1]) == 0: # no longer true
#case Ops.REDUCE if len(self.arg[1]) == 0:
# these can mismatch if there's a horizonal reduce # these can mismatch if there's a horizonal reduce
return (self.dtype.count,) if self.dtype.count > 1 else () #return (self.dtype.count,) if self.dtype.count > 1 else ()
# TODO: disallow shape changing bitcast # TODO: disallow shape changing bitcast
case Ops.BITCAST: case Ops.BITCAST:
@ -473,12 +474,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0] if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None])) return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs): def vectorize(self, *srcs):
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs) return UOp(Ops.STACK, self.dtype, (self,)+srcs)
def index(self, *srcs:UOp|None, ptr=False, **kwargs): def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx): def __getitem__(self, idx):
# pointers index into INDEX UOps (scalar lookup); everything else uses the shared mixin view path # pointers index into INDEX UOps (scalar lookup); everything else uses the shared mixin view path
if not isinstance(self.dtype, PtrDType): return super(UOp, self).__getitem__(idx) #if not isinstance(self.dtype, PtrDType): return super(UOp, self).__getitem__(idx)
idx = self._normalize_indices(list(argfix(idx))) idx = self._normalize_indices(list(argfix(idx)))
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]): if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
# apply SHRINK for slices that aren't the full range # apply SHRINK for slices that aren't the full range
@ -919,6 +920,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# *** uop symbolic stuff *** # *** uop symbolic stuff ***
def is_increasing(self:UOp) -> bool:
# is f a monotonically increasing function regards its input
if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
return False # False if not sure
def const_factor(self) -> int: def const_factor(self) -> int:
"""largest known int that divides self""" """largest known int that divides self"""
# TODO: for negatives it's not the largest # TODO: for negatives it's not the largest
@ -1184,9 +1191,8 @@ python_alu: dict[Ops, Callable] = {
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq} Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True): def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
if any(isinstance(x, tuple) for x in operands): if dtype.count > 1:
count = max(len(x) for x in operands if isinstance(x, tuple)) return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(count)])
if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid
alu = python_alu[op](*operands) alu = python_alu[op](*operands)
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
@ -1667,8 +1673,6 @@ pm_lower_index_dtype = PatternMatcher([
lambda var,val: var.bind(val).cast(dtypes.weakint)), lambda var,val: var.bind(val).cast(dtypes.weakint)),
# remove hanging casts # remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("slen", dtypes.ints).cast(),), name="shrink"),
lambda shrink,buf,idx,slen: shrink.replace(src=(buf,idx,slen))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)), lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
# remove hanging casts for images # remove hanging casts for images
@ -1679,7 +1683,7 @@ pm_lower_index_dtype = PatternMatcher([
UPat.var("gate").where(UPat.var("idx_x", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))), UPat.var("gate").where(UPat.var("idx_x", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx_x,idx_y,gate: buf.index(gate.where(idx_y, idx_y.const_like(Invalid)), lambda buf,idx_x,idx_y,gate: buf.index(gate.where(idx_y, idx_y.const_like(Invalid)),
gate.where(idx_x, idx_x.const_like(Invalid)), ptr=True)), gate.where(idx_x, idx_x.const_like(Invalid)), ptr=True)),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"), (UPat((Ops.SINK, Ops.NOOP, Ops.END, Ops.AFTER, Ops.BUFFER), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))), lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
]) ])
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]

View file

@ -29,7 +29,14 @@ def const_arg(u:UOp) -> ConstType|tuple[ConstType, ...]|None:
def fold_const_alu(a:UOp) -> UOp|None: def fold_const_alu(a:UOp) -> UOp|None:
vals = [const_arg(s) for s in a.src] vals = [const_arg(s) for s in a.src]
return None if any(v is None for v in vals) else a.const_like(exec_alu(a.op, a.dtype, vals, False)) if any(v is None for v in vals): return None
if any(isinstance(v, tuple) for v in vals):
out_len = prod(a.shape)
if not all(not isinstance(v, tuple) or len(v) in {1, out_len} for v in vals): return None
return a.const_like(tuple(exec_alu(a.op, a.dtype.scalar(),
[v[0] if isinstance(v, tuple) and len(v) == 1 else v[i] if isinstance(v, tuple) else v for v in vals], False)
for i in range(out_len)))
return a.const_like(exec_alu(a.op, a.dtype, vals, False))
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i") invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat) invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
@ -121,8 +128,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
# TODO: combine this with "# rules for threefry" below # TODO: combine this with "# rules for threefry" below
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"), ((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None), lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
((UPat.var("x") & UPat.cvar("mask")) // UPat.cvar("c"),
lambda x,mask,c: x // c.arg if c.arg > 0 and c.arg & (c.arg-1) == 0 and mask.arg | (c.arg-1) == -1 else None),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"), (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# ** constant folding ** # ** constant folding **
@ -162,7 +167,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x), (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x), (((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
# ** simple where folding ** # ** simple where folding **
# a conditional with the same results either way is a noop, also fold const conditionals # a conditional with the same results either way is a noop, also fold const conditionals
@ -170,9 +174,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), (UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# a.where(b.where(c, d), d) -> (a & b).where(c, d) # a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)), (UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
# STACK on INDEX CONST (TODO: remove all the GEP crap)
(UPat(Ops.STACK, src=UPat(Ops.INDEX, src=(UPat.var("src"), UPat(Ops.CONST))), name="stk"),
lambda src,stk: src if stk.shape == src.shape and list(range(len(stk.src))) == [x.src[1].arg for x in stk.src] else None),
]) ])
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ******** # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********