Compare commits

...

29 commits

Author SHA1 Message Date
chenyu
687ade119e
IMAGE hand_coded_optimizations update (#16720) 2026-06-23 21:55:28 -04:00
George Hotz
0a8e61d0c5
switch to the new memory coaleser [pr] (#16716)
* switch to the new memory coalese

* move that stuff

* copy in allowed length logic

* mulitple buffers

* new coalese is better

* fine

* earlier

* fixes

* work

* work

* valid

* stack on index const
2026-06-23 18:03:48 -07:00
wozeparrot
dfea9e7994
llama: fused silu mul quantize mxfp8 (#16704) 2026-06-23 16:59:50 -07:00
chenyu
ce87d80911
better _drop_valid_stmts [pr] (#16719)
also dropped the unused is_increasing
2026-06-23 19:35:01 -04:00
George Hotz
5a2b3b7b06
early dtype decomp (#16718)
* early dtype decomp

* simplify

* cleanup

* that goes there

* doing too much

* stupid symbolic rules
2026-06-23 16:07:20 -07:00
Christopher Milan
116045cc8e
ci: remove tensorflow from testoptim (#16717) 2026-06-23 18:11:48 -04:00
nimlgen
7c1d0b6d9a
hcq2: use shrink(bitcast) (#16713)
* hcq2: use shrink(bitcast)

* x
2026-06-23 18:11:39 +03:00
George Hotz
c9dc1d63cc
small changes from new codegen (#16712)
* small changes from new codegen

* shrink/flatten
2026-06-22 17:44:15 -07:00
Christopher Milan
da98fae9e1
ci: try parallelizing tc tests (#16710) 2026-06-22 20:43:32 -04:00
chenyu
15988b5941
contiguous to mixin and cleanups [PR] (#16711) 2026-06-22 20:18:18 -04:00
Christopher Milan
cbfcf36e44
ci: remove generate_dataset and CL misc (#16709) 2026-06-22 18:01:07 -04:00
nimlgen
f9c8c697d6
hcq2: drop args after inner deps (#16708) 2026-06-22 23:26:11 +03:00
chenyu
0138480910
dropout and scaled_dot_product_attention to mixin (#16707) 2026-06-22 16:17:45 -04:00
chenyu
33b635d23a
Tensor.train -> TRAINING [PR] (#16705)
* Tensor.train -> TRAINING [PR]

* doc
2026-06-22 15:13:22 -04:00
chenyu
625d8bbd0d
TRAINING ContextVar (#16703) 2026-06-22 13:03:08 -04:00
wozeparrot
fe9b19b12d
llama: more mp mem fixes (#16701)
* llama: more mp mem fixes

* clean: unused

* fix: batch
2026-06-22 10:54:35 -04:00
chenyu
267af9c601
full_like to CreationMixin [PR] (#16702) 2026-06-22 09:33:23 -04:00
chenyu
97da54b9d6
more method to CreationMixin [PR] (#16698) 2026-06-22 00:01:22 -04:00
chenyu
fd0dc40689
clean up CreationMixin and DTypeMixin [PR] (#16697) 2026-06-21 21:13:40 -04:00
chenyu
2d8b802958
contiguous in wino conv (#16696)
also fixed test_counters
2026-06-21 17:11:46 -04:00
chenyu
ba1d3baae8
masked_select and nonzero to mixin [PR] (#16695)
with a .data stub
2026-06-21 15:10:44 -04:00
chenyu
d80a41d559
some rand method to RandMixin [PR] (#16693) 2026-06-21 12:16:51 -04:00
wozeparrot
5164c21b44
gemm: keep shape thru mxfp8 quantize (#16692) 2026-06-20 22:28:53 -07:00
chenyu
58ff75272e
const_like and invalids to mixin [PR] (#16690)
* const_like and invalids to mixin [PR]

* empty_like

* einsum

* type
2026-06-21 00:02:29 -04:00
chenyu
b50da5c205
move Tensor.__getitem__ to mixin [PR] (#16689) 2026-06-20 22:01:45 -04:00
chenyu
4618d27129
final const cleanups [PR] (#16688) 2026-06-20 21:38:16 -04:00
chenyu
9ae0a93d0e
more const cleanups [PR] (#16682) 2026-06-20 20:41:43 -04:00
George Hotz
30830850a9
small changes from new codegen (#16681)
* small changes from new codegen

* revert that
2026-06-19 18:29:01 -07:00
chenyu
8b07cca9f7
invalid clone try 3+ [PR] (#16679) 2026-06-19 20:13:52 -04:00
65 changed files with 1160 additions and 975 deletions

View file

@ -133,46 +133,26 @@ jobs:
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20 run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
- name: Test IMAGE support - name: Test IMAGE support
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
- name: Test emulated METAL tensor cores - name: Test emulated tensor cores
env: env:
DEV: 'PYTHON::METAL' DEBUG: 2
N: 64
CNT: 1
SHOULD_USE_TC: 1
run: | run: |
DEBUG=2 python3 test/backend/test_ops.py TestOps.test_big_gemm parallel -k --link --tagstring '[{1}]' '{2} python3 ./extra/gemm/simple_matmul.py' \
python3 -m pytest -nauto test/opt/test_tensor_cores.py ::: metal gfx950 gfx1100 gfx1100_acchalf gfx1201 gfx1201_acchalf sm_75 sm_80_half sm_80_tf32 \
- name: Test emulated AMD tensor cores ::: 'DEV=PYTHON::METAL' 'DEV=PYTHON::gfx950 HALF=1 ACC_HALF=0' \
env: 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=1 ATOL=1e-3' \
DEV: 'PYTHON::gfx1100' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=1 ATOL=1e-3' \
'DEV=PYTHON::sm_75 HALF=1' 'DEV=PYTHON::sm_80 HALF=1' 'DEV=PYTHON::sm_80 ALLOW_TF32=1'
- name: Run additional tensor core tests
run: | run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::METAL python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::gfx1100 python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::gfx950 python3 -m pytest -nauto test/opt/test_tensor_cores.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::gfx1201 python3 -m pytest -nauto test/opt/test_tensor_cores.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD MFMA tensor cores
env:
DEV: 'PYTHON::gfx950'
run: |
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD RDNA4 tensor cores
env:
DEV: 'PYTHON::gfx1201'
run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated CUDA tensor cores
run: |
DEBUG=2 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test device flop counts
run: |
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
linter: linter:
@ -267,13 +247,6 @@ jobs:
run: python3 test/external/external_benchmark_schedule.py run: python3 test/external/external_benchmark_schedule.py
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
- name: Regen dataset on test_tiny
run: |
test/external/process_replay/reset.py
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
python extra/optimization/extract_dataset.py
gzip -c /tmp/sops > extra/datasets/sops.gz
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
- name: Repo line count < 25000 lines - name: Repo line count < 25000 lines
run: MAX_LINE_COUNT=25000 python sz.py run: MAX_LINE_COUNT=25000 python sz.py
@ -338,31 +311,6 @@ jobs:
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
testgpumisc:
name: CL Misc tests
runs-on: *linux
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: gen-dataset
deps: testing
opencl: 'true'
- name: Generate Dataset
run: DEV=CL extra/optimization/generate_dataset.sh
- name: Run Kernel Count Test
run: DEV=CL python -m pytest -n=auto test/external/external_test_opt.py
- name: Run fused optimizer tests
run: DEV=CL FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py test/backend/test_optim.py -k "not muon"
- name: Upload artifact
uses: actions/upload-artifact@v7
with:
name: sops.gz
path: /tmp/sops.gz
testopenpilot: testopenpilot:
name: openpilot Compile Tests name: openpilot Compile Tests
runs-on: *linux runs-on: *linux
@ -379,7 +327,7 @@ jobs:
llvm: 'true' llvm: 'true'
- name: Test openpilot model kernel count and gate usage - name: Test openpilot model kernel count and gate usage
run: | run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness) - name: Test openpilot CL compile fp32 (test correctness)
run: | run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
@ -422,7 +370,6 @@ jobs:
with: with:
key: optim key: optim
deps: testing deps: testing
pydeps: "tensorflow==2.19"
opencl: 'true' opencl: 'true'
#- name: Test Optimization Helpers #- name: Test Optimization Helpers
# run: DEBUG=1 python3 extra/optimization/test_helpers.py # run: DEBUG=1 python3 extra/optimization/test_helpers.py
@ -431,7 +378,7 @@ jobs:
- name: Test Beam Search - name: Test Beam Search
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Test MLPerf stuff - name: Test MLPerf stuff
run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20 run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
- name: DEV=NULL beautiful_mnist_multigpu - name: DEV=NULL beautiful_mnist_multigpu
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
- name: Test Bert training - name: Test Bert training

View file

@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
Throw in an optimizer, a data loader, and some compute, and you have all you need. Throw in an optimizer, a data loader, and some compute, and you have all you need.
```python ```python
from tinygrad import Tensor, nn from tinygrad import Tensor, nn, Context
class LinearNet: class LinearNet:
def __init__(self): def __init__(self):
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
with Tensor.train(): with Context(TRAINING=1):
for i in range(10): for i in range(10):
optim.zero_grad() optim.zero_grad()
loss = model(x).sparse_categorical_crossentropy(y).backward() loss = model(x).sparse_categorical_crossentropy(y).backward()

View file

@ -165,13 +165,14 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network. Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64. We will be training for 1000 steps with a batch size of 64.
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training. We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager. Upon exit, the flag is restored to its previous value by the context manager.
```python ```python
from tinygrad import Context
X_train, Y_train, X_test, Y_test = fetch_mnist() X_train, Y_train, X_test, Y_test = fetch_mnist()
with Tensor.train(): with Context(TRAINING=1):
for step in range(1000): for step in range(1000):
# random sample a batch # random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64)) samp = np.random.randint(0, X_train.shape[0], size=(64))

View file

@ -1,6 +1,6 @@
from typing import Tuple from typing import Tuple
import time import time
from tinygrad import Tensor, TinyJit, nn from tinygrad import Tensor, TinyJit, nn, Context
import gymnasium as gym import gymnasium as gym
from tinygrad.helpers import trange from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import import numpy as np # TODO: remove numpy import
@ -55,7 +55,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]: def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train(): with Context(TRAINING=1):
log_dist, value = model(x) log_dist, value = model(x)
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float() action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()

View file

@ -122,7 +122,7 @@ if __name__ == "__main__":
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(idxs:Tensor) -> Tensor: def train_step(idxs:Tensor) -> Tensor:
X, Y = X_train[idxs], Y_train[idxs] X, Y = X_train[idxs], Y_train[idxs]
if len(GPUS) > 1: if len(GPUS) > 1:

View file

@ -1,6 +1,6 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -19,7 +19,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor: def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])

View file

@ -1,6 +1,6 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -31,7 +31,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step() -> Tensor: def train_step() -> Tensor:
with Tensor.train(): with Context(TRAINING=1):
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0 Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0

View file

@ -1,6 +1,6 @@
import itertools import itertools
from typing import Callable from typing import Callable
from tinygrad import nn, Tensor, dtypes, Device, TinyJit from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
from tinygrad.helpers import getenv, trange, partition from tinygrad.helpers import getenv, trange, partition
class Model: class Model:
@ -59,7 +59,7 @@ if __name__ == "__main__":
Tensor.realize(*params, *buffers, *adam_params, loss, grads) Tensor.realize(*params, *buffers, *adam_params, loss, grads)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def microbatch(): def microbatch():
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0]) samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
for t in params: t.grad = None for t in params: t.grad = None

View file

@ -359,7 +359,7 @@ def train_cifar():
i = 0 i = 0
eval_acc_pct = 0.0 eval_acc_pct = 0.0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train(): with Context(TRAINING=1):
st = time.monotonic() st = time.monotonic()
while i <= STEPS: while i <= STEPS:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"): if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os, math, time import os, math, time
import numpy as np import numpy as np
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
@ -177,7 +177,7 @@ if __name__ == "__main__":
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0) if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def step(x:Tensor, y:Tensor) -> Tensor: def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y) _, loss = model(x, y)
optimizer.zero_grad() optimizer.zero_grad()
@ -204,4 +204,3 @@ if __name__ == "__main__":
top_k = 40 top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist())) print(decode(y[0].tolist()))

View file

@ -1,5 +1,5 @@
# much taken from https://github.com/cloneofsimo/minRF # much taken from https://github.com/cloneofsimo/minRF
from tinygrad import Tensor, nn, GlobalCounters, TinyJit from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
from tinygrad.helpers import getenv, trange from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@ -135,7 +135,7 @@ if __name__ == "__main__":
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4) optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(): def train_step():
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int') if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0]) else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])

View file

@ -358,7 +358,7 @@ def eval_stable_diffusion():
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape)) batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
return batch, unpadded_bs return batch, unpadded_bs
@Tensor.train(mode=False) @Context(TRAINING=0)
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL, def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]: inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
# Eval is divided into 5 jits, one per model # Eval is divided into 5 jits, one per model

View file

@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
from pathlib import Path from pathlib import Path
import multiprocessing import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@ -614,7 +614,7 @@ def train_retinanet():
if getenv("RESET_STEP", 1): _train_step.reset() if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False): with Context(TRAINING=0):
if not RUNMLPERF: if not RUNMLPERF:
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True)) i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
else: else:
@ -784,7 +784,7 @@ def train_unet3d():
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(model, x, y): def train_step(model, x, y):
optim.zero_grad() optim.zero_grad()
@ -795,7 +795,7 @@ def train_unet3d():
optim.step() optim.step()
return loss.realize() return loss.realize()
@Tensor.train(mode=False) @Context(TRAINING=0)
def eval_step(model, x, y): def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS) y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y) y_hat, y = Tensor(y_hat), Tensor(y)
@ -1490,7 +1490,7 @@ def train_llama3():
return lr_cpu, grad_norm_cpu return lr_cpu, grad_norm_cpu
@TinyJit @TinyJit
@Tensor.train(False) @Context(TRAINING=0)
def eval_step(tokens:Tensor): def eval_step(tokens:Tensor):
if is_dp: tokens = tokens.to(None).shard(device, 0) if is_dp: tokens = tokens.to(None).shard(device, 0)
if is_mp: tokens = tokens.shard(device) if is_mp: tokens = tokens.shard(device)
@ -1803,7 +1803,7 @@ if __name__ == "__main__":
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN) elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
else: bench_log_manager = contextlib.nullcontext() else: bench_log_manager = contextlib.nullcontext()
with Tensor.train(): with Context(TRAINING=1):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","): for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
nm = f"train_{m}" nm = f"train_{m}"
if nm in globals(): if nm in globals():

View file

@ -38,7 +38,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None, def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None, x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]: grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]:
if not fp8: if not fp8:
if ASM_GEMM: if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
@ -47,12 +47,14 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)" assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
if MXFP8: if MXFP8:
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1])) if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d)
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
if can_use_asm_gemm(x_q, w.T): if can_use_asm_gemm(x_q, w.T):
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale), out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0]) mx_w_stored=True).reshape(*l_shape, w.shape[0])
else: else:
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1]) x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1])
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
return out, (amax_x.detach() if amax_x is not None else None), x_q return out, (amax_x.detach() if amax_x is not None else None), x_q
if x_fp8 is None: if x_fp8 is None:
@ -126,10 +128,8 @@ class FlatTransformer:
# FeedForward # FeedForward
if SPLIT_W13: if SPLIT_W13:
if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim) self.w1, s_1 = self.lin_per_layer(dim, hidden_dim)
else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02) self.w3, s_3 = self.lin_per_layer(dim, hidden_dim)
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
else: else:
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2) self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std) self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
@ -160,7 +160,7 @@ class FlatTransformer:
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None): def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
if w is None: if w is None:
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features) if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize() else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
if MXFP8: if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8 from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features)) w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
@ -216,8 +216,15 @@ class FlatTransformer:
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"]) x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
amaxs.append(new_amax) amaxs.append(new_amax)
saves.extend([*s, x_w3]) saves.extend([*s, x_w3])
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"], if FUSED_SILU_W13 and MXFP8:
grad_amax_state=kwargs["grad_amax_xout"]) from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
else:
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
grad_amax_state=kwargs["grad_amax_xout"])
amaxs.append(new_amax) amaxs.append(new_amax)
saves.extend([*s, out]) saves.extend([*s, out])
else: else:
@ -247,20 +254,30 @@ class FlatTransformer:
for v in get_parameters(self): v.shard_(device, axis=None) for v in get_parameters(self): v.shard_(device, axis=None)
else: else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer # flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
def _shard_fp8(name:str, axis:int): def _shard_fp8(name:str, axis:int, std:float=0.02):
getattr(self, name).shard_(device, axis=axis) w = getattr(self, name)
scale_axis = axis if MXFP8 else (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None if MXFP8:
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False) from extra.gemm.cdna_asm_gemm import quantize_mxfp8
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False) w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name]) w_q, w_e8, _ = quantize_mxfp8(w_bf16)
w.replace(w_q)
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
else:
w.shard_(device, axis=axis)
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
sstd = 0.02 / math.sqrt(2 * self.n_layers)
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out _shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in _shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in
if SPLIT_W13: if SPLIT_W13:
_shard_fp8("w1", 1) _shard_fp8("w1", 1)
_shard_fp8("w3", 1) _shard_fp8("w3", 1)
else: else:
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out _shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in _shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize() self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize() self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize() self.norm.weight.shard_(device, axis=None).realize()

