Compare commits

..

29 commits

Author SHA1 Message Date
chenyu
687ade119e
IMAGE hand_coded_optimizations update (#16720) 2026-06-23 21:55:28 -04:00
George Hotz
0a8e61d0c5
switch to the new memory coaleser [pr] (#16716)
* switch to the new memory coalese

* move that stuff

* copy in allowed length logic

* mulitple buffers

* new coalese is better

* fine

* earlier

* fixes

* work

* work

* valid

* stack on index const
2026-06-23 18:03:48 -07:00
wozeparrot
dfea9e7994
llama: fused silu mul quantize mxfp8 (#16704) 2026-06-23 16:59:50 -07:00
chenyu
ce87d80911
better _drop_valid_stmts [pr] (#16719)
also dropped the unused is_increasing
2026-06-23 19:35:01 -04:00
George Hotz
5a2b3b7b06
early dtype decomp (#16718)
* early dtype decomp

* simplify

* cleanup

* that goes there

* doing too much

* stupid symbolic rules
2026-06-23 16:07:20 -07:00
Christopher Milan
116045cc8e
ci: remove tensorflow from testoptim (#16717) 2026-06-23 18:11:48 -04:00
nimlgen
7c1d0b6d9a
hcq2: use shrink(bitcast) (#16713)
* hcq2: use shrink(bitcast)

* x
2026-06-23 18:11:39 +03:00
George Hotz
c9dc1d63cc
small changes from new codegen (#16712)
* small changes from new codegen

* shrink/flatten
2026-06-22 17:44:15 -07:00
Christopher Milan
da98fae9e1
ci: try parallelizing tc tests (#16710) 2026-06-22 20:43:32 -04:00
chenyu
15988b5941
contiguous to mixin and cleanups [PR] (#16711) 2026-06-22 20:18:18 -04:00
Christopher Milan
cbfcf36e44
ci: remove generate_dataset and CL misc (#16709) 2026-06-22 18:01:07 -04:00
nimlgen
f9c8c697d6
hcq2: drop args after inner deps (#16708) 2026-06-22 23:26:11 +03:00
chenyu
0138480910
dropout and scaled_dot_product_attention to mixin (#16707) 2026-06-22 16:17:45 -04:00
chenyu
33b635d23a
Tensor.train -> TRAINING [PR] (#16705)
* Tensor.train -> TRAINING [PR]

* doc
2026-06-22 15:13:22 -04:00
chenyu
625d8bbd0d
TRAINING ContextVar (#16703) 2026-06-22 13:03:08 -04:00
wozeparrot
fe9b19b12d
llama: more mp mem fixes (#16701)
* llama: more mp mem fixes

* clean: unused

* fix: batch
2026-06-22 10:54:35 -04:00
chenyu
267af9c601
full_like to CreationMixin [PR] (#16702) 2026-06-22 09:33:23 -04:00
chenyu
97da54b9d6
more method to CreationMixin [PR] (#16698) 2026-06-22 00:01:22 -04:00
chenyu
fd0dc40689
clean up CreationMixin and DTypeMixin [PR] (#16697) 2026-06-21 21:13:40 -04:00
chenyu
2d8b802958
contiguous in wino conv (#16696)
also fixed test_counters
2026-06-21 17:11:46 -04:00
chenyu
ba1d3baae8
masked_select and nonzero to mixin [PR] (#16695)
with a .data stub
2026-06-21 15:10:44 -04:00
chenyu
d80a41d559
some rand method to RandMixin [PR] (#16693) 2026-06-21 12:16:51 -04:00
wozeparrot
5164c21b44
gemm: keep shape thru mxfp8 quantize (#16692) 2026-06-20 22:28:53 -07:00
chenyu
58ff75272e
const_like and invalids to mixin [PR] (#16690)
* const_like and invalids to mixin [PR]

* empty_like

* einsum

* type
2026-06-21 00:02:29 -04:00
chenyu
b50da5c205
move Tensor.__getitem__ to mixin [PR] (#16689) 2026-06-20 22:01:45 -04:00
chenyu
4618d27129
final const cleanups [PR] (#16688) 2026-06-20 21:38:16 -04:00
chenyu
9ae0a93d0e
more const cleanups [PR] (#16682) 2026-06-20 20:41:43 -04:00
George Hotz
30830850a9
small changes from new codegen (#16681)
* small changes from new codegen

* revert that
2026-06-19 18:29:01 -07:00
chenyu
8b07cca9f7
invalid clone try 3+ [PR] (#16679) 2026-06-19 20:13:52 -04:00
65 changed files with 1176 additions and 985 deletions

View file

@ -133,46 +133,26 @@ jobs:
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
- name: Test IMAGE support
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
- name: Test emulated METAL tensor cores
- name: Test emulated tensor cores
env:
DEV: 'PYTHON::METAL'
DEBUG: 2
N: 64
CNT: 1
SHOULD_USE_TC: 1
run: |
DEBUG=2 python3 test/backend/test_ops.py TestOps.test_big_gemm
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD tensor cores
env:
DEV: 'PYTHON::gfx1100'
parallel -k --link --tagstring '[{1}]' '{2} python3 ./extra/gemm/simple_matmul.py' \
::: metal gfx950 gfx1100 gfx1100_acchalf gfx1201 gfx1201_acchalf sm_75 sm_80_half sm_80_tf32 \
::: 'DEV=PYTHON::METAL' 'DEV=PYTHON::gfx950 HALF=1 ACC_HALF=0' \
'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=1 ATOL=1e-3' \
'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=1 ATOL=1e-3' \
'DEV=PYTHON::sm_75 HALF=1' 'DEV=PYTHON::sm_80 HALF=1' 'DEV=PYTHON::sm_80 ALLOW_TF32=1'
- name: Run additional tensor core tests
run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD MFMA tensor cores
env:
DEV: 'PYTHON::gfx950'
run: |
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD RDNA4 tensor cores
env:
DEV: 'PYTHON::gfx1201'
run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated CUDA tensor cores
run: |
DEBUG=2 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
DEV=PYTHON::METAL python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEV=PYTHON::gfx1100 python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEV=PYTHON::gfx950 python3 -m pytest -nauto test/opt/test_tensor_cores.py
DEV=PYTHON::gfx1201 python3 -m pytest -nauto test/opt/test_tensor_cores.py
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test device flop counts
run: |
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
linter:
@ -267,13 +247,6 @@ jobs:
run: python3 test/external/external_benchmark_schedule.py
- name: Run process replay tests
uses: ./.github/actions/process-replay
- name: Regen dataset on test_tiny
run: |
test/external/process_replay/reset.py
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
python extra/optimization/extract_dataset.py
gzip -c /tmp/sops > extra/datasets/sops.gz
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
- name: Repo line count < 25000 lines
run: MAX_LINE_COUNT=25000 python sz.py
@ -338,31 +311,6 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
testgpumisc:
name: CL Misc tests
runs-on: *linux
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: gen-dataset
deps: testing
opencl: 'true'
- name: Generate Dataset
run: DEV=CL extra/optimization/generate_dataset.sh
- name: Run Kernel Count Test
run: DEV=CL python -m pytest -n=auto test/external/external_test_opt.py
- name: Run fused optimizer tests
run: DEV=CL FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py test/backend/test_optim.py -k "not muon"
- name: Upload artifact
uses: actions/upload-artifact@v7
with:
name: sops.gz
path: /tmp/sops.gz
testopenpilot:
name: openpilot Compile Tests
runs-on: *linux
@ -379,7 +327,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness)
run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
@ -422,7 +370,6 @@ jobs:
with:
key: optim
deps: testing
pydeps: "tensorflow==2.19"
opencl: 'true'
#- name: Test Optimization Helpers
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
@ -431,7 +378,7 @@ jobs:
- name: Test Beam Search
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Test MLPerf stuff
run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
- name: DEV=NULL beautiful_mnist_multigpu
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
- name: Test Bert training

View file

@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
Throw in an optimizer, a data loader, and some compute, and you have all you need.
```python
from tinygrad import Tensor, nn
from tinygrad import Tensor, nn, Context
class LinearNet:
def __init__(self):
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
with Tensor.train():
with Context(TRAINING=1):
for i in range(10):
optim.zero_grad()
loss = model(x).sparse_categorical_crossentropy(y).backward()

View file

@ -165,13 +165,14 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64.
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training.
We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager.
```python
from tinygrad import Context
X_train, Y_train, X_test, Y_test = fetch_mnist()
with Tensor.train():
with Context(TRAINING=1):
for step in range(1000):
# random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64))

View file

@ -1,6 +1,6 @@
from typing import Tuple
import time
from tinygrad import Tensor, TinyJit, nn
from tinygrad import Tensor, TinyJit, nn, Context
import gymnasium as gym
from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import
@ -55,7 +55,7 @@ if __name__ == "__main__":
@TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train():
with Context(TRAINING=1):
log_dist, value = model(x)
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()

View file

@ -122,7 +122,7 @@ if __name__ == "__main__":
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step(idxs:Tensor) -> Tensor:
X, Y = X_train[idxs], Y_train[idxs]
if len(GPUS) > 1:

View file

@ -1,6 +1,6 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
@ -19,7 +19,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])

View file

@ -1,6 +1,6 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
@ -31,7 +31,7 @@ if __name__ == "__main__":
@TinyJit
def train_step() -> Tensor:
with Tensor.train():
with Context(TRAINING=1):
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0

View file

@ -1,6 +1,6 @@
import itertools
from typing import Callable
from tinygrad import nn, Tensor, dtypes, Device, TinyJit
from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
from tinygrad.helpers import getenv, trange, partition
class Model:
@ -59,7 +59,7 @@ if __name__ == "__main__":
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def microbatch():
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
for t in params: t.grad = None

View file

@ -359,7 +359,7 @@ def train_cifar():
i = 0
eval_acc_pct = 0.0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train():
with Context(TRAINING=1):
st = time.monotonic()
while i <= STEPS:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3
import os, math, time
import numpy as np
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
from dataclasses import dataclass
@dataclass
@ -177,7 +177,7 @@ if __name__ == "__main__":
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y)
optimizer.zero_grad()
@ -204,4 +204,3 @@ if __name__ == "__main__":
top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))

View file

@ -1,5 +1,5 @@
# much taken from https://github.com/cloneofsimo/minRF
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@ -135,7 +135,7 @@ if __name__ == "__main__":
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step():
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])

View file

@ -358,7 +358,7 @@ def eval_stable_diffusion():
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
return batch, unpadded_bs
@Tensor.train(mode=False)
@Context(TRAINING=0)
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
# Eval is divided into 5 jits, one per model

View file

@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
from pathlib import Path
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@ -614,7 +614,7 @@ def train_retinanet():
if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False):
with Context(TRAINING=0):
if not RUNMLPERF:
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
else:
@ -784,7 +784,7 @@ def train_unet3d():
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step(model, x, y):
optim.zero_grad()
@ -795,7 +795,7 @@ def train_unet3d():
optim.step()
return loss.realize()
@Tensor.train(mode=False)
@Context(TRAINING=0)
def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y)
@ -1490,7 +1490,7 @@ def train_llama3():
return lr_cpu, grad_norm_cpu
@TinyJit
@Tensor.train(False)
@Context(TRAINING=0)
def eval_step(tokens:Tensor):
if is_dp: tokens = tokens.to(None).shard(device, 0)
if is_mp: tokens = tokens.shard(device)
@ -1803,7 +1803,7 @@ if __name__ == "__main__":
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
else: bench_log_manager = contextlib.nullcontext()
with Tensor.train():
with Context(TRAINING=1):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
nm = f"train_{m}"
if nm in globals():

View file

@ -38,7 +38,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]:
if not fp8:
if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
@ -47,12 +47,14 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
if MXFP8:
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d)
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
if can_use_asm_gemm(x_q, w.T):
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0])
mx_w_stored=True).reshape(*l_shape, w.shape[0])
else:
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1])
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1])
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
return out, (amax_x.detach() if amax_x is not None else None), x_q
if x_fp8 is None:
@ -126,10 +128,8 @@ class FlatTransformer:
# FeedForward
if SPLIT_W13:
if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim)
else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02)
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim)
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim)
else:
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
@ -160,7 +160,7 @@ class FlatTransformer:
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
if w is None:
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize()
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
@ -216,8 +216,15 @@ class FlatTransformer:
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
amaxs.append(new_amax)
saves.extend([*s, x_w3])
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
grad_amax_state=kwargs["grad_amax_xout"])
if FUSED_SILU_W13 and MXFP8:
from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
else:
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
grad_amax_state=kwargs["grad_amax_xout"])
amaxs.append(new_amax)
saves.extend([*s, out])
else:
@ -247,20 +254,30 @@ class FlatTransformer:
for v in get_parameters(self): v.shard_(device, axis=None)
else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
def _shard_fp8(name:str, axis:int):
getattr(self, name).shard_(device, axis=axis)
scale_axis = axis if MXFP8 else (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
def _shard_fp8(name:str, axis:int, std:float=0.02):
w = getattr(self, name)
if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std
w_q, w_e8, _ = quantize_mxfp8(w_bf16)
w.replace(w_q)
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
else:
w.shard_(device, axis=axis)
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
sstd = 0.02 / math.sqrt(2 * self.n_layers)
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
_shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in
if SPLIT_W13:
_shard_fp8("w1", 1)
_shard_fp8("w3", 1)
else:
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
_shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize()

View file

@ -3,7 +3,7 @@ import torch
from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.helpers import trange
from tinygrad.helpers import trange, Context
from tinygrad.nn import optim
from tinygrad.nn.datasets import mnist
@ -86,7 +86,7 @@ if __name__ == "__main__":
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop
with Tensor.train():
with Context(TRAINING=1):
for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps):

View file

@ -5,7 +5,7 @@
# - symbolic removal
from examples.beautiful_mnist import Model
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
from tinygrad.nn.datasets import mnist
from tinygrad.helpers import trange
@ -26,7 +26,7 @@ if __name__ == "__main__":
X_samp, Y_samp = X_train[samples], Y_train[samples]
print("*** got samples")
with Tensor.train():
with Context(TRAINING=1):
"""
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)

View file

@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
acc = acc.after(acc.store(acc.zeros_like()))
acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
if use_wmma:
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)

View file

@ -2674,14 +2674,14 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# 1x32 block scaling along the last axis
rows, K = x.shape
scale_K, k_iters = K // 32, K // 128
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
*batch, K = x.shape
scale_K = K // 32
amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
x_scaled = x.float() * qscale
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8)
return x_clamped.cast(FP8_DTYPE), e8, (mx_pack(e8) if len(batch) == 1 else None)
def mx_pack(e8:Tensor) -> Tensor:
rows, scale_K = e8.shape

View file

@ -143,14 +143,17 @@ def make_getaddr(u, device=None):
def make_ins(op, *srcs):
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
dt = dtype or val.dtype
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
def make_cmdbuf(lin, devs, tag):
blob, patches = b'', []
for s in (s for ins in lin.src for s in ins.src):
if s.op is not Ops.CONST: patches.append((len(blob), s))
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches])
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
@ -211,15 +214,11 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
data, info = prg.arg
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
dtypes.uint64) for i,gi in enumerate(info.globals)] \
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
(data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device)
buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \
+ [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)]
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
pm_prep_runtime = PatternMatcher([
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
@ -307,7 +306,7 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}),)))
return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
@ -532,9 +531,9 @@ pm_resolve_patches = PatternMatcher([
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
# index on slice is index
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
# shrink on slice is shrink on base at offset
(UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"),
lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))),
# getaddr
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
@ -542,8 +541,8 @@ pm_resolve_patches = PatternMatcher([
# folders
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
fold_const_store),
(UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast()
.store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store),
]) + symbolic_simple
# *****************

View file

@ -0,0 +1,104 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
LOG2E = 1.4426950408889634
@functools.cache
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
act = w1 * sig * w3
abs_a = (act < 0.0).where(-act, act)
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
@functools.cache
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
VEC = 8
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
nwg = n_elems // (THREADS_PER_WG * VEC)
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
lane = UOp.range(VEC, 2, AxisType.UNROLL)
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
e8v = e8[idx // BLK].cast(dtypes.float)
qscale = (127.0 - e8v).exp2()
ga = grad_aq[idx].cast(dtypes.float) * qscale
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
s = w1 * sig
sprime = sig * (1.0 + w1 * (1.0 - sig))
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
device = x_w1.device
rows, K = x_w1.shape
axis = x_w1.axis if isinstance(device, tuple) else None
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
fxn=_custom_silu_mul_bwd_mxfp8)
return (None, None, None, gx1.uop, gx3.uop)
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x_w1.shape
scale_K = K // BLK
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
return fp8_out, e8_out, si_out

View file

@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
step = THREADS_PER_WG // 2
while step:
active = tid < step
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
other = lds[(tid + step).valid(active)].load()
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
step //= 2
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])

View file

@ -1,6 +1,6 @@
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import trange
from tinygrad.helpers import trange, Context
from tinygrad.engine.jit import TinyJit
@ -22,7 +22,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
if allow_jit: train_step = TinyJit(train_step)
with Tensor.train():
with Context(TRAINING=1):
losses, accuracies = [], []
for i in (t := trange(steps, disable=None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
@ -55,4 +55,3 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
acc, Y_test_pred = numpy_eval(Y_test, num_classes)
print("test set accuracy is %f" % acc)
return (acc, Y_test_pred) if return_predict else acc

View file

@ -25,7 +25,7 @@
import unittest
import numpy as np
import torch
from tinygrad import Tensor, dtypes, nn
from tinygrad import Tensor, dtypes, nn, Context
from tinygrad.device import Device
from tinygrad.helpers import DEV
from tinygrad.renderer.nir import NIRRenderer
@ -101,7 +101,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
# we don't need more of these
def test_dropout_rate_one(self):
with Tensor.train():
with Context(TRAINING=1):
out = Tensor.ones(100).dropout(1.0)
np.testing.assert_allclose(out.numpy(), np.zeros(100))
@ -109,7 +109,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
with self.assertRaises(ValueError):
torch.nn.functional.dropout(torch.ones(10), -0.1, True)
with self.assertRaises(ValueError):
with Tensor.train():
with Context(TRAINING=1):
Tensor.ones(10).dropout(-0.1)
class TestInputValidation(unittest.TestCase):

View file

@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
renderer=Device[Device.DEFAULT].renderer).src[2].src)
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
assert num_loads >= 1, "expected at least one load uop"
@unittest.skip("this is handled at higher level now")
def test_upcast_cse(self):

View file

@ -14,7 +14,7 @@ from test.helpers import not_support_multi_device, needs_second_gpu, slow
@slow
class TestNN(unittest.TestCase):
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
with Tensor.train(training):
with Context(TRAINING=training):
szs = [4, 8, 16, 32]
for sz in szs:
# create in tinygrad

View file

@ -41,7 +41,7 @@ class TestStunning(unittest.TestCase):
X_samp, Y_samp = X_train[samples], Y_train[samples]
vi = Variable('i', 0, samples.shape[0]-1)
with Context(SPLIT_REDUCEOP=0):
with Tensor.train():
with Context(TRAINING=1):
losses = []
for i in range(samples.shape[0]):
vib = vi.bind(i)

View file

@ -1,5 +1,5 @@
import unittest
from tinygrad import Tensor, Variable, GlobalCounters
from tinygrad import Tensor, Variable, GlobalCounters, Context
from tinygrad.uop.ops import sym_infer
from tinygrad.dtype import dtypes
from examples.gpt2 import Attention
@ -63,7 +63,7 @@ class TestSymbolicOps(unittest.TestCase):
self.test_attention(imin=4, imax=5, use_symbolic=True)
def test_attention_training(self):
with Tensor.train():
with Context(TRAINING=1):
self.test_attention(dropout_p=0.0)
with self.assertRaises(ValueError):
# symbolic shape dropout is not supported

View file

@ -1,7 +1,7 @@
import numpy as np
import torch
import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad import Tensor, Device, dtypes, nn, Context
from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
@ -203,7 +203,7 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_dropout(self):
with Tensor.train():
with Context(TRAINING=1):
n, rate = 1_000_000, 0.1
w = Tensor.ones(n).dropout(rate)
non_zeros = np.count_nonzero(w.numpy())

View file

@ -0,0 +1,67 @@
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import AdamW
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TinyNet:
def __init__(self):
self.x = Tensor(x_init.copy())
self.W = Tensor(W_init.copy())
self.m = Tensor(m_init.copy())
def forward(self):
out = self.x.matmul(self.W).relu()
out = out.log_softmax(1)
out = out.mul(self.m).add(self.m).sum()
return out
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__':
unittest.main()

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python
import unittest, math
import unittest
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Lamb
@ -7,11 +7,11 @@ from tensorflow.python.ops import math_ops
from extra.lr_scheduler import LRSchedulerGroup
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
np.random.seed(1337)
@ -173,48 +173,5 @@ class ExternalTestOptim(unittest.TestCase):
'warmup': steps_per_epoch * warmup_epochs,
}, 1e-5, 1e-5, do_optim=False)
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__':
unittest.main()

View file

@ -1,6 +1,6 @@
import unittest, os
from tempfile import TemporaryDirectory
from tinygrad import Tensor
from tinygrad import Context
from tinygrad.helpers import getenv
from examples.mlperf.model_train import train_stable_diffusion
@ -14,10 +14,10 @@ class TestTrain(unittest.TestCase):
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
with TemporaryDirectory(prefix="test-train") as tmp:
os.environ["UNET_CKPTDIR"] = tmp
with Tensor.train():
with Context(TRAINING=1):
saved_ckpts = train_stable_diffusion()
expected_ckpt = f"{tmp}/{num_steps}.safetensors"
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
if __name__=="__main__":
unittest.main()
unittest.main()

View file

@ -3,7 +3,7 @@ import ast, pathlib, unittest
import numpy as np
from PIL import Image
from tinygrad import Tensor
from tinygrad import Tensor, Context
from tinygrad.helpers import getenv
from test.helpers import slow
from extra.models.efficientnet import EfficientNet
@ -40,7 +40,7 @@ def preprocess(img, new=False):
return img
def _infer(model: EfficientNet, img):
with Tensor.train(False):
with Context(TRAINING=0):
out = model.forward(Tensor(img)).argmax(axis=-1)
return out.tolist()

View file

@ -5,10 +5,11 @@ import numpy as np
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor
from tinygrad.helpers import Context
from extra.datasets import fetch_mnist
def compare_tiny_torch(model, model_torch, X, Y):
with Tensor.train():
with Context(TRAINING=1):
model_torch.train()
model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters():

View file

@ -106,7 +106,7 @@ class TestRealWorld(unittest.TestCase):
@slow
def test_train_mnist(self):
from examples.beautiful_mnist import Model
with Tensor.train():
with Context(TRAINING=1):
model = Model()
optimizer = optim.Adam(get_parameters(model))
BS = 32
@ -125,7 +125,7 @@ class TestRealWorld(unittest.TestCase):
def test_forward_cifar(self):
BS = 32
# with training batchnorm still though
with Tensor.train():
with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
@TinyJit
def run(X): return model(X)
@ -133,7 +133,7 @@ class TestRealWorld(unittest.TestCase):
@slow
def test_train_cifar(self):
with Tensor.train():
with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
BS = 32
@ -151,7 +151,7 @@ class TestRealWorld(unittest.TestCase):
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
def test_train_cifar_hyp(self):
dtypes.default_float = dtypes.float16
with Tensor.train():
with Context(TRAINING=1):
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
initial_div_factor = hyp['opt']['initial_div_factor']
@ -163,7 +163,7 @@ class TestRealWorld(unittest.TestCase):
@slow
def test_bert(self):
with Tensor.train():
with Context(TRAINING=1):
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
model = BertForPretraining(**args_tiny)

View file

@ -1093,14 +1093,14 @@ class TestSchedule(unittest.TestCase):
#@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3, nn.state.get_parameters(bn))
def test_fold_conv_batchnorm_notrain(self):
with Tensor.train(False):
with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=True)
@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm_notrain_no_running_stats(self):
with Tensor.train(False):
with Context(TRAINING=0):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1116,7 +1116,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1125,7 +1125,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self, adam=False):
optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15)
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
@ -1139,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
def test_fold_batchnorm_backward(self):
with Tensor.train():
with Context(TRAINING=1):
x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16)
fw = bn(x).contiguous_backward().relu().contiguous()
@ -1484,7 +1484,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4)
def test_adam_step_fusion(self):
with Tensor.train():
with Context(TRAINING=1):
x = Tensor.empty(4, 64, 32)
layer = nn.Linear(32, 32*4)
_realize_weights(layer)
@ -1494,7 +1494,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13)
def test_adam_conv_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
_realize_weights(c1)
@ -1505,7 +1505,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 13)
def test_adam_2convs_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
@ -1517,7 +1517,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15)
def test_sgd_conv_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
_realize_weights(c1)
@ -1527,7 +1527,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 5) # TODO: 3?
def test_sgd_2convs_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
@ -1538,7 +1538,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 7)
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
@ -1550,7 +1550,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 11)
def test_sgd_4convs_fuse(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
@ -1563,7 +1563,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(opt.schedule_step(), 15)
def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train():
with Context(TRAINING=1):
img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
@ -1664,7 +1664,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0)
def test_resnet_block(self):
with Tensor.train(False):
with Context(TRAINING=0):
in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes)

View file

@ -16,41 +16,17 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(nmax, n)
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
f3 = (idx2*2)+ridx1+(-1)
self.assertFalse(f0.is_increasing())
self.assertTrue(f1.is_increasing())
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid, extra=()):
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
@ -506,6 +482,16 @@ class TestImageSimplification(unittest.TestCase):
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
def test_drop_non_monotonic_window(self):
# two-sided window valid (645 <= gidx0 < 653) on a non-monotonic index (lane split via %4 and //4):
# gidx0 outside the window pushes idx_x out of the (1, 48) image, so the gate is dropped
gidx0 = Special("gidx0", 1064)
r12 = Range(12, 3)
valid = ((gidx0 < 645).ne(True)) & (gidx0 < 653)
idx = (r12*4 + (gidx0+3)%4 + (gidx0+3)//4*24 - 3888, UOp.const(dtypes.weakint, 0))
load = get_load_image_uop((1, 48, 4), valid, idx)
self.check(load, None, "(r12*4+(gidx0+3)%4+(gidx0+3)//4*24+-3888)", "0")
class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True valid gets simplified to drop the valid

View file

@ -1,7 +1,7 @@
# tensor tests that pass on NULL backend (no copyout needed)
import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import Ops, UOp
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
@ -15,7 +15,7 @@ m_init = np.random.randn(1,3).astype(np.float32)
class TestTrainMode(unittest.TestCase):
def test_train_mode(self):
assert not Tensor.training
@Tensor.train()
@Context(TRAINING=1)
def f():
assert Tensor.training
f()

View file

@ -317,6 +317,19 @@ class TestTensorUOpScatterReduce(unittest.TestCase):
def test_mean_exclude_self(self):
self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean", include_self=False)
class TestTensorUOpMaskedSelect(unittest.TestCase):
# only the fixed-size path is pure
def _check(self, t, mask, **kw):
self.assertIs(t.masked_select(mask, **kw).uop, t.uop.masked_select(mask.uop, **kw))
def test_masked_select_1d(self): self._check(_t(6), Tensor([True, False, True, False, True, False]), size=4)
def test_masked_select_2d(self):
self._check(_t(3, 3), Tensor([[True, False, True], [False, True, False], [False, False, True]]), size=6, fill_value=-1)
class TestTensorUOpNonzero(unittest.TestCase):
def _check(self, t, **kw): self.assertIs(t.nonzero(**kw).uop, t.uop.nonzero(**kw))
def test_nonzero_1d(self): self._check(_t(5), size=3)
def test_nonzero_2d(self): self._check(_t(2, 3), size=4)
class TestTensorUOpPool(unittest.TestCase):
def test_avg_pool2d(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d())
def test_avg_pool2d_padding(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(padding=1))
@ -462,6 +475,10 @@ class TestTensorUOpCreation(unittest.TestCase):
self.assertIs(_strip_unique(Tensor.ones(2, 3).uop), _strip_unique(UOp.ones(2, 3)))
def test_invalids(self):
self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids((2, 3), dtype=dtypes.int8)))
def test_empty_like(self):
t = Tensor.empty(2, 3, dtype=dtypes.int8)
self.assertIs(_strip_unique(t.empty_like().uop), _strip_unique(t.uop.empty_like()))
self.assertIs(_strip_unique(t.empty_like(dtype=dtypes.float, device="NULL").uop), _strip_unique(t.uop.empty_like(dtypes.float, "NULL")))
def test_arange(self):
self.assertIs(Tensor.arange(5).uop, UOp.arange(5))
def test_arange_empty(self):

View file

@ -13,35 +13,26 @@ class TestWinograd(unittest.TestCase):
def test_forward_kernels(self):
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
out = Tensor.conv2d(x,w)
self.assertEqual(len(out.schedule_linear().src), 2)
self.assertEqual(len(out.schedule_linear().src), 4)
def test_backward_kernels(self):
x,w = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,4,3,3).realize()
out = Tensor.conv2d(x,w, padding=1)
out.mean().backward()
backward_schedule = x.grad.schedule_linear(w.grad)
self.assertEqual(len(backward_schedule.src), 2)
self.assertEqual(len(backward_schedule.src), 4)
@unittest.skip("this requires optimizations")
def test_counters(self):
IC, OC, X, Y = 4,4,9,9
x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize()
IC, OC, H = 64, 64, 28
x,w = Tensor.empty(1,IC,H,H,device="NULL").realize(), Tensor.empty(OC,IC,3,3,device="NULL").realize()
GlobalCounters.reset()
with Context(WINO=1):
Tensor.conv2d(x,w).realize()
ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
with Context(NOOPT=0, WINO=1): Tensor.conv2d(x,w).realize()
ops_wino = GlobalCounters.global_ops
GlobalCounters.reset()
with Context(WINO=0):
Tensor.conv2d(x,w).realize()
ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem
ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
# TODO: what's optimal on this?
self.assertLess(ops_ratio, 4.3)
self.assertLess(mem_ratio, 4)
with Context(NOOPT=0, WINO=0): Tensor.conv2d(x,w).realize()
ops_normal = GlobalCounters.global_ops
print(f"ops: normal {ops_normal} wino {ops_wino} ratio {ops_wino/ops_normal:.2f}")
self.assertLess(ops_wino/ops_normal, 0.6)
def test_dtype(self):
IC, OC, X, Y = 4,4,9,9

View file

@ -222,7 +222,7 @@ class TestCallSchedule(unittest.TestCase):
# find the FUNCTION nodes
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
# the function bodies (src[0]) should have identical keys
self.assertEqual(c0.src[0].key, c1.src[0].key)
def test_precompile_symbolic_2d(self):

View file

@ -1,8 +1,9 @@
import numpy as np
import unittest
from tinygrad.function import function
from tinygrad import Tensor, GlobalCounters
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad import Tensor, GlobalCounters, Device
from tinygrad.dtype import dtypes, Invalid
from tinygrad.uop.ops import UOp, Ops, KernelInfo, ProgramInfo
class TestFunction(unittest.TestCase):
def test_simple(self):
@ -549,6 +550,36 @@ class TestFunctionTuple(unittest.TestCase):
f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize()
np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.])
def test_custom_kernel_program_invalids_not_captured(self):
# llama FP8 kernels are PROGRAM with bare-buffer sinks (no analyzable stores), so the invalids scratch
# still must not be captured as an input -- else it is read before the kernel writes it
src = "void k(float* restrict data0, float* restrict data1) { for (int i=0;i<4;i++) data0[i]=data1[i]*2.0f; }"
lib = Device["CPU"].compiler.compile(src)
def prog(C:UOp, A:UOp) -> UOp:
sink = UOp.sink(C.base, A.base, arg=KernelInfo(name="k"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="CPU"), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)),
arg=ProgramInfo(name="k", global_size=(1, 1, 1), local_size=(1, 1, 1), globals=(0, 1)))
@function(precompile=True)
def f(a:Tensor):
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
return Tensor.custom_kernel(c, a, fxn=prog)[0]
a = Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()
np.testing.assert_allclose(f(a).numpy(), [2., 4., 6., 8.])
def test_invalid_store_into_realized_buffer_is_captured(self):
# only fresh invalids() scratch is skipped; a realized buffer is a real input even if an Invalid store
# writes into part of it (its other elements must be preserved), so it is still captured
state = Tensor([10., 20., 30., 40.], device="CPU").contiguous().realize()
@function(precompile=True, allow_implicit=True)
def f(a:Tensor):
after = state.uop.after(state.uop.shrink(((0, 2),)).store(UOp.const(dtypes.float32, Invalid, shape=(2,))))
return Tensor(after).contiguous() + a
out = f(Tensor([1., 1., 1., 1.], device="CPU").contiguous().realize())
np.testing.assert_allclose(out.numpy(), [11., 21., 31., 41.])
def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2):
devs = ("CPU:0", "CPU:1")
def my_kernel(C:UOp, A:UOp) -> UOp:

View file

@ -207,7 +207,7 @@ class TestMultiTensor(unittest.TestCase):
out.numpy()
def test_backprop_conv(self):
with Tensor.train():
with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.Adam(get_parameters(conv))
@ -511,7 +511,7 @@ class TestMultiTensor(unittest.TestCase):
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_dropout_on_shard(self):
with Tensor.train():
with Context(TRAINING=1):
X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
@ -519,7 +519,7 @@ class TestMultiTensor(unittest.TestCase):
assert 96 < counts[0] < 160, counts[0]
def test_dropout_on_shard_axis(self):
with Tensor.train():
with Context(TRAINING=1):
X = Tensor.ones(512).shard(devices_2, axis=0)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
@ -664,7 +664,7 @@ class TestBatchNorm(unittest.TestCase):
def setUp(self): pass
def test_unsynced_backprop_conv_bn(self):
with Tensor.train():
with Context(TRAINING=1):
from extra.lr_scheduler import OneCycleLR
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
@ -709,7 +709,7 @@ class TestBatchNorm(unittest.TestCase):
bn_ts.append(bni)
return bn_ts[0].cat(*bn_ts[1:])
with Tensor.train():
with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3)
bn = BatchNorm(16)
@ -731,7 +731,7 @@ class TestBatchNorm(unittest.TestCase):
from examples.hlb_cifar10 import UnsyncedBatchNorm
GPUS = (d1, d2)
with Tensor.train():
with Context(TRAINING=1):
conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
@ -756,7 +756,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0)
with Tensor.train(is_training):
with Context(TRAINING=is_training):
bns = []
for _ in range(len(devices)):
bn = nn.BatchNorm2d(8)
@ -777,7 +777,7 @@ class TestBatchNorm(unittest.TestCase):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train():
with Context(TRAINING=1):
synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))

View file

@ -13,16 +13,17 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps, get_simplifying_rewrite_patterns
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_make_images
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.late.gater import pm_move_gates_from_index
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
from tinygrad.codegen.late.coalese import memory_coalesing
pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK
@ -52,12 +53,8 @@ pm_number_params = PatternMatcher([
(UPat(Ops.PARAM, name="x"), do_number_param),
])
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_load_to_alu = PatternMatcher([
# NOTE: the PtrDType thing is temporary
(UPat(GroupOp.Elementwise|{Ops.STACK,Ops.GEP}, name="x"), lambda x:
x.replace(src=tuple([maybe_load(u) for u in x.src])) if not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
pm_no_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int))
])
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
@ -86,7 +83,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = apply_opts(sink, ren, beam=ast.arg.beam)
# ** expander (expand_rewrite) **
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
# expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
@ -104,7 +101,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# **** optimizations are done, now we lower to actual code ****
# add loads and remove invalids
sink = graph_rewrite(sink, pm_load_to_alu+pm_remove_invalid, name="** add loads (code)")
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
# create image buffers
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
@ -121,18 +118,23 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# optional pre matcher
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
# decompositions
# floordiv+mod / dtype decomp (early)
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
pm_decomp = symbolic_simple+get_simplifying_rewrite_patterns(supported_ops)
sink = graph_rewrite(sink, pm_decomp, name="early decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
# GEP/STACK stuff
# do memory coalesing (late)
sink = memory_coalesing(sink, ren)
# instruction selection decompositions
pm_decomp = pm_decomp+\
get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))+\
get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="late decompositions")
# this is new style (TODO: this should all be removed)
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
# this is new style
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
@ -141,7 +143,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends+pm_no_weakints
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
# this was the linearizer

View file

@ -0,0 +1,73 @@
from typing import Any
import itertools
from collections import defaultdict
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import getenv
from tinygrad.renderer import Renderer
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
if getenv("DMC"): return sink
# collect
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
for u in sink.toposort():
# TODO: this should handle images too, it's just memory coalesing
if u.op in {Ops.LOAD, Ops.STORE} and not isinstance(u.src[0].src[0].dtype, ImageDType):
assert len(u.src) == (2 if u.op is Ops.STORE else 1), "memory coalesing does not support gated loads/stores"
assert u.src[0].op is Ops.INDEX
buf, idx_u = u.src[0].src
if buf.addrspace == AddrSpace.REG: continue
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
# build replacements
replacements = {}
for (op,buf,base,valid),offsets in memory.items():
# allowed lengths (copied in)
lengths = []
must_divide = True
if ctx is not None and ctx.target.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded
# do the grouping
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for full_grp in grouped_offsets:
while len(full_grp):
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
grp = full_grp[:length]
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op == Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1, f"attempting multiple stores: {len(offsets[g])}"
datas.append(offsets[g][0].src[1])
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
store = idx.store(data, valid) if valid is not None else idx.store(data)
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
else:
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
for i,g in enumerate(grp):
for oo in offsets[g]:
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
full_grp = full_grp[length:]
# apply
return sink.substitute(replacements, name="memory coalesing")

View file

@ -14,7 +14,7 @@ from tinygrad.renderer import Renderer
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in valid.split_uop(Ops.AND):
for i,stmt in enumerate(valid.split_uop(Ops.AND)):
if (res:=parse_valid(stmt)) is None: continue
X, is_upper_bound, c = res
@ -25,12 +25,12 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
drop_stmt.append(stmt)
continue
# if X <= c, check if it's out of bound when X = c+1
# if X >= c, check if it's out of bound when X = c-1
test_value = c + 1 if is_upper_bound else c - 1
for i,b in zip(idx.src, (width, height)):
if i.is_increasing():
rw = i.substitute({X:X.const_like(test_value)})
# check if idx is out of bound when X is on the wrong side of the bound: X in [c+1, vmax] or [vmin, c-1]
lo, hi = (c + 1, X.vmax) if is_upper_bound else (X.vmin, c - 1)
if lo <= hi:
fake = UOp.variable(f"fake{i}", lo, hi, X.dtype)
for coord,b in zip(idx.src, (width, height)):
rw = coord.substitute({X:fake}).simplify()
if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt)
break
@ -162,18 +162,8 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# determine fold lengths
lengths = []
must_divide = True
if ctx is not None and ctx.target.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
# TODO: this belongs in coalese
if isinstance(buf.dtype, ImageDType): lengths = [4]
lengths.append(1) # worst case, it's not folded
# filter fold lengths that don't divide
@ -289,7 +279,6 @@ pm_render = PatternMatcher([
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
(UPat(Ops.PTRCAT, src=(UPat(name='x'),)), lambda x: x),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
@ -357,6 +346,21 @@ pm_reduce = PatternMatcher([
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
])
# add loads
def add_load(idx:UOp):
if isinstance(idx.dtype, PtrDType): return None
assert isinstance(idx.src[0].dtype, PtrDType), f"param is not PtrDType {idx.src[0].dtype}"
return idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)
pm_add_loads = PatternMatcher([
# add loads to non ptr index
(UPat(Ops.INDEX, name="idx"), add_load),
# remove loads from stores
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
(UPat(Ops.LOAD, src=(UPat(Ops.LOAD),), allow_any_len=True, name="l"), lambda l: l.replace(src=(l.src[0].src[0],)+l.src[1:])),
])
# make images
pm_imageh_store = PatternMatcher([

View file

@ -101,6 +101,12 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
is_masked = k.rngs[axis] in where_gate_rngs
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
# upcasting a masked global axis moves that range out of the launch grid into each work-item
# under IMAGE, skip the upcast unless enough global work-items remain after it to hide memory latency
if IMAGE and k.axis_types[axis] is AxisType.GLOBAL:
global_upcast = prod(k.full_shape[i] for i in to_upcast if k.axis_types[i] is AxisType.GLOBAL) * k.full_shape[axis]
global_items_after = prod(k.full_shape[i] for i in k.axes_of(AxisType.GLOBAL)) // global_upcast
if resolve(global_items_after < getenv("OCCUPANCY_FLOOR", 4096), False): continue
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))

View file

@ -231,8 +231,7 @@ def _prepare_jit_inputs(args, kwargs):
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
def get_input_uops() -> list[UOp]: return flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
# TODO: drop the CONST branch once all CONST are deviceless
if any(u.device is None or u.base.op is Ops.CONST for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
if any(u.device is None for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
input_uops = get_input_uops()
# collect buffer UOps (including MultiBuffer)

View file

@ -1,26 +1,29 @@
import functools, itertools, time
import functools, time
from typing import Generic, TypeVar, Callable, cast, overload
from tinygrad.helpers import Context, dedup, getenv, DEBUG
from tinygrad.dtype import Invalid
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict
def add_to_ctx(ctx, x:UOp):
if x.buf_uop in ctx[1]: return None
ret = x.param_like(len(ctx[0]))
ctx[0].append(x)
return ret
pm_transform_unique_const = PatternMatcher([
# transform unique consts to LUNIQUE
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))),
])
pm_ctx = PatternMatcher([
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
])+pm_transform_unique_const
])
def invalid_outputs(uret:UOp) -> set[UOp]:
# invalids() returns fresh write-only scratch: a clone storing CONST(Invalid)
# don't capture it as an input; only skip fresh buffers, not realized ones
return {u.src[0].buf_uop for u in uret.backward_slice_with_self
if u.op is Ops.STORE and u.src[1].base.op is Ops.CONST and u.src[1].base.arg is Invalid
and not u.src[0].buf_uop.is_realized}
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
@ -63,7 +66,7 @@ class _function(Generic[ReturnType]):
# the BUFFERs that are left are the implicit inputs
num_explicit = len(call_uops)
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
uret = graph_rewrite(uret, pm_ctx, (call_uops, invalid_outputs(uret)), bottom_up=True, name="get_implicit_inputs")
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
if not self.allow_implicit:
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]

View file

@ -240,6 +240,7 @@ DEV, DEBUG, BEAM, NOOPT = _DEV("DEV", ""), ContextVar("DEBUG", 0), ContextVar("B
IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0)
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0)
TRAINING = ContextVar("TRAINING", 0)
USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0)
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)

View file

@ -1,13 +1,13 @@
from __future__ import annotations
import functools, itertools
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args
import functools, itertools, string
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args, cast
from tinygrad.mixin.elementwise import ElementwiseMixin
from tinygrad.mixin.movement import MovementMixin
from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod
from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod
from tinygrad.helpers import resolve_pool_pads, round_up
if TYPE_CHECKING:
@ -17,35 +17,59 @@ ReductionStr = Literal["mean", "sum", "none"]
class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory")
@classmethod
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
def item(self) -> PyConst:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
Returns the value of this tensor as a standard Python number.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
t = Tensor(42)
print(t.item())
```
"""
# TODO: enable this check
# if not buffer: assert device is None, "buffer=False does not support device specification"
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)]
def __getitem__(self, indices) -> Self: return self._getitem(indices)
def __getitem__(self, indices) -> Self:
"""
Retrieves a sub-tensor using indexing.
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
Examples:
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(12).reshape(3, 4)
print(t.numpy())
```
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
```python exec="true" source="above" session="tensor" result="python"
print(t[1, 2].numpy())
```
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
```python exec="true" source="above" session="tensor" result="python"
print(t[0:2, ::2].numpy())
```
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
```python exec="true" source="above" session="tensor" result="python"
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
```
- `None` Indexing: Add a new dimension to the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(t[:, None].shape)
```
NOTE: Out-of-bounds indexing results in a value of `0`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t[Tensor([4, 3, 2])].numpy())
```
"""
return self._getitem(indices)
def _getitem(self, indices, v=None) -> Self:
from tinygrad.uop.ops import UOp
@ -138,40 +162,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
@classmethod
def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self:
"""
@ -367,7 +357,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
raise NotImplementedError(f"{mode=} is not supported")
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
def _broadcasted(self, y:Self|ConstType|UOp, reverse:bool=False) -> tuple[Self, Self]:
if not isinstance(y, type(self)): y = self.ufix(y)
x, y = (self, y) if not reverse else (y, self)
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
@ -423,6 +413,47 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
def __matmul__(self, x:Self) -> Self: return self.matmul(x)
def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True)
@classmethod
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]:
"""
Computes the gradient of the targets with respect to self.
@ -1147,6 +1178,68 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
# select from values for each True element in mask else select from self
return mask.where(values, self)
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
"""
Selects elements from `self` based on the boolean `mask`.
With `size=None` (default), output length equals the number of `True` values (not jittable).
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
print(t.numpy())
print(mask.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
```
"""
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
mask_cumsum = mask.cumsum()
if size is None:
counts = type(self).zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
counts = type(self).zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
return (type(self).arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Self:
"""
Returns the indices of the elements that are non-zero.
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 0, 2, 0, 3])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0], [0, 2]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero(size=3, fill_value=-1).numpy())
```
"""
if self.ndim == 0:
return type(self).zeros(size if size is not None else int(self.ne(0).item()), 0, dtype=dtypes.int32, device=self.device)
mask = self.ne(0).flatten()
indices = type(self).stack(*[type(self).arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
for i, s in enumerate(self.shape)], dim=-1)
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
# ***** functional nn ops *****
def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:

View file

@ -1,13 +1,101 @@
from typing import Self
from tinygrad.dtype import ConstType, DType
from typing import TYPE_CHECKING, Callable, Self
from tinygrad.dtype import ConstType, DTypeLike, Invalid, dtypes, to_dtype
from tinygrad.helpers import argfix
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin
class CreationMixin:
def const_like(self, b: ConstType) -> Self: raise NotImplementedError
def cast(self, dtype: DType) -> Self: raise NotImplementedError
if TYPE_CHECKING:
from tinygrad.uop.ops import sint, UOp
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self:
"""Creates a tensor with the same shape as `self`, filled with the given value."""
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype)
class CreationMixin(DTypeMixin, MovementMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
def _multi_like(self, fxn:'Callable[[tuple[sint, ...], str|None], Self]') -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return self._wrap_uop(self._uop.empty_like(dtype, device))
@classmethod
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:'tuple[sint, ...]', fill_value:'ConstType|UOp', dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
```
"""
# TODO: enable this check
# if not buffer: assert device is None, "buffer=False does not support device specification"
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def full_like(self, fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: type(self).full(shape, fill_value, dtype=dtype or self.dtype, device=dev, buffer=buffer))
return type(self).full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device, buffer=buffer)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
def zeros_like(self, **kwargs) -> Self:
"""
@ -22,6 +110,23 @@ class CreationMixin:
"""
return self.full_like(0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
def ones_like(self, **kwargs) -> Self:
"""
Creates a tensor with the same shape as `self`, filled with ones.

View file

@ -1,13 +1,36 @@
from typing import Self
from tinygrad.dtype import DType, dtypes
from typing import TYPE_CHECKING, Self
from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
class DTypeMixin:
@property
def dtype(self) -> DType: raise NotImplementedError
@property
def _uop(self) -> 'UOp': raise NotImplementedError
def _wrap_uop(self, u:'UOp') -> Self: raise NotImplementedError
def cast(self, dtype:DType) -> Self: raise NotImplementedError
def cast(self, dtype:DTypeLike) -> Self:
"""
Casts `self` to the given `dtype`.
def bitcast(self, dtype:DType) -> Self: raise NotImplementedError
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.uint8)
print(t.dtype, t.numpy())
```
"""
return self if self.dtype == (dt:=to_dtype(dtype)) else self._wrap_uop(self._uop.cast(dt))
def bitcast(self, dtype:DTypeLike) -> Self: raise NotImplementedError
def element_size(self) -> int:
"""

View file

@ -3,23 +3,17 @@ from typing import TYPE_CHECKING, Literal, Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
from tinygrad.helpers import argfix, polyN
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.creation import CreationMixin
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
class ElementwiseMixin(DTypeMixin, CreationMixin):
class ElementwiseMixin(CreationMixin):
# required to implement
def alu(self, op: Ops, *src: Self) -> Self:
raise NotImplementedError
@property
def _uop(self) -> 'UOp': raise NotImplementedError
def _wrap_uop(self, u: 'UOp') -> Self: raise NotImplementedError
# great functions you get!
def ufix(self, x: 'Self|ConstType|UOp') -> Self:
return x if isinstance(x, type(self)) else self._wrap_uop(self._uop.ufix(x))
@ -51,7 +45,11 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
"""
return self.cast(dtypes.bool).ne(True)
def contiguous(self, *args, **kwargs) -> Self: raise NotImplementedError
def contiguous(self, **kwargs) -> Self:
"""
Returns a contiguous tensor.
"""
return self._wrap_uop(self._uop.contiguous(**kwargs))
def contiguous_backward(self) -> Self:
"""

View file

@ -1,6 +1,6 @@
from typing import cast
import math, dataclasses, itertools
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
import math, dataclasses
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort
from tinygrad.dtype import sum_acc_dtype
@ -33,7 +33,7 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grad_args = ctx.src
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
g if g.base.op is Ops.CONST and g.device is None else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
g if g.base.op is Ops.CONST else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
grads = compute_gradient(fxn, root_grad, set(params.values()))
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
@ -42,9 +42,6 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
# TODO: is this okay here?
from tinygrad.function import pm_transform_unique_const
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))

View file

@ -1,8 +1,10 @@
from __future__ import annotations
from typing import Self
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import ceildiv, prod
import math
from typing import Self, cast
from tinygrad.dtype import DType, DTypeLike, dtypes, least_upper_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, prod, TRAINING
from tinygrad.mixin import OpMixin
from tinygrad.device import canonicalize_device
class RandMixin(OpMixin):
@ -39,3 +41,286 @@ class RandMixin(OpMixin):
bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4))
out = cls._bits_to_rand(bits, shape, dtype)
return out.contiguous() if contiguous else out
@staticmethod
def _next_counter(device:str, num:int):
raise NotImplementedError("_next_counter requires the stateful per-device RNG counter, only implemented on Tensor")
@classmethod
def rand(cls, *shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
dt = to_dtype(dtype or dtypes.default_float)
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = cast(str, canonicalize_device(device))
key, counter = cls._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
return cls._rand(key, counter, shape, dt, contiguous=contiguous)
def rand_like(self, **kwargs) -> Self:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.rand_like(t).numpy())
```
"""
if isinstance(self.device, tuple):
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
dtype = kwargs.pop("dtype", self.dtype)
return self._multi_like(lambda shape, dev: type(self).rand(*shape, dtype=dtype, device=dev, **kwargs))
return type(self).rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.randn_like(t).numpy())
```
"""
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(to_dtype(dtype or self.dtype))
@classmethod
def randn(cls, *shape, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randn(2, 3).numpy())
```
"""
return cls.empty(*shape, **kwargs).randn_like(dtype=dtype) # type: ignore[attr-defined]
@classmethod
def randint(cls, *shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
Requires `low < high`. If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
return cls.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@classmethod
def normal(cls, *shape, mean=0.0, std=1.0, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
Requires `std >= 0`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
return std * cls.randn(*shape, **kwargs) + mean
@classmethod
def uniform(cls, *shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
Requires `low < high`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
return ((high-low) * cls.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
@classmethod
def scaled_uniform(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.scaled_uniform(2, 3).numpy())
```
"""
return cls.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
@classmethod
def glorot_uniform(cls, *shape, **kwargs) -> Self:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
@classmethod
def kaiming_uniform(cls, *shape, a:float = 0.01, **kwargs) -> Self:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
@classmethod
def kaiming_normal(cls, *shape, a:float = 0.01, **kwargs) -> Self:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return cls.normal(*shape, mean=0.0, std=std, **kwargs)
@classmethod
def randperm(cls, n:int, device=None, dtype=dtypes.int32, **kwargs) -> Self:
"""
Returns a tensor with a random permutation of integers from `0` to `n-1`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randperm(6).numpy())
```
"""
return cls.rand(n, device=device, **kwargs).argsort().cast(dtype)
def multinomial(self, num_samples:int = 1, replacement:bool = False) -> Self:
"""
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(20, replacement=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(3, replacement=False).numpy())
```
"""
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
weight = self.unsqueeze(0) if self.ndim == 1 else self
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
if replacement or num_samples == 1:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = type(self).rand(num_samples, cdf.shape[0], 1).to(self.device) # type: ignore[attr-defined]
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
else:
# Efraimidis-Spirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
def dropout(self, p=0.5) -> Self:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `TRAINING` is set (e.g. inside `Context(TRAINING=1)`).
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Context(TRAINING=1):
print(t.dropout().numpy())
```
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not TRAINING or p == 0: return self
if p == 1: return self.const_like(0)
return (self.rand_like(dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def scaled_dot_product_attention(self, key:Self, value:Self, attn_mask:Self|None=None, dropout_p:float=0.0,
is_causal:bool=False, enable_gqa:bool=False) -> Self:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if enable_gqa:
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
q = self
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value

View file

@ -1,8 +1,7 @@
import string
from typing import Self, Sequence, cast
from typing import Self, Sequence
from tinygrad.uop import Ops
from tinygrad.dtype import DTypeLike, dtypes, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, argsort, make_tuple, merge_dicts
from tinygrad.helpers import make_tuple
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin
@ -136,44 +135,3 @@ class ReduceMixin(DTypeMixin, MovementMixin):
```
"""
return self.bool().prod(axis, keepdim)
@classmethod
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))

View file

@ -22,6 +22,7 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \
if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: ctx[x.src[0]] if x.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else None),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"),
# GPU stuff

View file

@ -18,7 +18,7 @@ import sys
sys.setrecursionlimit(10000)
def add_ranges_to_store(ctx, x):
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == (): return None
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == () or x.src[0].max_numel() == x.src[1].max_numel() == 1: return None
assert x.src[0].shape == x.src[1].shape, "bad store shape"
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)
@ -543,9 +543,6 @@ to_define_global = PatternMatcher([
# remove device from local BUFFERIZE
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
# remove UNIQUE/DEVICE to dedup CONST
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
# renumber the ranges starting with 0 so that kernel deduping works
(UPat(Ops.RANGE, name="r"), renumber_range),
])

View file

@ -1,14 +1,13 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
from contextlib import ContextDecorator
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, to_dtype
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.helpers import suppress_finalizing, disable_gc, TRAINING
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
from tinygrad.mixin.rand import RandMixin
from tinygrad.schedule import create_linear_with_vars
@ -59,19 +58,25 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
class Tensor(RandMixin):
# TODO: deprecate this, always use TRAINING
class TensorMeta(type):
@property
def training(cls) -> bool: return bool(TRAINING.value)
@training.setter
def training(cls, mode:bool): TRAINING.value = int(mode)
class Tensor(RandMixin, metaclass=TensorMeta):
"""
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
```python exec="true" session="tensor"
from tinygrad import Tensor, dtypes, nn
from tinygrad import Tensor, dtypes, nn, Context
import numpy as np
import math
np.set_printoptions(precision=4)
```
"""
__slots__ = "uop", "is_param", "grad"
training: ClassVar[bool] = False
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
device:str|tuple|list|None=None, dtype:DTypeLike|None=None):
@ -125,9 +130,9 @@ class Tensor(RandMixin):
@suppress_finalizing
def __del__(self): all_tensors.pop(weakref.ref(self), None)
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, extra_args=(), **kwargs) -> Tensor:
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, **kwargs) -> Tensor:
srcs = (self,)+x
new_uop: UOp = fxn(*[t.uop for t in srcs], *extra_args, **kwargs)
new_uop: UOp = fxn(*[t.uop for t in srcs], **kwargs)
if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
# directly create the Tensor
ret = Tensor.__new__(Tensor)
@ -136,34 +141,18 @@ class Tensor(RandMixin):
all_tensors[weakref.ref(ret)] = None
return ret
# alu and const_like are used by the mixins
# alu, _uop, _wrap_uop and const are used by the mixins
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
@property
def _uop(self) -> UOp: return self.uop
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
@staticmethod
def const(dtype:DType, b:ConstType|UOp) -> Tensor: return Tensor(UOp.const(dtype, b))
@staticmethod
def invalids(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Tensor:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return Tensor(UOp.invalids(argfix(*shape), dtype, device))
def is_param_(self, is_param:bool=True) -> Tensor:
self.is_param = is_param
return self
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
def __repr__(self):
ld = self.uop
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>"
@ -275,6 +264,7 @@ class Tensor(RandMixin):
x = self.cast(self.dtype.base).contiguous()
if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU")
return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
def _data(self) -> memoryview: return self._buffer().as_memoryview()
def data(self) -> memoryview:
@ -290,19 +280,7 @@ class Tensor(RandMixin):
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape)
def item(self) -> PyConst:
"""
Returns the value of this tensor as a standard Python number.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor(42)
print(t.item())
```
"""
assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)]
return self._data().cast(self.dtype.base.fmt, self.shape)
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
def tolist(self) -> PyConst|list[Any]:
@ -464,13 +442,6 @@ class Tensor(RandMixin):
"""
return Tensor(UOp.empty(argfix(*shape), dtype, device))
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> Tensor:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return Tensor(self.uop.empty_like(dtype, device))
@staticmethod
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
"""
@ -533,261 +504,6 @@ class Tensor(RandMixin):
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
return Tensor._device_seeds[device], low.cat(high)
@staticmethod
def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
dt = to_dtype(dtype or dtypes.default_float)
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = cast(str, canonicalize_device(device))
key, counter = Tensor._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
return Tensor._rand(key, counter, shape, dt, contiguous=contiguous)
# ***** creation helper functions *****
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Tensor]) -> Tensor:
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self.uop.axis is None: return fxn(self.shape, None).shard(self.device)
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, d).uop for d in self.device])
return Tensor(stacked.multi(self.uop.axis))
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
def rand_like(self, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.rand_like(t).numpy())
```
"""
if isinstance(self.device, tuple):
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
dtype = kwargs.pop("dtype", self.dtype)
return self._multi_like(lambda shape, dev: Tensor.rand(*shape, dtype=dtype, device=dev, **kwargs))
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
# ***** random functions *****
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.randn_like(t).numpy())
```
"""
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or self.dtype)
@staticmethod
def randn(*shape, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randn(2, 3).numpy())
```
"""
return Tensor.empty(*shape, **kwargs).randn_like(dtype=dtype)
@staticmethod
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
Requires `low < high`. If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
Requires `std >= 0`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
return std * Tensor.randn(*shape, **kwargs) + mean
@staticmethod
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
Requires `low < high`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.scaled_uniform(2, 3).numpy())
```
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
@staticmethod
def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Returns a tensor with a random permutation of integers from `0` to `n-1`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randperm(6).numpy())
```
"""
return Tensor.rand(n, device=device, **kwargs).argsort().cast(dtype)
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
"""
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(20, replacement=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(3, replacement=False).numpy())
```
"""
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
weight = self.unsqueeze(0) if self.ndim == 1 else self
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
if replacement or num_samples == 1:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
else:
# EfraimidisSpirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass *****
def backward(self, gradient:Tensor|None=None) -> Tensor:
@ -814,49 +530,9 @@ class Tensor(RandMixin):
# ***** movement ops *****
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, op=op, arg=arg)
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
def __getitem__(self, indices) -> Tensor:
"""
Retrieves a sub-tensor using indexing.
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
Examples:
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(12).reshape(3, 4)
print(t.numpy())
```
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
```python exec="true" source="above" session="tensor" result="python"
print(t[1, 2].numpy())
```
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
```python exec="true" source="above" session="tensor" result="python"
print(t[0:2, ::2].numpy())
```
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
```python exec="true" source="above" session="tensor" result="python"
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
```
- `None` Indexing: Add a new dimension to the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(t[:, None].shape)
```
NOTE: Out-of-bounds indexing results in a value of `0`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t[Tensor([4, 3, 2])].numpy())
```
"""
return super().__getitem__(indices)
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
# raise if mutation would diverge from eager (allow only pure views of a realized buffer; exclude +=/-= RHS via v_uop/v_bw)
@ -887,68 +563,6 @@ class Tensor(RandMixin):
def __delitem__(self, indices) -> None:
raise TypeError("Tensor does not support deleting items")
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
"""
Selects elements from `self` based on the boolean `mask`.
With `size=None` (default), output length equals the number of `True` values (not jittable).
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
print(t.numpy())
print(mask.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
```
"""
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
mask_cumsum = mask.cumsum()
if size is None:
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
counts = Tensor.zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
return (Tensor.arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor:
"""
Returns the indices of the elements that are non-zero.
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 0, 2, 0, 3])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0], [0, 2]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero(size=3, fill_value=-1).numpy())
```
"""
if self.ndim == 0:
return Tensor.zeros(size if size is not None else int((self != 0).item()), 0, dtype=dtypes.int32, device=self.device)
mask = (self != 0).flatten()
indices = Tensor.stack(*[Tensor.arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
for i, s in enumerate(self.shape)], dim=-1)
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
# ***** reduce ops *****
def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
@ -1054,11 +668,11 @@ class Tensor(RandMixin):
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
# compute 6x6 winograd tiles: GgGt, BtdB
# compute 6x6 winograd tiles: GgGt, BtdB. contiguous so the transforms are materialized once
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
@ -1105,14 +719,6 @@ class Tensor(RandMixin):
if IMAGE: return self.image_dot(w, dtype)
return super().dot(w, dtype)
# ***** unary ops *****
def contiguous(self, *args, **kwargs) -> Tensor:
"""
Returns a contiguous tensor.
"""
return self._apply_uop(UOp.contiguous, extra_args=args, **kwargs)
# ***** broadcasted elementwise ops *****
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
@ -1172,80 +778,8 @@ class Tensor(RandMixin):
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)))
# ***** functional nn ops *****
def dropout(self, p=0.5) -> Tensor:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `Tensor.training` is `True`.
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Tensor.train():
print(t.dropout().numpy())
```
"""
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
if not Tensor.training or p == 0: return self
if p == 1: return self.const_like(0)
return (Tensor.rand_like(self, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if enable_gqa:
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
q = self
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
# ***** cast ops *****
def cast(self, dtype:DTypeLike) -> Tensor:
"""
Casts `self` to the given `dtype`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.uint8)
print(t.dtype, t.numpy())
```
"""
return self if self.dtype == (dt:=to_dtype(dtype)) else self._apply_uop(UOp.cast, dtype=dt)
def bitcast(self, dtype:DTypeLike) -> Tensor:
"""
Bitcasts `self` to the given `dtype` of the same itemsize.

View file

@ -454,7 +454,8 @@ def floormod_to_mod(a:UOp, b:UOp) -> UOp:
powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
# these are rewrites that make things simpler
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
@ -463,6 +464,11 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
return PatternMatcher(pat)
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
pat: list[tuple[UPat, Callable]] = []
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())]
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)