View file

@ -3,7 +3,7 @@ import torch
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import trange from tinygrad.helpers import trange, Context
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -86,7 +86,7 @@ if __name__ == "__main__":
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium! optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5) optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop # training loop
with Tensor.train(): with Context(TRAINING=1):
for epoch in (t := trange(epochs)): for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0 loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps): for _ in range(n_steps):

View file

@ -5,7 +5,7 @@
# - symbolic removal # - symbolic removal
from examples.beautiful_mnist import Model from examples.beautiful_mnist import Model
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
from tinygrad.helpers import trange from tinygrad.helpers import trange
@ -26,7 +26,7 @@ if __name__ == "__main__":
X_samp, Y_samp = X_train[samples], Y_train[samples] X_samp, Y_samp = X_train[samples], Y_train[samples]
print("*** got samples") print("*** got samples")
with Tensor.train(): with Context(TRAINING=1):
""" """
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i) losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)

View file

@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float) # accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG) acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
acc = acc.after(acc.store(acc.zeros_like())) acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
if use_wmma: if use_wmma:
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE) k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)

View file

@ -2674,14 +2674,14 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]: def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# 1x32 block scaling along the last axis # 1x32 block scaling along the last axis
rows, K = x.shape *batch, K = x.shape
scale_K, k_iters = K // 32, K // 128 scale_K = K // 32
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1) amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8) e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K) qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
x_scaled = x.float() * qscale x_scaled = x.float() * qscale
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8) return x_clamped.cast(FP8_DTYPE), e8, (mx_pack(e8) if len(batch) == 1 else None)
def mx_pack(e8:Tensor) -> Tensor: def mx_pack(e8:Tensor) -> Tensor:
rows, scale_K = e8.shape rows, scale_K = e8.shape

View file

@ -143,14 +143,17 @@ def make_getaddr(u, device=None):
def make_ins(op, *srcs): def make_ins(op, *srcs):
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op) return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
dt = dtype or val.dtype
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
def make_cmdbuf(lin, devs, tag): def make_cmdbuf(lin, devs, tag):
blob, patches = b'', [] blob, patches = b'', []
for s in (s for ins in lin.src for s in ins.src): for s in (s for ins in lin.src for s in ins.src):
if s.op is not Ops.CONST: patches.append((len(blob), s)) if s.op is not Ops.CONST: patches.append((len(blob), s))
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0) blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag) buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches] return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches])
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops)) def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
@ -211,15 +214,11 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call)) return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
def prep_kernargs(call:UOp, prg:UOp) -> UOp: def prep_kernargs(call:UOp, prg:UOp) -> UOp:
data, info = prg.arg (data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device)
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))), buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
dtypes.uint64) for i,gi in enumerate(info.globals)] \ patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)] + [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)]
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
pm_prep_runtime = PatternMatcher([ pm_prep_runtime = PatternMatcher([
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering # bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
@ -307,7 +306,7 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
for (_, lane), dep in latest.items(): deps[dep] += (lane,) for (_, lane), dep in latest.items(): deps[dep] += (lane,)
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps") if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:]))) new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}),)))
return linear.replace(src=tuple(new_src)) return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)]) pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
@ -532,9 +531,9 @@ pm_resolve_patches = PatternMatcher([
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack), (UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack), (UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
# index on slice is index # shrink on slice is shrink on base at offset
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True), (UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"),
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))), lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))),
# getaddr # getaddr
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x) (UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
@ -542,8 +541,8 @@ pm_resolve_patches = PatternMatcher([
# folders # folders
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store), (UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), (UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast()
fold_const_store), .store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store),
]) + symbolic_simple ]) + symbolic_simple
# ***************** # *****************

View file

@ -0,0 +1,104 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
LOG2E = 1.4426950408889634
@functools.cache
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
act = w1 * sig * w3
abs_a = (act < 0.0).where(-act, act)
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
@functools.cache
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
VEC = 8
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
nwg = n_elems // (THREADS_PER_WG * VEC)
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
lane = UOp.range(VEC, 2, AxisType.UNROLL)
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
e8v = e8[idx // BLK].cast(dtypes.float)
qscale = (127.0 - e8v).exp2()
ga = grad_aq[idx].cast(dtypes.float) * qscale
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
s = w1 * sig
sprime = sig * (1.0 + w1 * (1.0 - sig))
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
device = x_w1.device
rows, K = x_w1.shape
axis = x_w1.axis if isinstance(device, tuple) else None
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
fxn=_custom_silu_mul_bwd_mxfp8)
return (None, None, None, gx1.uop, gx3.uop)
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x_w1.shape
scale_K = K // BLK
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
return fp8_out, e8_out, si_out

View file

@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
step = THREADS_PER_WG // 2 step = THREADS_PER_WG // 2
while step: while step:
active = tid < step active = tid < step
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active) other = lds[(tid + step).valid(active)].load()
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier()) lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
step //= 2 step //= 2
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0]) amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])

View file

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import trange from tinygrad.helpers import trange, Context
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit
@ -22,7 +22,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
if allow_jit: train_step = TinyJit(train_step) if allow_jit: train_step = TinyJit(train_step)
with Tensor.train(): with Context(TRAINING=1):
losses, accuracies = [], [] losses, accuracies = [], []
for i in (t := trange(steps, disable=None)): for i in (t := trange(steps, disable=None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS)) samp = np.random.randint(0, X_train.shape[0], size=(BS))
@ -55,4 +55,3 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
acc, Y_test_pred = numpy_eval(Y_test, num_classes) acc, Y_test_pred = numpy_eval(Y_test, num_classes)
print("test set accuracy is %f" % acc) print("test set accuracy is %f" % acc)
return (acc, Y_test_pred) if return_predict else acc return (acc, Y_test_pred) if return_predict else acc

View file

@ -25,7 +25,7 @@
import unittest import unittest
import numpy as np import numpy as np
import torch import torch
from tinygrad import Tensor, dtypes, nn from tinygrad import Tensor, dtypes, nn, Context
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.helpers import DEV from tinygrad.helpers import DEV
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
@ -101,7 +101,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
# we don't need more of these # we don't need more of these
def test_dropout_rate_one(self): def test_dropout_rate_one(self):
with Tensor.train(): with Context(TRAINING=1):
out = Tensor.ones(100).dropout(1.0) out = Tensor.ones(100).dropout(1.0)
np.testing.assert_allclose(out.numpy(), np.zeros(100)) np.testing.assert_allclose(out.numpy(), np.zeros(100))
@ -109,7 +109,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
torch.nn.functional.dropout(torch.ones(10), -0.1, True) torch.nn.functional.dropout(torch.ones(10), -0.1, True)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
with Tensor.train(): with Context(TRAINING=1):
Tensor.ones(10).dropout(-0.1) Tensor.ones(10).dropout(-0.1)
class TestInputValidation(unittest.TestCase): class TestInputValidation(unittest.TestCase):

View file

@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
renderer=Device[Device.DEFAULT].renderer).src[2].src) renderer=Device[Device.DEFAULT].renderer).src[2].src)
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD]) num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
assert num_loads <= 4, "more load uops than needed" assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" assert num_loads >= 1, "expected at least one load uop"
@unittest.skip("this is handled at higher level now") @unittest.skip("this is handled at higher level now")
def test_upcast_cse(self): def test_upcast_cse(self):

View file

@ -14,7 +14,7 @@ from test.helpers import not_support_multi_device, needs_second_gpu, slow
@slow @slow
class TestNN(unittest.TestCase): class TestNN(unittest.TestCase):
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True): def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
with Tensor.train(training): with Context(TRAINING=training):
szs = [4, 8, 16, 32] szs = [4, 8, 16, 32]
for sz in szs: for sz in szs:
# create in tinygrad # create in tinygrad

View file

@ -41,7 +41,7 @@ class TestStunning(unittest.TestCase):
X_samp, Y_samp = X_train[samples], Y_train[samples] X_samp, Y_samp = X_train[samples], Y_train[samples]
vi = Variable('i', 0, samples.shape[0]-1) vi = Variable('i', 0, samples.shape[0]-1)
with Context(SPLIT_REDUCEOP=0): with Context(SPLIT_REDUCEOP=0):
with Tensor.train(): with Context(TRAINING=1):
losses = [] losses = []
for i in range(samples.shape[0]): for i in range(samples.shape[0]):
vib = vi.bind(i) vib = vi.bind(i)

View file

@ -1,5 +1,5 @@
import unittest import unittest
from tinygrad import Tensor, Variable, GlobalCounters from tinygrad import Tensor, Variable, GlobalCounters, Context
from tinygrad.uop.ops import sym_infer from tinygrad.uop.ops import sym_infer
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from examples.gpt2 import Attention from examples.gpt2 import Attention
@ -63,7 +63,7 @@ class TestSymbolicOps(unittest.TestCase):
self.test_attention(imin=4, imax=5, use_symbolic=True) self.test_attention(imin=4, imax=5, use_symbolic=True)
def test_attention_training(self): def test_attention_training(self):
with Tensor.train(): with Context(TRAINING=1):
self.test_attention(dropout_p=0.0) self.test_attention(dropout_p=0.0)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# symbolic shape dropout is not supported # symbolic shape dropout is not supported

View file

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import torch import torch
import unittest, copy, mmap, random, math, array import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes, nn from tinygrad import Tensor, Device, dtypes, nn, Context
from tinygrad.helpers import getenv, temp, mv_address from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat from hypothesis import given, settings, strategies as strat
@ -203,7 +203,7 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(x, y, atol=1e-5) np.testing.assert_allclose(x, y, atol=1e-5)
def test_dropout(self): def test_dropout(self):
with Tensor.train(): with Context(TRAINING=1):
n, rate = 1_000_000, 0.1 n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate) w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.numpy()) non_zeros = np.count_nonzero(w.numpy())

View file

@ -0,0 +1,67 @@
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import AdamW
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TinyNet:
def __init__(self):
self.x = Tensor(x_init.copy())
self.W = Tensor(W_init.copy())
self.m = Tensor(m_init.copy())
def forward(self):
out = self.x.matmul(self.W).relu()
out = out.log_softmax(1)
out = out.mul(self.m).add(self.m).sum()
return out
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__':
unittest.main()

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest, math import unittest
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.optimizers import Lamb from tensorflow.keras.optimizers import Lamb
@ -7,11 +7,11 @@ from tensorflow.python.ops import math_ops
from extra.lr_scheduler import LRSchedulerGroup from extra.lr_scheduler import LRSchedulerGroup
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
np.random.seed(1337) np.random.seed(1337)
@ -173,48 +173,5 @@ class ExternalTestOptim(unittest.TestCase):
'warmup': steps_per_epoch * warmup_epochs, 'warmup': steps_per_epoch * warmup_epochs,
}, 1e-5, 1e-5, do_optim=False) }, 1e-5, 1e-5, do_optim=False)
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -1,6 +1,6 @@
import unittest, os import unittest, os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from tinygrad import Tensor from tinygrad import Context
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from examples.mlperf.model_train import train_stable_diffusion from examples.mlperf.model_train import train_stable_diffusion
@ -14,10 +14,10 @@ class TestTrain(unittest.TestCase):
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion" if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
with TemporaryDirectory(prefix="test-train") as tmp: with TemporaryDirectory(prefix="test-train") as tmp:
os.environ["UNET_CKPTDIR"] = tmp os.environ["UNET_CKPTDIR"] = tmp
with Tensor.train(): with Context(TRAINING=1):
saved_ckpts = train_stable_diffusion() saved_ckpts = train_stable_diffusion()
expected_ckpt = f"{tmp}/{num_steps}.safetensors" expected_ckpt = f"{tmp}/{num_steps}.safetensors"
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
if __name__=="__main__": if __name__=="__main__":
unittest.main() unittest.main()

View file

@ -3,7 +3,7 @@ import ast, pathlib, unittest
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tinygrad import Tensor from tinygrad import Tensor, Context
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from test.helpers import slow from test.helpers import slow
from extra.models.efficientnet import EfficientNet from extra.models.efficientnet import EfficientNet
@ -40,7 +40,7 @@ def preprocess(img, new=False):
return img return img
def _infer(model: EfficientNet, img): def _infer(model: EfficientNet, img):
with Tensor.train(False): with Context(TRAINING=0):
out = model.forward(Tensor(img)).argmax(axis=-1) out = model.forward(Tensor(img)).argmax(axis=-1)
return out.tolist() return out.tolist()

View file

@ -5,10 +5,11 @@ import numpy as np
from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import Context
from extra.datasets import fetch_mnist from extra.datasets import fetch_mnist
def compare_tiny_torch(model, model_torch, X, Y): def compare_tiny_torch(model, model_torch, X, Y):
with Tensor.train(): with Context(TRAINING=1):
model_torch.train() model_torch.train()
model_state_dict = get_state_dict(model) model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters(): for k,v in model_torch.named_parameters():

View file

@ -106,7 +106,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_train_mnist(self): def test_train_mnist(self):
from examples.beautiful_mnist import Model from examples.beautiful_mnist import Model
with Tensor.train(): with Context(TRAINING=1):
model = Model() model = Model()
optimizer = optim.Adam(get_parameters(model)) optimizer = optim.Adam(get_parameters(model))
BS = 32 BS = 32
@ -125,7 +125,7 @@ class TestRealWorld(unittest.TestCase):
def test_forward_cifar(self): def test_forward_cifar(self):
BS = 32 BS = 32
# with training batchnorm still though # with training batchnorm still though
with Tensor.train(): with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
@TinyJit @TinyJit
def run(X): return model(X) def run(X): return model(X)
@ -133,7 +133,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_train_cifar(self): def test_train_cifar(self):
with Tensor.train(): with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15) optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
BS = 32 BS = 32
@ -151,7 +151,7 @@ class TestRealWorld(unittest.TestCase):
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16") @unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_train_cifar_hyp(self): def test_train_cifar_hyp(self):
dtypes.default_float = dtypes.float16 dtypes.default_float = dtypes.float16
with Tensor.train(): with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2))) model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay']) optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
initial_div_factor = hyp['opt']['initial_div_factor'] initial_div_factor = hyp['opt']['initial_div_factor']
@ -163,7 +163,7 @@ class TestRealWorld(unittest.TestCase):
@slow @slow
def test_bert(self): def test_bert(self):
with Tensor.train(): with Context(TRAINING=1):
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2, args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2} "max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
model = BertForPretraining(**args_tiny) model = BertForPretraining(**args_tiny)

View file