View file

@ -560,12 +560,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
@staticmethod
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, unique=True) -> UOp:
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
ret = UOp(Ops.CONST, dt, arg=dt.const(Invalid),
src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device))))
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@staticmethod
@ -797,7 +791,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
return self.src[0].addrspace
if self.op in GroupOp.Movement: return self.src[0].addrspace
if self.op in {Ops.STACK, Ops.PTRCAT, Ops.WMMA} or self.op in GroupOp.Elementwise:
if self.op in {Ops.STACK, Ops.WMMA} or self.op in GroupOp.Elementwise:
ad = [x.addrspace for x in self.src if x.addrspace is not None]
if not len(ad) or not all_same(ad): return None
return ad[0]
@ -925,12 +919,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# *** uop symbolic stuff ***
def is_increasing(self:UOp) -> bool:
# is f a monotonically increasing function regards its input
if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
return False # False if not sure
def const_factor(self) -> int:
"""largest known int that divides self"""
# TODO: for negatives it's not the largest
@ -1064,9 +1052,9 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
return ret
def placeholder_like(self, slot:int):
def placeholder_like(self, slot:int, addrspace=AddrSpace.GLOBAL):
assert all_int(self.shape), "no placeholder-like on symbolic shape"
return UOp.placeholder(self.max_shard_shape, self.dtype, slot)
return UOp.placeholder(self.max_shard_shape, self.dtype, slot, addrspace)
# set is store+end+after
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
@ -1196,8 +1184,9 @@ python_alu: dict[Ops, Callable] = {
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
if dtype.count > 1:
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
if any(isinstance(x, tuple) for x in operands):
count = max(len(x) for x in operands if isinstance(x, tuple))
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(count)])
if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid
alu = python_alu[op](*operands)
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
@ -1678,6 +1667,8 @@ pm_lower_index_dtype = PatternMatcher([
lambda var,val: var.bind(val).cast(dtypes.weakint)),
# remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("slen", dtypes.ints).cast(),), name="shrink"),
lambda shrink,buf,idx,slen: shrink.replace(src=(buf,idx,slen))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
# remove hanging casts for images

View file

@ -1,5 +1,5 @@
from typing import cast
from tinygrad.dtype import dtypes, Invalid
from tinygrad.dtype import dtypes
from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
from tinygrad.helpers import strip_parens
@ -77,8 +77,6 @@ def render_marg(ctx,x:UOp):
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"),
lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),