@ -1093,14 +1093,14 @@ class TestSchedule(unittest.TestCase):
#@unittest.skip("may want to reconsider this") #@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self): def test_fold_batchnorm(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(1,32,4,4) img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img) out = bn(img)
check_schedule(out, 3, nn.state.get_parameters(bn)) check_schedule(out, 3, nn.state.get_parameters(bn))
def test_fold_conv_batchnorm_notrain(self): def test_fold_conv_batchnorm_notrain(self):
with Tensor.train(False): with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=True) bn = nn.BatchNorm2d(32, track_running_stats=True)
@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)]) check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm_notrain_no_running_stats(self): def test_fold_conv_batchnorm_notrain_no_running_stats(self):
with Tensor.train(False): with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1116,7 +1116,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)]) check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm(self): def test_fold_conv_batchnorm(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(1,3,8,8) img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1125,7 +1125,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self, adam=False): def test_fold_conv_batchnorm_optim(self, adam=False):
optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15) optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15)
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.ones(1,3,4,4) img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False) bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1139,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True) def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
def test_fold_batchnorm_backward(self): def test_fold_batchnorm_backward(self):
with Tensor.train(): with Context(TRAINING=1):
x = Tensor.empty((2, 16, 8, 8)).contiguous() x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16) bn = nn.BatchNorm2d(16)
fw = bn(x).contiguous_backward().relu().contiguous() fw = bn(x).contiguous_backward().relu().contiguous()
@ -1484,7 +1484,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4) check_schedule(out, 4)
def test_adam_step_fusion(self): def test_adam_step_fusion(self):
with Tensor.train(): with Context(TRAINING=1):
x = Tensor.empty(4, 64, 32) x = Tensor.empty(4, 64, 32)
layer = nn.Linear(32, 32*4) layer = nn.Linear(32, 32*4)
_realize_weights(layer) _realize_weights(layer)
@ -1494,7 +1494,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13) check_schedule(opt.schedule_step(), 13)
def test_adam_conv_fuse(self): def test_adam_conv_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
_realize_weights(c1) _realize_weights(c1)
@ -1505,7 +1505,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13) check_schedule(opt.schedule_step(), 13)
def test_adam_2convs_fuse(self): def test_adam_2convs_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1517,7 +1517,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15) check_schedule(opt.schedule_step(), 15)
def test_sgd_conv_fuse(self): def test_sgd_conv_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3) c1 = nn.Conv2d(3,32,3)
_realize_weights(c1) _realize_weights(c1)
@ -1527,7 +1527,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 5) # TODO: 3? check_schedule(opt.schedule_step(), 5) # TODO: 3?
def test_sgd_2convs_fuse(self): def test_sgd_2convs_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1538,7 +1538,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 7) check_schedule(opt.schedule_step(), 7)
def test_fold_2convs_sgd_nesterov_momentum_wd(self): def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4) img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False) c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False) c2 = nn.Conv2d(16,32,2,bias=False)
@ -1550,7 +1550,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 11) check_schedule(opt.schedule_step(), 11)
def test_sgd_4convs_fuse(self): def test_sgd_4convs_fuse(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16) img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False) c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False)
@ -1563,7 +1563,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15) check_schedule(opt.schedule_step(), 15)
def test_sgd_4convs_fuse_conv_bw(self): def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train(): with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16) img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False) c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False)
@ -1664,7 +1664,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0) self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0)
def test_resnet_block(self): def test_resnet_block(self):
with Tensor.train(False): with Context(TRAINING=0):
in_planes, planes = 64, 64 in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes) bn1 = nn.BatchNorm2d(planes)

View file

@ -16,41 +16,17 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
def get_gated_load_uop(valid:UOp, idx:UOp): def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, ( return UOp(Ops.LOAD, dtypes.float, (
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True), UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
)) ))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]): def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), ( return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True), UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
)) ))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr) def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(nmax, n) def Range(n, nmax): return UOp.range(nmax, n)
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
f3 = (idx2*2)+ridx1+(-1)
self.assertFalse(f0.is_increasing())
self.assertTrue(f1.is_increasing())
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())
class TestValidIdxSimplification(unittest.TestCase): class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid, extra=()): def check(self, load, sidx, svalid, extra=()):
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0] load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
@ -506,6 +482,16 @@ class TestImageSimplification(unittest.TestCase):
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))", self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0") "(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
def test_drop_non_monotonic_window(self):
# two-sided window valid (645 <= gidx0 < 653) on a non-monotonic index (lane split via %4 and //4):
# gidx0 outside the window pushes idx_x out of the (1, 48) image, so the gate is dropped
gidx0 = Special("gidx0", 1064)
r12 = Range(12, 3)
valid = ((gidx0 < 645).ne(True)) & (gidx0 < 653)
idx = (r12*4 + (gidx0+3)%4 + (gidx0+3)//4*24 - 3888, UOp.const(dtypes.weakint, 0))
load = get_load_image_uop((1, 48, 4), valid, idx)
self.check(load, None, "(r12*4+(gidx0+3)%4+(gidx0+3)//4*24+-3888)", "0")
class TestDropTrueGate(unittest.TestCase): class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self): def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True valid gets simplified to drop the valid # test that INDEX with a constant True valid gets simplified to drop the valid

View file

@ -1,7 +1,7 @@
# tensor tests that pass on NULL backend (no copyout needed) # tensor tests that pass on NULL backend (no copyout needed)
import numpy as np import numpy as np
import unittest import unittest
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import Ops, UOp from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer from tinygrad.renderer.nir import NIRRenderer
@ -15,7 +15,7 @@ m_init = np.random.randn(1,3).astype(np.float32)
class TestTrainMode(unittest.TestCase): class TestTrainMode(unittest.TestCase):
def test_train_mode(self): def test_train_mode(self):
assert not Tensor.training assert not Tensor.training
@Tensor.train() @Context(TRAINING=1)
def f(): def f():
assert Tensor.training assert Tensor.training
f() f()

View file

@ -317,6 +317,19 @@ class TestTensorUOpScatterReduce(unittest.TestCase):
def test_mean_exclude_self(self): def test_mean_exclude_self(self):
self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean", include_self=False) self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean", include_self=False)
class TestTensorUOpMaskedSelect(unittest.TestCase):
# only the fixed-size path is pure
def _check(self, t, mask, **kw):
self.assertIs(t.masked_select(mask, **kw).uop, t.uop.masked_select(mask.uop, **kw))
def test_masked_select_1d(self): self._check(_t(6), Tensor([True, False, True, False, True, False]), size=4)
def test_masked_select_2d(self):
self._check(_t(3, 3), Tensor([[True, False, True], [False, True, False], [False, False, True]]), size=6, fill_value=-1)
class TestTensorUOpNonzero(unittest.TestCase):
def _check(self, t, **kw): self.assertIs(t.nonzero(**kw).uop, t.uop.nonzero(**kw))
def test_nonzero_1d(self): self._check(_t(5), size=3)
def test_nonzero_2d(self): self._check(_t(2, 3), size=4)
class TestTensorUOpPool(unittest.TestCase): class TestTensorUOpPool(unittest.TestCase):
def test_avg_pool2d(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d()) def test_avg_pool2d(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d())
def test_avg_pool2d_padding(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(padding=1)) def test_avg_pool2d_padding(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(padding=1))
@ -462,6 +475,10 @@ class TestTensorUOpCreation(unittest.TestCase):
self.assertIs(_strip_unique(Tensor.ones(2, 3).uop), _strip_unique(UOp.ones(2, 3))) self.assertIs(_strip_unique(Tensor.ones(2, 3).uop), _strip_unique(UOp.ones(2, 3)))
def test_invalids(self): def test_invalids(self):
self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids((2, 3), dtype=dtypes.int8))) self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids((2, 3), dtype=dtypes.int8)))
def test_empty_like(self):
t = Tensor.empty(2, 3, dtype=dtypes.int8)
self.assertIs(_strip_unique(t.empty_like().uop), _strip_unique(t.uop.empty_like()))
self.assertIs(_strip_unique(t.empty_like(dtype=dtypes.float, device="NULL").uop), _strip_unique(t.uop.empty_like(dtypes.float, "NULL")))
def test_arange(self): def test_arange(self):
self.assertIs(Tensor.arange(5).uop, UOp.arange(5)) self.assertIs(Tensor.arange(5).uop, UOp.arange(5))
def test_arange_empty(self): def test_arange_empty(self):

View file

@ -13,35 +13,26 @@ class TestWinograd(unittest.TestCase):
def test_forward_kernels(self): def test_forward_kernels(self):
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize() x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
out = Tensor.conv2d(x,w) out = Tensor.conv2d(x,w)
self.assertEqual(len(out.schedule_linear().src), 2) self.assertEqual(len(out.schedule_linear().src), 4)
def test_backward_kernels(self): def test_backward_kernels(self):
x,w = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,4,3,3).realize() x,w = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,4,3,3).realize()
out = Tensor.conv2d(x,w, padding=1) out = Tensor.conv2d(x,w, padding=1)
out.mean().backward() out.mean().backward()
backward_schedule = x.grad.schedule_linear(w.grad) backward_schedule = x.grad.schedule_linear(w.grad)
self.assertEqual(len(backward_schedule.src), 2) self.assertEqual(len(backward_schedule.src), 4)
@unittest.skip("this requires optimizations")
def test_counters(self): def test_counters(self):
IC, OC, X, Y = 4,4,9,9 IC, OC, H = 64, 64, 28
x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize() x,w = Tensor.empty(1,IC,H,H,device="NULL").realize(), Tensor.empty(OC,IC,3,3,device="NULL").realize()
GlobalCounters.reset() GlobalCounters.reset()
with Context(WINO=1): with Context(NOOPT=0, WINO=1): Tensor.conv2d(x,w).realize()
Tensor.conv2d(x,w).realize() ops_wino = GlobalCounters.global_ops
ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
GlobalCounters.reset() GlobalCounters.reset()
with Context(WINO=0): with Context(NOOPT=0, WINO=0): Tensor.conv2d(x,w).realize()
Tensor.conv2d(x,w).realize() ops_normal = GlobalCounters.global_ops
ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem print(f"ops: normal {ops_normal} wino {ops_wino} ratio {ops_wino/ops_normal:.2f}")
self.assertLess(ops_wino/ops_normal, 0.6)
ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
# TODO: what's optimal on this?
self.assertLess(ops_ratio, 4.3)
self.assertLess(mem_ratio, 4)
def test_dtype(self): def test_dtype(self):
IC, OC, X, Y = 4,4,9,9 IC, OC, X, Y = 4,4,9,9

View file

@ -222,7 +222,7 @@ class TestCallSchedule(unittest.TestCase):
# find the FUNCTION nodes # find the FUNCTION nodes
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION) c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION) c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
# the function bodies (src[0]) should have identical keys — unique consts must not leak through # the function bodies (src[0]) should have identical keys
self.assertEqual(c0.src[0].key, c1.src[0].key) self.assertEqual(c0.src[0].key, c1.src[0].key)
def test_precompile_symbolic_2d(self): def test_precompile_symbolic_2d(self):

View file

@ -1,8 +1,9 @@
import numpy as np import numpy as np
import unittest import unittest
from tinygrad.function import function from tinygrad.function import function
from tinygrad import Tensor, GlobalCounters from tinygrad import Tensor, GlobalCounters, Device
from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.dtype import dtypes, Invalid
from tinygrad.uop.ops import UOp, Ops, KernelInfo, ProgramInfo
class TestFunction(unittest.TestCase): class TestFunction(unittest.TestCase):
def test_simple(self): def test_simple(self):
@ -549,6 +550,36 @@ class TestFunctionTuple(unittest.TestCase):
f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize() f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize()
np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.]) np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.])
def test_custom_kernel_program_invalids_not_captured(self):
# llama FP8 kernels are PROGRAM with bare-buffer sinks (no analyzable stores), so the invalids scratch
# still must not be captured as an input -- else it is read before the kernel writes it
src = "void k(float* restrict data0, float* restrict data1) { for (int i=0;i<4;i++) data0[i]=data1[i]*2.0f; }"
lib = Device["CPU"].compiler.compile(src)
def prog(C:UOp, A:UOp) -> UOp:
sink = UOp.sink(C.base, A.base, arg=KernelInfo(name="k"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="CPU"), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)),
arg=ProgramInfo(name="k", global_size=(1, 1, 1), local_size=(1, 1, 1), globals=(0, 1)))
@function(precompile=True)
def f(a:Tensor):
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
return Tensor.custom_kernel(c, a, fxn=prog)[0]
a = Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()
np.testing.assert_allclose(f(a).numpy(), [2., 4., 6., 8.])
def test_invalid_store_into_realized_buffer_is_captured(self):
# only fresh invalids() scratch is skipped; a realized buffer is a real input even if an Invalid store
# writes into part of it (its other elements must be preserved), so it is still captured
state = Tensor([10., 20., 30., 40.], device="CPU").contiguous().realize()
@function(precompile=True, allow_implicit=True)
def f(a:Tensor):
after = state.uop.after(state.uop.shrink(((0, 2),)).store(UOp.const(dtypes.float32, Invalid, shape=(2,))))
return Tensor(after).contiguous() + a
out = f(Tensor([1., 1., 1., 1.], device="CPU").contiguous().realize())
np.testing.assert_allclose(out.numpy(), [11., 21., 31., 41.])
def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2): def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2):
devs = ("CPU:0", "CPU:1") devs = ("CPU:0", "CPU:1")
def my_kernel(C:UOp, A:UOp) -> UOp: def my_kernel(C:UOp, A:UOp) -> UOp:

View file

@ -207,7 +207,7 @@ class TestMultiTensor(unittest.TestCase):
out.numpy() out.numpy()
def test_backprop_conv(self): def test_backprop_conv(self):
with Tensor.train(): with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2) for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.Adam(get_parameters(conv)) optim = nn.optim.Adam(get_parameters(conv))
@ -511,7 +511,7 @@ class TestMultiTensor(unittest.TestCase):
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0) def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_dropout_on_shard(self): def test_dropout_on_shard(self):
with Tensor.train(): with Context(TRAINING=1):
X = Tensor.ones(256).to(devices_2) X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5).numpy() output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True) unique, counts = np.unique(output, return_counts=True)
@ -519,7 +519,7 @@ class TestMultiTensor(unittest.TestCase):
assert 96 < counts[0] < 160, counts[0] assert 96 < counts[0] < 160, counts[0]
def test_dropout_on_shard_axis(self): def test_dropout_on_shard_axis(self):
with Tensor.train(): with Context(TRAINING=1):
X = Tensor.ones(512).shard(devices_2, axis=0) X = Tensor.ones(512).shard(devices_2, axis=0)
output = X.dropout(0.5).numpy() output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True) unique, counts = np.unique(output, return_counts=True)
@ -664,7 +664,7 @@ class TestBatchNorm(unittest.TestCase):
def setUp(self): pass def setUp(self): pass
def test_unsynced_backprop_conv_bn(self): def test_unsynced_backprop_conv_bn(self):
with Tensor.train(): with Context(TRAINING=1):
from extra.lr_scheduler import OneCycleLR from extra.lr_scheduler import OneCycleLR
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)] convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
@ -709,7 +709,7 @@ class TestBatchNorm(unittest.TestCase):
bn_ts.append(bni) bn_ts.append(bni)
return bn_ts[0].cat(*bn_ts[1:]) return bn_ts[0].cat(*bn_ts[1:])
with Tensor.train(): with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
bn = BatchNorm(16) bn = BatchNorm(16)
@ -731,7 +731,7 @@ class TestBatchNorm(unittest.TestCase):
from examples.hlb_cifar10 import UnsyncedBatchNorm from examples.hlb_cifar10 import UnsyncedBatchNorm
GPUS = (d1, d2) GPUS = (d1, d2)
with Tensor.train(): with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3) conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS)) bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
@ -756,7 +756,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0) x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0)
with Tensor.train(is_training): with Context(TRAINING=is_training):
bns = [] bns = []
for _ in range(len(devices)): for _ in range(len(devices)):
bn = nn.BatchNorm2d(8) bn = nn.BatchNorm2d(8)
@ -777,7 +777,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train(): with Context(TRAINING=1):
synced_bn = BatchNorm2d(8) synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices)) unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))

View file

@ -13,7 +13,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType
# import all pattern matchers here # import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps, get_simplifying_rewrite_patterns
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
@ -23,6 +23,7 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
from tinygrad.codegen.late.coalese import memory_coalesing
pm_index_is_shrink = PatternMatcher([ pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK # rewrite non-image INDEX to SHRINK
@ -52,6 +53,10 @@ pm_number_params = PatternMatcher([
(UPat(Ops.PARAM, name="x"), do_number_param), (UPat(Ops.PARAM, name="x"), do_number_param),
]) ])
pm_no_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int))
])
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast)) if DEBUG >= 5: print(pyrender(ast))
@ -78,7 +83,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = apply_opts(sink, ren, beam=ast.arg.beam) sink = apply_opts(sink, ren, beam=ast.arg.beam)
# ** expander (expand_rewrite) ** # ** expander (expand_rewrite) **
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic") sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
# expand # expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander") sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
@ -113,18 +118,23 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# optional pre matcher # optional pre matcher
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher") if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
# decompositions # floordiv+mod / dtype decomp (early)
supported_ops = tuple(ren.code_for_op.keys()) supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV)) pm_decomp = symbolic_simple+get_simplifying_rewrite_patterns(supported_ops)
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2) sink = graph_rewrite(sink, pm_decomp, name="early decompositions")
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes") sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
# GEP/STACK stuff # do memory coalesing (late)
sink = memory_coalesing(sink, ren)
# instruction selection decompositions
pm_decomp = pm_decomp+\
get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))+\
get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="late decompositions")
# this is new style (TODO: this should all be removed)
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack") sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
# this is new style
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink") sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style") sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
@ -133,7 +143,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# final rules for the renderer (without sym) # final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([]) extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends+pm_no_weakints
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite") sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
# this was the linearizer # this was the linearizer

View file

@ -0,0 +1,73 @@
from typing import Any
import itertools
from collections import defaultdict
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import getenv
from tinygrad.renderer import Renderer
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
if getenv("DMC"): return sink
# collect
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
for u in sink.toposort():
# TODO: this should handle images too, it's just memory coalesing
if u.op in {Ops.LOAD, Ops.STORE} and not isinstance(u.src[0].src[0].dtype, ImageDType):
assert len(u.src) == (2 if u.op is Ops.STORE else 1), "memory coalesing does not support gated loads/stores"
assert u.src[0].op is Ops.INDEX
buf, idx_u = u.src[0].src
if buf.addrspace == AddrSpace.REG: continue
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
# build replacements
replacements = {}
for (op,buf,base,valid),offsets in memory.items():
# allowed lengths (copied in)
lengths = []
must_divide = True
if ctx is not None and ctx.target.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded
# do the grouping
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for full_grp in grouped_offsets:
while len(full_grp):
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
grp = full_grp[:length]
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op == Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1, f"attempting multiple stores: {len(offsets[g])}"
datas.append(offsets[g][0].src[1])
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
store = idx.store(data, valid) if valid is not None else idx.store(data)
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
else:
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
for i,g in enumerate(grp):
for oo in offsets[g]:
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
full_grp = full_grp[length:]
# apply
return sink.substitute(replacements, name="memory coalesing")

View file

@ -14,7 +14,7 @@ from tinygrad.renderer import Renderer
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]: def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
# can drop valid if idx is out of bound when valid is False # can drop valid if idx is out of bound when valid is False
drop_stmt = [] drop_stmt = []
for stmt in valid.split_uop(Ops.AND): for i,stmt in enumerate(valid.split_uop(Ops.AND)):
if (res:=parse_valid(stmt)) is None: continue if (res:=parse_valid(stmt)) is None: continue
X, is_upper_bound, c = res X, is_upper_bound, c = res
@ -25,12 +25,12 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
drop_stmt.append(stmt) drop_stmt.append(stmt)
continue continue
# if X <= c, check if it's out of bound when X = c+1 # check if idx is out of bound when X is on the wrong side of the bound: X in [c+1, vmax] or [vmin, c-1]
# if X >= c, check if it's out of bound when X = c-1 lo, hi = (c + 1, X.vmax) if is_upper_bound else (X.vmin, c - 1)
test_value = c + 1 if is_upper_bound else c - 1 if lo <= hi:
for i,b in zip(idx.src, (width, height)): fake = UOp.variable(f"fake{i}", lo, hi, X.dtype)
if i.is_increasing(): for coord,b in zip(idx.src, (width, height)):
rw = i.substitute({X:X.const_like(test_value)}) rw = coord.substitute({X:fake}).simplify()
if rw.vmin >= b or rw.vmax < 0: if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt) drop_stmt.append(stmt)
break break
@ -162,18 +162,8 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# determine fold lengths # determine fold lengths
lengths = [] lengths = []
must_divide = True must_divide = True
if ctx is not None and ctx.target.device == "DSP": # TODO: this belongs in coalese
lengths = [128,64,32,16,8,4] if isinstance(buf.dtype, ImageDType): lengths = [4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded lengths.append(1) # worst case, it's not folded
# filter fold lengths that don't divide # filter fold lengths that don't divide

View file

@ -101,6 +101,12 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# for Schedule, we check if the range is used in INDEX gates or WHERE gates # for Schedule, we check if the range is used in INDEX gates or WHERE gates
is_masked = k.rngs[axis] in where_gate_rngs is_masked = k.rngs[axis] in where_gate_rngs
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7: if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
# upcasting a masked global axis moves that range out of the launch grid into each work-item
# under IMAGE, skip the upcast unless enough global work-items remain after it to hide memory latency
if IMAGE and k.axis_types[axis] is AxisType.GLOBAL:
global_upcast = prod(k.full_shape[i] for i in to_upcast if k.axis_types[i] is AxisType.GLOBAL) * k.full_shape[axis]
global_items_after = prod(k.full_shape[i] for i in k.axes_of(AxisType.GLOBAL)) // global_upcast
if resolve(global_items_after < getenv("OCCUPANCY_FLOOR", 4096), False): continue
if DEBUG >= 4: print(f"upcasting masked axis : {axis}") if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis) to_upcast.append(axis)
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0)) for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))

View file

@ -231,8 +231,7 @@ def _prepare_jit_inputs(args, kwargs):
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else [] it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)] tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
def get_input_uops() -> list[UOp]: return flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors]) def get_input_uops() -> list[UOp]: return flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
# TODO: drop the CONST branch once all CONST are deviceless if any(u.device is None for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
if any(u.device is None or u.base.op is Ops.CONST for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors) if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
input_uops = get_input_uops() input_uops = get_input_uops()
# collect buffer UOps (including MultiBuffer) # collect buffer UOps (including MultiBuffer)

View file