View file

@ -126,9 +126,6 @@ spec_tensor = PatternMatcher([
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
# CONST with a UNIQUE and DEVICE
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid),
# BUFFER
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
@ -152,8 +149,7 @@ spec_tensor = PatternMatcher([
# movement ops
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat())), lambda: True),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"),
lambda x: x.src[1].dtype.count == x.src[2].dtype.count),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"), lambda x: x.src[1].shape == x.src[2].shape),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
# REDUCE has arg=(op, axis_tuple), src[1:] are ranges after lowering

View file

@ -121,6 +121,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([
# TODO: combine this with "# rules for threefry" below
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
((UPat.var("x") & UPat.cvar("mask")) // UPat.cvar("c"),
lambda x,mask,c: x // c.arg if c.arg > 0 and c.arg & (c.arg-1) == 0 and mask.arg | (c.arg-1) == -1 else None),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# ** constant folding **
@ -160,6 +162,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
# ** simple where folding **
# a conditional with the same results either way is a noop, also fold const conditionals
@ -167,6 +170,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
# STACK on INDEX CONST (TODO: remove all the GEP crap)
(UPat(Ops.STACK, src=UPat(Ops.INDEX, src=(UPat.var("src"), UPat(Ops.CONST))), name="stk"),
lambda src,stk: src if stk.shape == src.shape and list(range(len(stk.src))) == [x.src[1].arg for x in stk.src] else None),
])
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********

View file

@ -121,7 +121,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)