@ -1,26 +1,29 @@
import functools, itertools, time import functools, time
from typing import Generic, TypeVar, Callable, cast, overload from typing import Generic, TypeVar, Callable, cast, overload
from tinygrad.helpers import Context, dedup, getenv, DEBUG from tinygrad.helpers import Context, dedup, getenv, DEBUG
from tinygrad.dtype import Invalid
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict from tinygrad.nn.state import get_state_dict
def add_to_ctx(ctx, x:UOp): def add_to_ctx(ctx, x:UOp):
if x.buf_uop in ctx[1]: return None
ret = x.param_like(len(ctx[0])) ret = x.param_like(len(ctx[0]))
ctx[0].append(x) ctx[0].append(x)
return ret return ret
pm_transform_unique_const = PatternMatcher([
# transform unique consts to LUNIQUE
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))),
])
pm_ctx = PatternMatcher([ pm_ctx = PatternMatcher([
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx), (UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"), (UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None), lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
])+pm_transform_unique_const ])
def invalid_outputs(uret:UOp) -> set[UOp]:
# invalids() returns fresh write-only scratch: a clone storing CONST(Invalid)
# don't capture it as an input; only skip fresh buffers, not realized ones
return {u.src[0].buf_uop for u in uret.backward_slice_with_self
if u.op is Ops.STORE and u.src[1].base.op is Ops.CONST and u.src[1].base.arg is Invalid
and not u.src[0].buf_uop.is_realized}
ReturnType = TypeVar('ReturnType') ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]): class _function(Generic[ReturnType]):
@ -63,7 +66,7 @@ class _function(Generic[ReturnType]):
# the BUFFERs that are left are the implicit inputs # the BUFFERs that are left are the implicit inputs
num_explicit = len(call_uops) num_explicit = len(call_uops)
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs") uret = graph_rewrite(uret, pm_ctx, (call_uops, invalid_outputs(uret)), bottom_up=True, name="get_implicit_inputs")
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__ name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
if not self.allow_implicit: if not self.allow_implicit:
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER] implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]

View file

@ -240,6 +240,7 @@ DEV, DEBUG, BEAM, NOOPT = _DEV("DEV", ""), ContextVar("DEBUG", 0), ContextVar("B
IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0) IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0)
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32) JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0) WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0)
TRAINING = ContextVar("TRAINING", 0)
USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0) USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0)
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0) TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)

View file

@ -1,13 +1,13 @@
from __future__ import annotations from __future__ import annotations
import functools, itertools import functools, itertools, string
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args, cast
from tinygrad.mixin.elementwise import ElementwiseMixin from tinygrad.mixin.elementwise import ElementwiseMixin
from tinygrad.mixin.movement import MovementMixin from tinygrad.mixin.movement import MovementMixin
from tinygrad.mixin.reduce import ReduceMixin from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod
from tinygrad.helpers import resolve_pool_pads, round_up from tinygrad.helpers import resolve_pool_pads, round_up
if TYPE_CHECKING: if TYPE_CHECKING:
@ -17,35 +17,59 @@ ReductionStr = Literal["mean", "sum", "none"]
class OpMixin(ElementwiseMixin, ReduceMixin): class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory")
def const(dtype, b): raise NotImplementedError
@classmethod def item(self) -> PyConst:
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
""" """
Creates a tensor with the given shape, filled with the given value. Returns the value of this tensor as a standard Python number.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python" ```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy()) t = Tensor(42)
``` print(t.item())
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
``` ```
""" """
# TODO: enable this check assert self.numel() == 1, "must have one element for item"
# if not buffer: assert device is None, "buffer=False does not support device specification" return self.data()[(0,) * len(self.shape)]
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def __getitem__(self, indices) -> Self: return self._getitem(indices) def __getitem__(self, indices) -> Self:
"""
Retrieves a sub-tensor using indexing.
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
Examples:
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(12).reshape(3, 4)
print(t.numpy())
```
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
```python exec="true" source="above" session="tensor" result="python"
print(t[1, 2].numpy())
```
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
```python exec="true" source="above" session="tensor" result="python"
print(t[0:2, ::2].numpy())
```
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
```python exec="true" source="above" session="tensor" result="python"
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
```
- `None` Indexing: Add a new dimension to the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(t[:, None].shape)
```
NOTE: Out-of-bounds indexing results in a value of `0`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t[Tensor([4, 3, 2])].numpy())
```
"""
return self._getitem(indices)
def _getitem(self, indices, v=None) -> Self: def _getitem(self, indices, v=None) -> Self:
from tinygrad.uop.ops import UOp from tinygrad.uop.ops import UOp
@ -138,40 +162,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops))) vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self) return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
@classmethod @classmethod
def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self: def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self:
""" """
@ -367,7 +357,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode) if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
raise NotImplementedError(f"{mode=} is not supported") raise NotImplementedError(f"{mode=} is not supported")
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]: def _broadcasted(self, y:Self|ConstType|UOp, reverse:bool=False) -> tuple[Self, Self]:
if not isinstance(y, type(self)): y = self.ufix(y) if not isinstance(y, type(self)): y = self.ufix(y)
x, y = (self, y) if not reverse else (y, self) x, y = (self, y) if not reverse else (y, self)
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch # ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
@ -423,6 +413,47 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
def __matmul__(self, x:Self) -> Self: return self.matmul(x) def __matmul__(self, x:Self) -> Self: return self.matmul(x)
def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True) def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True)
@classmethod
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]: def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]:
""" """
Computes the gradient of the targets with respect to self. Computes the gradient of the targets with respect to self.
@ -1147,6 +1178,68 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
# select from values for each True element in mask else select from self # select from values for each True element in mask else select from self
return mask.where(values, self) return mask.where(values, self)
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
"""
Selects elements from `self` based on the boolean `mask`.
With `size=None` (default), output length equals the number of `True` values (not jittable).
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
print(t.numpy())
print(mask.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
```
"""
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
mask_cumsum = mask.cumsum()
if size is None:
counts = type(self).zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
counts = type(self).zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
return (type(self).arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Self:
"""
Returns the indices of the elements that are non-zero.
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 0, 2, 0, 3])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0], [0, 2]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero(size=3, fill_value=-1).numpy())
```
"""
if self.ndim == 0:
return type(self).zeros(size if size is not None else int(self.ne(0).item()), 0, dtype=dtypes.int32, device=self.device)
mask = self.ne(0).flatten()
indices = type(self).stack(*[type(self).arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
for i, s in enumerate(self.shape)], dim=-1)
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
# ***** functional nn ops ***** # ***** functional nn ops *****
def sequential(self, ll:list[Callable[[Self], Self]]) -> Self: def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:

View file

@ -1,13 +1,101 @@
from typing import Self from typing import TYPE_CHECKING, Callable, Self
from tinygrad.dtype import ConstType, DType from tinygrad.dtype import ConstType, DTypeLike, Invalid, dtypes, to_dtype
from tinygrad.helpers import argfix
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin
class CreationMixin: if TYPE_CHECKING:
def const_like(self, b: ConstType) -> Self: raise NotImplementedError from tinygrad.uop.ops import sint, UOp
def cast(self, dtype: DType) -> Self: raise NotImplementedError
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self: class CreationMixin(DTypeMixin, MovementMixin):
"""Creates a tensor with the same shape as `self`, filled with the given value.""" @staticmethod
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype) def const(dtype, b): raise NotImplementedError
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
def _multi_like(self, fxn:'Callable[[tuple[sint, ...], str|None], Self]') -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return self._wrap_uop(self._uop.empty_like(dtype, device))
@classmethod
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:'tuple[sint, ...]', fill_value:'ConstType|UOp', dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
```
"""
# TODO: enable this check
# if not buffer: assert device is None, "buffer=False does not support device specification"
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def full_like(self, fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: type(self).full(shape, fill_value, dtype=dtype or self.dtype, device=dev, buffer=buffer))
return type(self).full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device, buffer=buffer)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
def zeros_like(self, **kwargs) -> Self: def zeros_like(self, **kwargs) -> Self:
""" """
@ -22,6 +110,23 @@ class CreationMixin:
""" """
return self.full_like(0, **kwargs) return self.full_like(0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
def ones_like(self, **kwargs) -> Self: def ones_like(self, **kwargs) -> Self:
""" """
Creates a tensor with the same shape as `self`, filled with ones. Creates a tensor with the same shape as `self`, filled with ones.

View file

@ -1,13 +1,36 @@
from typing import Self from typing import TYPE_CHECKING, Self
from tinygrad.dtype import DType, dtypes from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
class DTypeMixin: class DTypeMixin:
@property @property
def dtype(self) -> DType: raise NotImplementedError def dtype(self) -> DType: raise NotImplementedError
@property
def _uop(self) -> 'UOp': raise NotImplementedError
def _wrap_uop(self, u:'UOp') -> Self: raise NotImplementedError
def cast(self, dtype:DType) -> Self: raise NotImplementedError def cast(self, dtype:DTypeLike) -> Self:
"""
Casts `self` to the given `dtype`.
def bitcast(self, dtype:DType) -> Self: raise NotImplementedError ```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.uint8)
print(t.dtype, t.numpy())
```
"""
return self if self.dtype == (dt:=to_dtype(dtype)) else self._wrap_uop(self._uop.cast(dt))
def bitcast(self, dtype:DTypeLike) -> Self: raise NotImplementedError
def element_size(self) -> int: def element_size(self) -> int:
""" """

View file

@ -3,23 +3,17 @@ from typing import TYPE_CHECKING, Literal, Self
from tinygrad.uop import Ops from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
from tinygrad.helpers import argfix, polyN from tinygrad.helpers import argfix, polyN
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.creation import CreationMixin from tinygrad.mixin.creation import CreationMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from tinygrad.uop.ops import UOp from tinygrad.uop.ops import UOp
class ElementwiseMixin(DTypeMixin, CreationMixin): class ElementwiseMixin(CreationMixin):
# required to implement # required to implement
def alu(self, op: Ops, *src: Self) -> Self: def alu(self, op: Ops, *src: Self) -> Self:
raise NotImplementedError raise NotImplementedError
@property
def _uop(self) -> 'UOp': raise NotImplementedError
def _wrap_uop(self, u: 'UOp') -> Self: raise NotImplementedError
# great functions you get! # great functions you get!
def ufix(self, x: 'Self|ConstType|UOp') -> Self: def ufix(self, x: 'Self|ConstType|UOp') -> Self:
return x if isinstance(x, type(self)) else self._wrap_uop(self._uop.ufix(x)) return x if isinstance(x, type(self)) else self._wrap_uop(self._uop.ufix(x))
@ -51,7 +45,11 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
""" """
return self.cast(dtypes.bool).ne(True) return self.cast(dtypes.bool).ne(True)
def contiguous(self, *args, **kwargs) -> Self: raise NotImplementedError def contiguous(self, **kwargs) -> Self:
"""
Returns a contiguous tensor.
"""
return self._wrap_uop(self._uop.contiguous(**kwargs))
def contiguous_backward(self) -> Self: def contiguous_backward(self) -> Self:
""" """

View file

@ -1,6 +1,6 @@
from typing import cast from typing import cast
import math, dataclasses, itertools import math, dataclasses
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort from tinygrad.helpers import argsort
from tinygrad.dtype import sum_acc_dtype from tinygrad.dtype import sum_acc_dtype
@ -33,7 +33,7 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM} params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grad_args = ctx.src grad_args = ctx.src
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
g if g.base.op is Ops.CONST and g.device is None else g.param_like(len(args)+i) for i,g in enumerate(grad_args))) g if g.base.op is Ops.CONST else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
grads = compute_gradient(fxn, root_grad, set(params.values())) grads = compute_gradient(fxn, root_grad, set(params.values()))
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed # for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {} fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
@ -42,9 +42,6 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads] grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True) bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs)) bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
# TODO: is this okay here?
from tinygrad.function import pm_transform_unique_const
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward) bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)} gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args))) return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))

View file

@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Self import math
from tinygrad.dtype import DType, dtypes from typing import Self, cast
from tinygrad.helpers import ceildiv, prod from tinygrad.dtype import DType, DTypeLike, dtypes, least_upper_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, prod, TRAINING
from tinygrad.mixin import OpMixin from tinygrad.mixin import OpMixin
from tinygrad.device import canonicalize_device
class RandMixin(OpMixin): class RandMixin(OpMixin):
@ -39,3 +41,286 @@ class RandMixin(OpMixin):
bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4)) bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4))
out = cls._bits_to_rand(bits, shape, dtype) out = cls._bits_to_rand(bits, shape, dtype)
return out.contiguous() if contiguous else out return out.contiguous() if contiguous else out
@staticmethod
def _next_counter(device:str, num:int):
raise NotImplementedError("_next_counter requires the stateful per-device RNG counter, only implemented on Tensor")
@classmethod
def rand(cls, *shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
dt = to_dtype(dtype or dtypes.default_float)
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = cast(str, canonicalize_device(device))
key, counter = cls._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
return cls._rand(key, counter, shape, dt, contiguous=contiguous)
def rand_like(self, **kwargs) -> Self:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.rand_like(t).numpy())
```
"""
if isinstance(self.device, tuple):
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
dtype = kwargs.pop("dtype", self.dtype)
return self._multi_like(lambda shape, dev: type(self).rand(*shape, dtype=dtype, device=dev, **kwargs))
return type(self).rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.randn_like(t).numpy())
```
"""
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(to_dtype(dtype or self.dtype))
@classmethod
def randn(cls, *shape, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randn(2, 3).numpy())
```
"""
return cls.empty(*shape, **kwargs).randn_like(dtype=dtype) # type: ignore[attr-defined]
@classmethod
def randint(cls, *shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
Requires `low < high`. If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
return cls.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@classmethod
def normal(cls, *shape, mean=0.0, std=1.0, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
Requires `std >= 0`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
return std * cls.randn(*shape, **kwargs) + mean
@classmethod
def uniform(cls, *shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
Requires `low < high`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
return ((high-low) * cls.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
@classmethod
def scaled_uniform(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.scaled_uniform(2, 3).numpy())
```
"""
return cls.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
@classmethod
def glorot_uniform(cls, *shape, **kwargs) -> Self:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
@classmethod
def kaiming_uniform(cls, *shape, a:float = 0.01, **kwargs) -> Self:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
@classmethod
def kaiming_normal(cls, *shape, a:float = 0.01, **kwargs) -> Self:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return cls.normal(*shape, mean=0.0, std=std, **kwargs)
@classmethod
def randperm(cls, n:int, device=None, dtype=dtypes.int32, **kwargs) -> Self:
"""
Returns a tensor with a random permutation of integers from `0` to `n-1`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randperm(6).numpy())
```
"""
return cls.rand(n, device=device, **kwargs).argsort().cast(dtype)
def multinomial(self, num_samples:int = 1, replacement:bool = False) -> Self:
"""
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(20, replacement=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(3, replacement=False).numpy())
```
"""
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
weight = self.unsqueeze(0) if self.ndim == 1 else self
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
if replacement or num_samples == 1:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = type(self).rand(num_samples, cdf.shape[0], 1).to(self.device) # type: ignore[attr-defined]
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
else:
# Efraimidis-Spirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
def dropout(self, p=0.5) -> Self:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `TRAINING` is set (e.g. inside `Context(TRAINING=1)`).
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Context(TRAINING=1):
print(t.dropout().numpy())
```
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not TRAINING or p == 0: return self
if p == 1: return self.const_like(0)
return (self.rand_like(dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def scaled_dot_product_attention(self, key:Self, value:Self, attn_mask:Self|None=None, dropout_p:float=0.0,
is_causal:bool=False, enable_gqa:bool=False) -> Self:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if enable_gqa:
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
q = self
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value

View file

@ -1,8 +1,7 @@
import string from typing import Self, Sequence
from typing import Self, Sequence, cast
from tinygrad.uop import Ops from tinygrad.uop import Ops
from tinygrad.dtype import DTypeLike, dtypes, sum_acc_dtype, to_dtype from tinygrad.dtype import DTypeLike, dtypes, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, argsort, make_tuple, merge_dicts from tinygrad.helpers import make_tuple
from tinygrad.mixin.dtype import DTypeMixin from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin from tinygrad.mixin.movement import MovementMixin
@ -136,44 +135,3 @@ class ReduceMixin(DTypeMixin, MovementMixin):
``` ```
""" """
return self.bool().prod(axis, keepdim) return self.bool().prod(axis, keepdim)
@classmethod
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))

View file

@ -22,6 +22,7 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \ (UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \
if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None), if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: ctx[x.src[0]] if x.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else None),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"),
# GPU stuff # GPU stuff

View file

@ -18,7 +18,7 @@ import sys
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
def add_ranges_to_store(ctx, x): def add_ranges_to_store(ctx, x):
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == (): return None if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == () or x.src[0].max_numel() == x.src[1].max_numel() == 1: return None
assert x.src[0].shape == x.src[1].shape, "bad store shape" assert x.src[0].shape == x.src[1].shape, "bad store shape"
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape] idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs) return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)
@ -543,9 +543,6 @@ to_define_global = PatternMatcher([
# remove device from local BUFFERIZE # remove device from local BUFFERIZE
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))), (UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
# remove UNIQUE/DEVICE to dedup CONST
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
# renumber the ranges starting with 0 so that kernel deduping works # renumber the ranges starting with 0 so that kernel deduping works
(UPat(Ops.RANGE, name="r"), renumber_range), (UPat(Ops.RANGE, name="r"), renumber_range),
]) ])

View file

@ -1,14 +1,13 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations from __future__ import annotations
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
from contextlib import ContextDecorator from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, to_dtype
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.helpers import suppress_finalizing, disable_gc, TRAINING
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
from tinygrad.mixin.rand import RandMixin from tinygrad.mixin.rand import RandMixin
from tinygrad.schedule import create_linear_with_vars from tinygrad.schedule import create_linear_with_vars
@ -59,19 +58,25 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
assert isinstance(ret, Tensor), "sum didn't return a Tensor" assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret return ret
class Tensor(RandMixin): # TODO: deprecate this, always use TRAINING
class TensorMeta(type):
@property
def training(cls) -> bool: return bool(TRAINING.value)
@training.setter
def training(cls, mode:bool): TRAINING.value = int(mode)
class Tensor(RandMixin, metaclass=TensorMeta):
""" """
A `Tensor` is a multi-dimensional matrix containing elements of a single data type. A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
```python exec="true" session="tensor" ```python exec="true" session="tensor"
from tinygrad import Tensor, dtypes, nn from tinygrad import Tensor, dtypes, nn, Context
import numpy as np import numpy as np
import math import math
np.set_printoptions(precision=4) np.set_printoptions(precision=4)
``` ```
""" """
__slots__ = "uop", "is_param", "grad" __slots__ = "uop", "is_param", "grad"
training: ClassVar[bool] = False
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None, def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
device:str|tuple|list|None=None, dtype:DTypeLike|None=None): device:str|tuple|list|None=None, dtype:DTypeLike|None=None):
@ -125,9 +130,9 @@ class Tensor(RandMixin):
@suppress_finalizing @suppress_finalizing
def __del__(self): all_tensors.pop(weakref.ref(self), None) def __del__(self): all_tensors.pop(weakref.ref(self), None)
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, extra_args=(), **kwargs) -> Tensor: def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, **kwargs) -> Tensor:
srcs = (self,)+x srcs = (self,)+x
new_uop: UOp = fxn(*[t.uop for t in srcs], *extra_args, **kwargs) new_uop: UOp = fxn(*[t.uop for t in srcs], **kwargs)
if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,) if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
# directly create the Tensor # directly create the Tensor
ret = Tensor.__new__(Tensor) ret = Tensor.__new__(Tensor)
@ -136,34 +141,18 @@ class Tensor(RandMixin):
all_tensors[weakref.ref(ret)] = None all_tensors[weakref.ref(ret)] = None
return ret return ret
# alu and const_like are used by the mixins # alu, _uop, _wrap_uop and const are used by the mixins
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src) def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
@property @property
def _uop(self) -> UOp: return self.uop def _uop(self) -> UOp: return self.uop
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u) def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
@staticmethod @staticmethod
def const(dtype:DType, b:ConstType|UOp) -> Tensor: return Tensor(UOp.const(dtype, b)) def const(dtype:DType, b:ConstType|UOp) -> Tensor: return Tensor(UOp.const(dtype, b))
@staticmethod
def invalids(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Tensor:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return Tensor(UOp.invalids(argfix(*shape), dtype, device))
def is_param_(self, is_param:bool=True) -> Tensor: def is_param_(self, is_param:bool=True) -> Tensor:
self.is_param = is_param self.is_param = is_param
return self return self
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
def __repr__(self): def __repr__(self):
ld = self.uop ld = self.uop
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>" ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>"
@ -275,6 +264,7 @@ class Tensor(RandMixin):
x = self.cast(self.dtype.base).contiguous() x = self.cast(self.dtype.base).contiguous()
if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU") if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU")
return cast(Buffer, x.realize().uop.buffer).ensure_allocated() return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
def _data(self) -> memoryview: return self._buffer().as_memoryview() def _data(self) -> memoryview: return self._buffer().as_memoryview()
def data(self) -> memoryview: def data(self) -> memoryview:
@ -290,19 +280,7 @@ class Tensor(RandMixin):
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}" assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12) assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape) return self._data().cast(self.dtype.base.fmt, self.shape)
def item(self) -> PyConst:
"""
Returns the value of this tensor as a standard Python number.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor(42)
print(t.item())
```
"""
assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)]
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions) # NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
def tolist(self) -> PyConst|list[Any]: def tolist(self) -> PyConst|list[Any]:
@ -464,13 +442,6 @@ class Tensor(RandMixin):
""" """
return Tensor(UOp.empty(argfix(*shape), dtype, device)) return Tensor(UOp.empty(argfix(*shape), dtype, device))
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> Tensor:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return Tensor(self.uop.empty_like(dtype, device))
@staticmethod @staticmethod
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor: def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
""" """
@ -533,261 +504,6 @@ class Tensor(RandMixin):
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff)) high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
return Tensor._device_seeds[device], low.cat(high) return Tensor._device_seeds[device], low.cat(high)
@staticmethod
def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
dt = to_dtype(dtype or dtypes.default_float)
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = cast(str, canonicalize_device(device))
key, counter = Tensor._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
return Tensor._rand(key, counter, shape, dt, contiguous=contiguous)
# ***** creation helper functions *****
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Tensor]) -> Tensor:
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self.uop.axis is None: return fxn(self.shape, None).shard(self.device)
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, d).uop for d in self.device])
return Tensor(stacked.multi(self.uop.axis))
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
def rand_like(self, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.rand_like(t).numpy())
```
"""
if isinstance(self.device, tuple):
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
dtype = kwargs.pop("dtype", self.dtype)
return self._multi_like(lambda shape, dev: Tensor.rand(*shape, dtype=dtype, device=dev, **kwargs))
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
# ***** random functions *****
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.randn_like(t).numpy())
```
"""
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or self.dtype)
@staticmethod
def randn(*shape, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randn(2, 3).numpy())
```
"""
return Tensor.empty(*shape, **kwargs).randn_like(dtype=dtype)
@staticmethod
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
Requires `low < high`. If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
Requires `std >= 0`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
return std * Tensor.randn(*shape, **kwargs) + mean
@staticmethod
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
Requires `low < high`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.scaled_uniform(2, 3).numpy())
```
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
@staticmethod
def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Returns a tensor with a random permutation of integers from `0` to `n-1`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randperm(6).numpy())
```
"""
return Tensor.rand(n, device=device, **kwargs).argsort().cast(dtype)
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
"""
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(20, replacement=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(3, replacement=False).numpy())
```
"""
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
weight = self.unsqueeze(0) if self.ndim == 1 else self
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
if replacement or num_samples == 1:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
else:
# EfraimidisSpirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass ***** # ***** toposort and backward pass *****
def backward(self, gradient:Tensor|None=None) -> Tensor: def backward(self, gradient:Tensor|None=None) -> Tensor:
@ -814,49 +530,9 @@ class Tensor(RandMixin):
# ***** movement ops ***** # ***** movement ops *****
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg) def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, op=op, arg=arg)
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis) def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
def __getitem__(self, indices) -> Tensor:
"""
Retrieves a sub-tensor using indexing.
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
Examples:
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(12).reshape(3, 4)
print(t.numpy())
```
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
```python exec="true" source="above" session="tensor" result="python"
print(t[1, 2].numpy())
```
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
```python exec="true" source="above" session="tensor" result="python"
print(t[0:2, ::2].numpy())
```
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
```python exec="true" source="above" session="tensor" result="python"
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
```
- `None` Indexing: Add a new dimension to the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(t[:, None].shape)
```
NOTE: Out-of-bounds indexing results in a value of `0`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t[Tensor([4, 3, 2])].numpy())
```
"""
return super().__getitem__(indices)
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None: def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}") if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
# raise if mutation would diverge from eager (allow only pure views of a realized buffer; exclude +=/-= RHS via v_uop/v_bw) # raise if mutation would diverge from eager (allow only pure views of a realized buffer; exclude +=/-= RHS via v_uop/v_bw)
@ -887,68 +563,6 @@ class Tensor(RandMixin):
def __delitem__(self, indices) -> None: def __delitem__(self, indices) -> None:
raise TypeError("Tensor does not support deleting items") raise TypeError("Tensor does not support deleting items")
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
"""
Selects elements from `self` based on the boolean `mask`.
With `size=None` (default), output length equals the number of `True` values (not jittable).
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
print(t.numpy())
print(mask.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
```
"""
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
mask_cumsum = mask.cumsum()
if size is None:
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
counts = Tensor.zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
return (Tensor.arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor:
"""
Returns the indices of the elements that are non-zero.
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 0, 2, 0, 3])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0], [0, 2]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero(size=3, fill_value=-1).numpy())
```
"""
if self.ndim == 0:
return Tensor.zeros(size if size is not None else int((self != 0).item()), 0, dtype=dtypes.int32, device=self.device)
mask = (self != 0).flatten()
indices = Tensor.stack(*[Tensor.arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
for i, s in enumerate(self.shape)], dim=-1)
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
# ***** reduce ops ***** # ***** reduce ops *****
def keccak(self, cfg:str|tuple[int, int]="sha3_256"): def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
@ -1054,11 +668,11 @@ class Tensor(RandMixin):
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
# compute 6x6 winograd tiles: GgGt, BtdB # compute 6x6 winograd tiles: GgGt, BtdB. contiguous so the transforms are materialized once
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1)) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx) dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW)) ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
@ -1105,14 +719,6 @@ class Tensor(RandMixin):
if IMAGE: return self.image_dot(w, dtype) if IMAGE: return self.image_dot(w, dtype)
return super().dot(w, dtype) return super().dot(w, dtype)
# ***** unary ops *****
def contiguous(self, *args, **kwargs) -> Tensor:
"""
Returns a contiguous tensor.
"""
return self._apply_uop(UOp.contiguous, extra_args=args, **kwargs)
# ***** broadcasted elementwise ops ***** # ***** broadcasted elementwise ops *****
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor: def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
@ -1172,80 +778,8 @@ class Tensor(RandMixin):
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec") fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos))) return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)))
# ***** functional nn ops *****
def dropout(self, p=0.5) -> Tensor:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `Tensor.training` is `True`.
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Tensor.train():
print(t.dropout().numpy())
```
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not Tensor.training or p == 0: return self
if p == 1: return self.const_like(0)
return (Tensor.rand_like(self, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if enable_gqa:
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
q = self
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
# ***** cast ops ***** # ***** cast ops *****
def cast(self, dtype:DTypeLike) -> Tensor:
"""
Casts `self` to the given `dtype`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.uint8)
print(t.dtype, t.numpy())
```
"""
return self if self.dtype == (dt:=to_dtype(dtype)) else self._apply_uop(UOp.cast, dtype=dt)
def bitcast(self, dtype:DTypeLike) -> Tensor: def bitcast(self, dtype:DTypeLike) -> Tensor:
""" """
Bitcasts `self` to the given `dtype` of the same itemsize. Bitcasts `self` to the given `dtype` of the same itemsize.

View file

@ -454,7 +454,8 @@ def floormod_to_mod(a:UOp, b:UOp) -> UOp:
powers_of_two: dict[int, int] = {2**i:i for i in range(64)} powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
@functools.cache @functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher: def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
# these are rewrites that make things simpler
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)] pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod # FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)) if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
@ -463,6 +464,11 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32)) if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends) # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]))) if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
return PatternMatcher(pat)
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
pat: list[tuple[UPat, Callable]] = []
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(), if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())] lambda x,y: (x | y).logical_not())]
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y) # rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)

View file

@ -560,12 +560,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=()) ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
@staticmethod @staticmethod
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, unique=True) -> UOp:
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
ret = UOp(Ops.CONST, dt, arg=dt.const(Invalid),
src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device))))
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs): def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs) return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@staticmethod @staticmethod
@ -925,12 +919,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# *** uop symbolic stuff *** # *** uop symbolic stuff ***
def is_increasing(self:UOp) -> bool:
# is f a monotonically increasing function regards its input
if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
return False # False if not sure
def const_factor(self) -> int: def const_factor(self) -> int:
"""largest known int that divides self""" """largest known int that divides self"""
# TODO: for negatives it's not the largest # TODO: for negatives it's not the largest
@ -1064,9 +1052,9 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace)) ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ())) if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
return ret return ret
def placeholder_like(self, slot:int): def placeholder_like(self, slot:int, addrspace=AddrSpace.GLOBAL):
assert all_int(self.shape), "no placeholder-like on symbolic shape" assert all_int(self.shape), "no placeholder-like on symbolic shape"
return UOp.placeholder(self.max_shard_shape, self.dtype, slot) return UOp.placeholder(self.max_shard_shape, self.dtype, slot, addrspace)
# set is store+end+after # set is store+end+after
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp: def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
@ -1196,8 +1184,9 @@ python_alu: dict[Ops, Callable] = {
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq} Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True): def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
if dtype.count > 1: if any(isinstance(x, tuple) for x in operands):
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)]) count = max(len(x) for x in operands if isinstance(x, tuple))
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(count)])
if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid
alu = python_alu[op](*operands) alu = python_alu[op](*operands)
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
@ -1678,6 +1667,8 @@ pm_lower_index_dtype = PatternMatcher([
lambda var,val: var.bind(val).cast(dtypes.weakint)), lambda var,val: var.bind(val).cast(dtypes.weakint)),
# remove hanging casts # remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("slen", dtypes.ints).cast(),), name="shrink"),
lambda shrink,buf,idx,slen: shrink.replace(src=(buf,idx,slen))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)), lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
# remove hanging casts for images # remove hanging casts for images

View file

@ -1,5 +1,5 @@
from typing import cast from typing import cast
from tinygrad.dtype import dtypes, Invalid from tinygrad.dtype import dtypes
from tinygrad.uop import Ops, GroupOp from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
from tinygrad.helpers import strip_parens from tinygrad.helpers import strip_parens
@ -77,8 +77,6 @@ def render_marg(ctx,x:UOp):
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY, sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH} Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
pm_pyrender_extra = PatternMatcher([ pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"),
lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"), (UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"), (UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"), (UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),

View file

@ -126,9 +126,6 @@ spec_tensor = PatternMatcher([
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True), (UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
# CONST with a UNIQUE and DEVICE
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid),
# BUFFER # BUFFER
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"), (UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)), lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
@ -152,8 +149,7 @@ spec_tensor = PatternMatcher([
# movement ops # movement ops
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat())), lambda: True), (UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat())), lambda: True),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"), (UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"), lambda x: x.src[1].shape == x.src[2].shape),
lambda x: x.src[1].dtype.count == x.src[2].dtype.count),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)), (UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
# REDUCE has arg=(op, axis_tuple), src[1:] are ranges after lowering # REDUCE has arg=(op, axis_tuple), src[1:] are ranges after lowering

View file

@ -121,6 +121,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([
# TODO: combine this with "# rules for threefry" below # TODO: combine this with "# rules for threefry" below
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"), ((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None), lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
((UPat.var("x") & UPat.cvar("mask")) // UPat.cvar("c"),
lambda x,mask,c: x // c.arg if c.arg > 0 and c.arg & (c.arg-1) == 0 and mask.arg | (c.arg-1) == -1 else None),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"), (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# ** constant folding ** # ** constant folding **
@ -160,6 +162,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x), (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x), (((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
# ** simple where folding ** # ** simple where folding **
# a conditional with the same results either way is a noop, also fold const conditionals # a conditional with the same results either way is a noop, also fold const conditionals
@ -167,6 +170,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), (UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# a.where(b.where(c, d), d) -> (a & b).where(c, d) # a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)), (UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
# STACK on INDEX CONST (TODO: remove all the GEP crap)
(UPat(Ops.STACK, src=UPat(Ops.INDEX, src=(UPat.var("src"), UPat(Ops.CONST))), name="stk"),
lambda src,stk: src if stk.shape == src.shape and list(range(len(stk.src))) == [x.src[1].arg for x in stk.src] else None),
]) ])
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ******** # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********

View file

@ -121,7 +121,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
for u in (toposort:=x.toposort()): for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE # always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u) if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u) if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST # exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u) if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)