mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
clean_load
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a214c4499 |
||
|
|
3e16109eb6 | ||
|
|
f79a7fc7c6 |
||
|
|
3526f8272b | ||
|
|
e143904deb |
65 changed files with 985 additions and 1176 deletions
89
.github/workflows/test.yml
vendored
89
.github/workflows/test.yml
vendored
|
|
@ -133,26 +133,46 @@ jobs:
|
||||||
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
|
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
|
||||||
- name: Test IMAGE support
|
- name: Test IMAGE support
|
||||||
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
|
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
|
||||||
- name: Test emulated tensor cores
|
- name: Test emulated METAL tensor cores
|
||||||
env:
|
env:
|
||||||
DEBUG: 2
|
DEV: 'PYTHON::METAL'
|
||||||
N: 64
|
|
||||||
CNT: 1
|
|
||||||
SHOULD_USE_TC: 1
|
|
||||||
run: |
|
run: |
|
||||||
parallel -k --link --tagstring '[{1}]' '{2} python3 ./extra/gemm/simple_matmul.py' \
|
DEBUG=2 python3 test/backend/test_ops.py TestOps.test_big_gemm
|
||||||
::: metal gfx950 gfx1100 gfx1100_acchalf gfx1201 gfx1201_acchalf sm_75 sm_80_half sm_80_tf32 \
|
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||||
::: 'DEV=PYTHON::METAL' 'DEV=PYTHON::gfx950 HALF=1 ACC_HALF=0' \
|
- name: Test emulated AMD tensor cores
|
||||||
'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=1 ATOL=1e-3' \
|
env:
|
||||||
'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=1 ATOL=1e-3' \
|
DEV: 'PYTHON::gfx1100'
|
||||||
'DEV=PYTHON::sm_75 HALF=1' 'DEV=PYTHON::sm_80 HALF=1' 'DEV=PYTHON::sm_80 ALLOW_TF32=1'
|
|
||||||
- name: Run additional tensor core tests
|
|
||||||
run: |
|
run: |
|
||||||
DEV=PYTHON::METAL python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
|
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||||
DEV=PYTHON::gfx1100 python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
|
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||||
DEV=PYTHON::gfx950 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||||
DEV=PYTHON::gfx1201 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||||
|
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||||
|
- name: Test emulated AMD MFMA tensor cores
|
||||||
|
env:
|
||||||
|
DEV: 'PYTHON::gfx950'
|
||||||
|
run: |
|
||||||
|
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||||
|
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||||
|
- name: Test emulated AMD RDNA4 tensor cores
|
||||||
|
env:
|
||||||
|
DEV: 'PYTHON::gfx1201'
|
||||||
|
run: |
|
||||||
|
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||||
|
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||||
|
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||||
|
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||||
|
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||||
|
- name: Test emulated CUDA tensor cores
|
||||||
|
run: |
|
||||||
|
DEBUG=2 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
|
||||||
|
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
|
||||||
|
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
|
||||||
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||||
|
- name: Test device flop counts
|
||||||
|
run: |
|
||||||
|
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||||
|
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||||
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||||
|
|
||||||
linter:
|
linter:
|
||||||
|
|
@ -247,6 +267,13 @@ jobs:
|
||||||
run: python3 test/external/external_benchmark_schedule.py
|
run: python3 test/external/external_benchmark_schedule.py
|
||||||
- name: Run process replay tests
|
- name: Run process replay tests
|
||||||
uses: ./.github/actions/process-replay
|
uses: ./.github/actions/process-replay
|
||||||
|
- name: Regen dataset on test_tiny
|
||||||
|
run: |
|
||||||
|
test/external/process_replay/reset.py
|
||||||
|
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
|
||||||
|
python extra/optimization/extract_dataset.py
|
||||||
|
gzip -c /tmp/sops > extra/datasets/sops.gz
|
||||||
|
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
|
||||||
- name: Repo line count < 25000 lines
|
- name: Repo line count < 25000 lines
|
||||||
run: MAX_LINE_COUNT=25000 python sz.py
|
run: MAX_LINE_COUNT=25000 python sz.py
|
||||||
|
|
||||||
|
|
@ -311,6 +338,31 @@ jobs:
|
||||||
- name: Run process replay tests
|
- name: Run process replay tests
|
||||||
uses: ./.github/actions/process-replay
|
uses: ./.github/actions/process-replay
|
||||||
|
|
||||||
|
testgpumisc:
|
||||||
|
name: CL Misc tests
|
||||||
|
runs-on: *linux
|
||||||
|
timeout-minutes: 10
|
||||||
|
steps:
|
||||||
|
- name: Checkout Code
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
- name: Setup Environment
|
||||||
|
uses: ./.github/actions/setup-tinygrad
|
||||||
|
with:
|
||||||
|
key: gen-dataset
|
||||||
|
deps: testing
|
||||||
|
opencl: 'true'
|
||||||
|
- name: Generate Dataset
|
||||||
|
run: DEV=CL extra/optimization/generate_dataset.sh
|
||||||
|
- name: Run Kernel Count Test
|
||||||
|
run: DEV=CL python -m pytest -n=auto test/external/external_test_opt.py
|
||||||
|
- name: Run fused optimizer tests
|
||||||
|
run: DEV=CL FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py test/backend/test_optim.py -k "not muon"
|
||||||
|
- name: Upload artifact
|
||||||
|
uses: actions/upload-artifact@v7
|
||||||
|
with:
|
||||||
|
name: sops.gz
|
||||||
|
path: /tmp/sops.gz
|
||||||
|
|
||||||
testopenpilot:
|
testopenpilot:
|
||||||
name: openpilot Compile Tests
|
name: openpilot Compile Tests
|
||||||
runs-on: *linux
|
runs-on: *linux
|
||||||
|
|
@ -327,7 +379,7 @@ jobs:
|
||||||
llvm: 'true'
|
llvm: 'true'
|
||||||
- name: Test openpilot model kernel count and gate usage
|
- name: Test openpilot model kernel count and gate usage
|
||||||
run: |
|
run: |
|
||||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||||
- name: Test openpilot CL compile fp32 (test correctness)
|
- name: Test openpilot CL compile fp32 (test correctness)
|
||||||
run: |
|
run: |
|
||||||
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
|
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
|
||||||
|
|
@ -370,6 +422,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
key: optim
|
key: optim
|
||||||
deps: testing
|
deps: testing
|
||||||
|
pydeps: "tensorflow==2.19"
|
||||||
opencl: 'true'
|
opencl: 'true'
|
||||||
#- name: Test Optimization Helpers
|
#- name: Test Optimization Helpers
|
||||||
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
|
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
|
||||||
|
|
@ -378,7 +431,7 @@ jobs:
|
||||||
- name: Test Beam Search
|
- name: Test Beam Search
|
||||||
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
|
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
|
||||||
- name: Test MLPerf stuff
|
- name: Test MLPerf stuff
|
||||||
run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
||||||
- name: DEV=NULL beautiful_mnist_multigpu
|
- name: DEV=NULL beautiful_mnist_multigpu
|
||||||
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
|
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
|
||||||
- name: Test Bert training
|
- name: Test Bert training
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
|
||||||
Throw in an optimizer, a data loader, and some compute, and you have all you need.
|
Throw in an optimizer, a data loader, and some compute, and you have all you need.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from tinygrad import Tensor, nn, Context
|
from tinygrad import Tensor, nn
|
||||||
|
|
||||||
class LinearNet:
|
class LinearNet:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
|
||||||
|
|
||||||
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
|
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
|
||||||
|
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
loss = model(x).sparse_categorical_crossentropy(y).backward()
|
loss = model(x).sparse_categorical_crossentropy(y).backward()
|
||||||
|
|
|
||||||
|
|
@ -165,14 +165,13 @@ from extra.datasets import fetch_mnist
|
||||||
Now we have everything we need to start training our neural network.
|
Now we have everything we need to start training our neural network.
|
||||||
We will be training for 1000 steps with a batch size of 64.
|
We will be training for 1000 steps with a batch size of 64.
|
||||||
|
|
||||||
We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
|
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training.
|
||||||
Upon exit, the flag is restored to its previous value by the context manager.
|
Upon exit, the flag is restored to its previous value by the context manager.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from tinygrad import Context
|
|
||||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||||
|
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
for step in range(1000):
|
for step in range(1000):
|
||||||
# random sample a batch
|
# random sample a batch
|
||||||
samp = np.random.randint(0, X_train.shape[0], size=(64))
|
samp = np.random.randint(0, X_train.shape[0], size=(64))
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
import time
|
import time
|
||||||
from tinygrad import Tensor, TinyJit, nn, Context
|
from tinygrad import Tensor, TinyJit, nn
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from tinygrad.helpers import trange
|
from tinygrad.helpers import trange
|
||||||
import numpy as np # TODO: remove numpy import
|
import numpy as np # TODO: remove numpy import
|
||||||
|
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
log_dist, value = model(x)
|
log_dist, value = model(x)
|
||||||
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()
|
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ if __name__ == "__main__":
|
||||||
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
@Context(TRAINING=1)
|
@Tensor.train()
|
||||||
def train_step(idxs:Tensor) -> Tensor:
|
def train_step(idxs:Tensor) -> Tensor:
|
||||||
X, Y = X_train[idxs], Y_train[idxs]
|
X, Y = X_train[idxs], Y_train[idxs]
|
||||||
if len(GPUS) > 1:
|
if len(GPUS) > 1:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
|
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function
|
||||||
from tinygrad.helpers import getenv, colored, trange
|
from tinygrad.helpers import getenv, colored, trange
|
||||||
from tinygrad.nn.datasets import mnist
|
from tinygrad.nn.datasets import mnist
|
||||||
|
|
||||||
|
|
@ -19,7 +19,7 @@ class Model:
|
||||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
@Context(TRAINING=1)
|
@Tensor.train()
|
||||||
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
|
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||||
from typing import List, Callable
|
from typing import List, Callable
|
||||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
|
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
|
||||||
from tinygrad.helpers import getenv, colored, trange
|
from tinygrad.helpers import getenv, colored, trange
|
||||||
from tinygrad.nn.datasets import mnist
|
from tinygrad.nn.datasets import mnist
|
||||||
|
|
||||||
|
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def train_step() -> Tensor:
|
def train_step() -> Tensor:
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||||
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0
|
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
|
from tinygrad import nn, Tensor, dtypes, Device, TinyJit
|
||||||
from tinygrad.helpers import getenv, trange, partition
|
from tinygrad.helpers import getenv, trange, partition
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
|
|
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
||||||
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
|
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
@Context(TRAINING=1)
|
@Tensor.train()
|
||||||
def microbatch():
|
def microbatch():
|
||||||
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
|
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
|
||||||
for t in params: t.grad = None
|
for t in params: t.grad = None
|
||||||
|
|
|
||||||
|
|
@ -359,7 +359,7 @@ def train_cifar():
|
||||||
i = 0
|
i = 0
|
||||||
eval_acc_pct = 0.0
|
eval_acc_pct = 0.0
|
||||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
st = time.monotonic()
|
st = time.monotonic()
|
||||||
while i <= STEPS:
|
while i <= STEPS:
|
||||||
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):
|
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os, math, time
|
import os, math, time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
|
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -177,7 +177,7 @@ if __name__ == "__main__":
|
||||||
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
|
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
@Context(TRAINING=1)
|
@Tensor.train()
|
||||||
def step(x:Tensor, y:Tensor) -> Tensor:
|
def step(x:Tensor, y:Tensor) -> Tensor:
|
||||||
_, loss = model(x, y)
|
_, loss = model(x, y)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
@ -204,3 +204,4 @@ if __name__ == "__main__":
|
||||||
top_k = 40
|
top_k = 40
|
||||||
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
||||||
print(decode(y[0].tolist()))
|
print(decode(y[0].tolist()))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# much taken from https://github.com/cloneofsimo/minRF
|
# much taken from https://github.com/cloneofsimo/minRF
|
||||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
|
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
|
||||||
from tinygrad.helpers import getenv, trange
|
from tinygrad.helpers import getenv, trange
|
||||||
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
|
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
|
||||||
|
|
||||||
|
|
@ -135,7 +135,7 @@ if __name__ == "__main__":
|
||||||
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
|
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
@Context(TRAINING=1)
|
@Tensor.train()
|
||||||
def train_step():
|
def train_step():
|
||||||
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
|
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
|
||||||
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])
|
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])
|
||||||
|
|
|
||||||
|
|
@ -358,7 +358,7 @@ def eval_stable_diffusion():
|
||||||
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
|
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
|
||||||
return batch, unpadded_bs
|
return batch, unpadded_bs
|
||||||
|
|
||||||
@Context(TRAINING=0)
|
@Tensor.train(mode=False)
|
||||||
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
|
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
|
||||||
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
|
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
|
||||||
# Eval is divided into 5 jits, one per model
|
# Eval is divided into 5 jits, one per model
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
|
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
|
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
|
||||||
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
|
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
|
||||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
|
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
|
||||||
|
|
@ -614,7 +614,7 @@ def train_retinanet():
|
||||||
|
|
||||||
if getenv("RESET_STEP", 1): _train_step.reset()
|
if getenv("RESET_STEP", 1): _train_step.reset()
|
||||||
|
|
||||||
with Context(TRAINING=0):
|
with Tensor.train(mode=False):
|
||||||
if not RUNMLPERF:
|
if not RUNMLPERF:
|
||||||
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
|
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
|
||||||
else:
|
else:
|
||||||
|
|
@ -784,7 +784,7 @@ def train_unet3d():
|
||||||
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
|
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
@Context(TRAINING=1)
|
@Tensor.train()
|
||||||
def train_step(model, x, y):
|
def train_step(model, x, y):
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
|
|
||||||
|
|
@ -795,7 +795,7 @@ def train_unet3d():
|
||||||
optim.step()
|
optim.step()
|
||||||
return loss.realize()
|
return loss.realize()
|
||||||
|
|
||||||
@Context(TRAINING=0)
|
@Tensor.train(mode=False)
|
||||||
def eval_step(model, x, y):
|
def eval_step(model, x, y):
|
||||||
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
|
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
|
||||||
y_hat, y = Tensor(y_hat), Tensor(y)
|
y_hat, y = Tensor(y_hat), Tensor(y)
|
||||||
|
|
@ -1490,7 +1490,7 @@ def train_llama3():
|
||||||
return lr_cpu, grad_norm_cpu
|
return lr_cpu, grad_norm_cpu
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
@Context(TRAINING=0)
|
@Tensor.train(False)
|
||||||
def eval_step(tokens:Tensor):
|
def eval_step(tokens:Tensor):
|
||||||
if is_dp: tokens = tokens.to(None).shard(device, 0)
|
if is_dp: tokens = tokens.to(None).shard(device, 0)
|
||||||
if is_mp: tokens = tokens.shard(device)
|
if is_mp: tokens = tokens.shard(device)
|
||||||
|
|
@ -1803,7 +1803,7 @@ if __name__ == "__main__":
|
||||||
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
|
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
|
||||||
else: bench_log_manager = contextlib.nullcontext()
|
else: bench_log_manager = contextlib.nullcontext()
|
||||||
|
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
|
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
|
||||||
nm = f"train_{m}"
|
nm = f"train_{m}"
|
||||||
if nm in globals():
|
if nm in globals():
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
||||||
|
|
||||||
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
||||||
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
|
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
|
||||||
grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]:
|
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
|
||||||
if not fp8:
|
if not fp8:
|
||||||
if ASM_GEMM:
|
if ASM_GEMM:
|
||||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||||
|
|
@ -47,14 +47,12 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
|
||||||
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
|
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
|
||||||
if MXFP8:
|
if MXFP8:
|
||||||
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
|
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
|
||||||
if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d)
|
x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
|
||||||
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
|
|
||||||
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
|
|
||||||
if can_use_asm_gemm(x_q, w.T):
|
if can_use_asm_gemm(x_q, w.T):
|
||||||
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
|
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
|
||||||
mx_w_stored=True).reshape(*l_shape, w.shape[0])
|
mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0])
|
||||||
else:
|
else:
|
||||||
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1])
|
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1])
|
||||||
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
|
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
|
||||||
return out, (amax_x.detach() if amax_x is not None else None), x_q
|
return out, (amax_x.detach() if amax_x is not None else None), x_q
|
||||||
if x_fp8 is None:
|
if x_fp8 is None:
|
||||||
|
|
@ -128,8 +126,10 @@ class FlatTransformer:
|
||||||
|
|
||||||
# FeedForward
|
# FeedForward
|
||||||
if SPLIT_W13:
|
if SPLIT_W13:
|
||||||
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim)
|
if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim)
|
||||||
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim)
|
else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02)
|
||||||
|
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
|
||||||
|
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
|
||||||
else:
|
else:
|
||||||
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
|
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
|
||||||
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
|
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
|
||||||
|
|
@ -160,7 +160,7 @@ class FlatTransformer:
|
||||||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
|
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
|
||||||
if w is None:
|
if w is None:
|
||||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
|
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize()
|
||||||
if MXFP8:
|
if MXFP8:
|
||||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||||
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
|
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
|
||||||
|
|
@ -216,15 +216,8 @@ class FlatTransformer:
|
||||||
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
|
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
|
||||||
amaxs.append(new_amax)
|
amaxs.append(new_amax)
|
||||||
saves.extend([*s, x_w3])
|
saves.extend([*s, x_w3])
|
||||||
if FUSED_SILU_W13 and MXFP8:
|
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
|
||||||
from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
|
grad_amax_state=kwargs["grad_amax_xout"])
|
||||||
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
|
|
||||||
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
|
|
||||||
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
|
|
||||||
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
|
|
||||||
else:
|
|
||||||
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
|
|
||||||
grad_amax_state=kwargs["grad_amax_xout"])
|
|
||||||
amaxs.append(new_amax)
|
amaxs.append(new_amax)
|
||||||
saves.extend([*s, out])
|
saves.extend([*s, out])
|
||||||
else:
|
else:
|
||||||
|
|
@ -254,30 +247,20 @@ class FlatTransformer:
|
||||||
for v in get_parameters(self): v.shard_(device, axis=None)
|
for v in get_parameters(self): v.shard_(device, axis=None)
|
||||||
else:
|
else:
|
||||||
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
||||||
def _shard_fp8(name:str, axis:int, std:float=0.02):
|
def _shard_fp8(name:str, axis:int):
|
||||||
w = getattr(self, name)
|
getattr(self, name).shard_(device, axis=axis)
|
||||||
if MXFP8:
|
scale_axis = axis if MXFP8 else (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
||||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||||
w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std
|
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||||
w_q, w_e8, _ = quantize_mxfp8(w_bf16)
|
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||||
w.replace(w_q)
|
|
||||||
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
|
|
||||||
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
|
|
||||||
else:
|
|
||||||
w.shard_(device, axis=axis)
|
|
||||||
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
|
||||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
|
||||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
|
||||||
Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
|
||||||
sstd = 0.02 / math.sqrt(2 * self.n_layers)
|
|
||||||
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
|
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
|
||||||
_shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in
|
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
|
||||||
if SPLIT_W13:
|
if SPLIT_W13:
|
||||||
_shard_fp8("w1", 1)
|
_shard_fp8("w1", 1)
|
||||||
_shard_fp8("w3", 1)
|
_shard_fp8("w3", 1)
|
||||||
else:
|
else:
|
||||||
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
|
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
|
||||||
_shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in
|
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
|
||||||
self.attention_norm.shard_(device, axis=None).realize()
|
self.attention_norm.shard_(device, axis=None).realize()
|
||||||
self.ffn_norm.shard_(device, axis=None).realize()
|
self.ffn_norm.shard_(device, axis=None).realize()
|
||||||
self.norm.weight.shard_(device, axis=None).realize()
|
self.norm.weight.shard_(device, axis=None).realize()
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torchvision.utils import make_grid, save_image
|
from torchvision.utils import make_grid, save_image
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import trange, Context
|
from tinygrad.helpers import trange
|
||||||
from tinygrad.nn import optim
|
from tinygrad.nn import optim
|
||||||
from tinygrad.nn.datasets import mnist
|
from tinygrad.nn.datasets import mnist
|
||||||
|
|
||||||
|
|
@ -86,7 +86,7 @@ if __name__ == "__main__":
|
||||||
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
||||||
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
|
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
|
||||||
# training loop
|
# training loop
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
for epoch in (t := trange(epochs)):
|
for epoch in (t := trange(epochs)):
|
||||||
loss_g, loss_d = 0.0, 0.0
|
loss_g, loss_d = 0.0, 0.0
|
||||||
for _ in range(n_steps):
|
for _ in range(n_steps):
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
# - symbolic removal
|
# - symbolic removal
|
||||||
|
|
||||||
from examples.beautiful_mnist import Model
|
from examples.beautiful_mnist import Model
|
||||||
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
|
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable
|
||||||
from tinygrad.nn.datasets import mnist
|
from tinygrad.nn.datasets import mnist
|
||||||
from tinygrad.helpers import trange
|
from tinygrad.helpers import trange
|
||||||
|
|
||||||
|
|
@ -26,7 +26,7 @@ if __name__ == "__main__":
|
||||||
X_samp, Y_samp = X_train[samples], Y_train[samples]
|
X_samp, Y_samp = X_train[samples], Y_train[samples]
|
||||||
print("*** got samples")
|
print("*** got samples")
|
||||||
|
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
"""
|
"""
|
||||||
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
|
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
|
||||||
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)
|
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
|
||||||
|
|
||||||
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
|
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
|
||||||
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||||
acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
|
acc = acc.after(acc.store(acc.zeros_like()))
|
||||||
|
|
||||||
if use_wmma:
|
if use_wmma:
|
||||||
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)
|
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)
|
||||||
|
|
|
||||||
|
|
@ -2674,14 +2674,14 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
|
||||||
|
|
||||||
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||||
# 1x32 block scaling along the last axis
|
# 1x32 block scaling along the last axis
|
||||||
*batch, K = x.shape
|
rows, K = x.shape
|
||||||
scale_K = K // 32
|
scale_K, k_iters = K // 32, K // 128
|
||||||
amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
|
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
|
||||||
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
|
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
|
||||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
|
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K)
|
||||||
x_scaled = x.float() * qscale
|
x_scaled = x.float() * qscale
|
||||||
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
|
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
|
||||||
return x_clamped.cast(FP8_DTYPE), e8, (mx_pack(e8) if len(batch) == 1 else None)
|
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8)
|
||||||
|
|
||||||
def mx_pack(e8:Tensor) -> Tensor:
|
def mx_pack(e8:Tensor) -> Tensor:
|
||||||
rows, scale_K = e8.shape
|
rows, scale_K = e8.shape
|
||||||
|
|
|
||||||
|
|
@ -143,17 +143,14 @@ def make_getaddr(u, device=None):
|
||||||
def make_ins(op, *srcs):
|
def make_ins(op, *srcs):
|
||||||
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
||||||
|
|
||||||
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
|
|
||||||
dt = dtype or val.dtype
|
|
||||||
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
|
|
||||||
|
|
||||||
def make_cmdbuf(lin, devs, tag):
|
def make_cmdbuf(lin, devs, tag):
|
||||||
blob, patches = b'', []
|
blob, patches = b'', []
|
||||||
for s in (s for ins in lin.src for s in ins.src):
|
for s in (s for ins in lin.src for s in ins.src):
|
||||||
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
||||||
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
||||||
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
|
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
|
||||||
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches])
|
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
|
||||||
|
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
|
||||||
|
|
||||||
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
|
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
|
||||||
|
|
||||||
|
|
@ -214,11 +211,15 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
|
||||||
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
|
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
|
||||||
|
|
||||||
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
|
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
|
||||||
(data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device)
|
data, info = prg.arg
|
||||||
buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
|
||||||
patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \
|
dtypes.uint64) for i,gi in enumerate(info.globals)] \
|
||||||
+ [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
||||||
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
|
|
||||||
|
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
||||||
|
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
|
||||||
|
|
||||||
|
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
|
||||||
|
|
||||||
pm_prep_runtime = PatternMatcher([
|
pm_prep_runtime = PatternMatcher([
|
||||||
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
|
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
|
||||||
|
|
@ -306,7 +307,7 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||||
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
|
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
|
||||||
|
|
||||||
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
|
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
|
||||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}),)))
|
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
|
||||||
return linear.replace(src=tuple(new_src))
|
return linear.replace(src=tuple(new_src))
|
||||||
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
||||||
|
|
||||||
|
|
@ -531,9 +532,9 @@ pm_resolve_patches = PatternMatcher([
|
||||||
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
||||||
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
||||||
|
|
||||||
# shrink on slice is shrink on base at offset
|
# index on slice is index
|
||||||
(UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"),
|
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
|
||||||
lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))),
|
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
|
||||||
|
|
||||||
# getaddr
|
# getaddr
|
||||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||||
|
|
@ -541,8 +542,8 @@ pm_resolve_patches = PatternMatcher([
|
||||||
|
|
||||||
# folders
|
# folders
|
||||||
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||||
(UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast()
|
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
||||||
.store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store),
|
fold_const_store),
|
||||||
]) + symbolic_simple
|
]) + symbolic_simple
|
||||||
|
|
||||||
# *****************
|
# *****************
|
||||||
|
|
|
||||||
|
|
@ -1,104 +0,0 @@
|
||||||
import functools
|
|
||||||
from tinygrad import Tensor, dtypes
|
|
||||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
|
||||||
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
|
|
||||||
|
|
||||||
BLK = 32
|
|
||||||
PACK = 4
|
|
||||||
LOG2E = 1.4426950408889634
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
|
|
||||||
rows, K = x_w1.shape
|
|
||||||
scale_K = K // BLK
|
|
||||||
n_elems = rows * K
|
|
||||||
n_super = n_elems // (BLK * PACK)
|
|
||||||
sk4 = scale_K // PACK
|
|
||||||
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
|
|
||||||
nwg = n_super // THREADS_PER_WG
|
|
||||||
|
|
||||||
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
|
|
||||||
fp8_out = fp8_out.reshape(n_elems)
|
|
||||||
e8_out = e8_out.reshape(rows * scale_K)
|
|
||||||
si_out = si_out.reshape(sk4 * rows)
|
|
||||||
|
|
||||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
|
||||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
|
||||||
sb = UOp.range(PACK, 2, AxisType.UNROLL)
|
|
||||||
lane = UOp.range(BLK, 3, AxisType.UNROLL)
|
|
||||||
|
|
||||||
super_idx = wg * THREADS_PER_WG + tid
|
|
||||||
idx = super_idx * (BLK * PACK) + sb * BLK + lane
|
|
||||||
|
|
||||||
w1 = x_w1[idx].cast(dtypes.float)
|
|
||||||
w3 = x_w3[idx].cast(dtypes.float)
|
|
||||||
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
|
|
||||||
act = w1 * sig * w3
|
|
||||||
abs_a = (act < 0.0).where(-act, act)
|
|
||||||
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
|
|
||||||
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
|
|
||||||
qscale = (127.0 - e8f).exp2()
|
|
||||||
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
|
|
||||||
e8u8 = e8f.cast(dtypes.uint8)
|
|
||||||
|
|
||||||
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
|
|
||||||
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
|
|
||||||
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
|
|
||||||
row, col4 = super_idx // sk4, super_idx % sk4
|
|
||||||
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
|
|
||||||
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
|
|
||||||
rows, K = x_w1.shape
|
|
||||||
scale_K = K // BLK
|
|
||||||
n_elems = rows * K
|
|
||||||
VEC = 8
|
|
||||||
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
|
|
||||||
nwg = n_elems // (THREADS_PER_WG * VEC)
|
|
||||||
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
|
|
||||||
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
|
|
||||||
|
|
||||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
|
||||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
|
||||||
lane = UOp.range(VEC, 2, AxisType.UNROLL)
|
|
||||||
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
|
|
||||||
|
|
||||||
e8v = e8[idx // BLK].cast(dtypes.float)
|
|
||||||
qscale = (127.0 - e8v).exp2()
|
|
||||||
ga = grad_aq[idx].cast(dtypes.float) * qscale
|
|
||||||
w1 = x_w1[idx].cast(dtypes.float)
|
|
||||||
w3 = x_w3[idx].cast(dtypes.float)
|
|
||||||
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
|
|
||||||
s = w1 * sig
|
|
||||||
sprime = sig * (1.0 + w1 * (1.0 - sig))
|
|
||||||
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
|
|
||||||
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
|
|
||||||
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
|
|
||||||
|
|
||||||
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
|
|
||||||
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
|
|
||||||
device = x_w1.device
|
|
||||||
rows, K = x_w1.shape
|
|
||||||
axis = x_w1.axis if isinstance(device, tuple) else None
|
|
||||||
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
|
|
||||||
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
|
|
||||||
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
|
|
||||||
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
|
|
||||||
fxn=_custom_silu_mul_bwd_mxfp8)
|
|
||||||
return (None, None, None, gx1.uop, gx3.uop)
|
|
||||||
|
|
||||||
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
|
||||||
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
|
|
||||||
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
|
|
||||||
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
|
|
||||||
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
|
|
||||||
rows, K = x_w1.shape
|
|
||||||
scale_K = K // BLK
|
|
||||||
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
|
|
||||||
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
|
|
||||||
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
|
|
||||||
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
|
|
||||||
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
|
|
||||||
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
|
|
||||||
return fp8_out, e8_out, si_out
|
|
||||||
|
|
@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
|
||||||
step = THREADS_PER_WG // 2
|
step = THREADS_PER_WG // 2
|
||||||
while step:
|
while step:
|
||||||
active = tid < step
|
active = tid < step
|
||||||
other = lds[(tid + step).valid(active)].load()
|
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
|
||||||
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
|
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
|
||||||
step //= 2
|
step //= 2
|
||||||
|
|
||||||
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])
|
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import trange, Context
|
from tinygrad.helpers import trange
|
||||||
from tinygrad.engine.jit import TinyJit
|
from tinygrad.engine.jit import TinyJit
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,7 +22,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
|
||||||
|
|
||||||
if allow_jit: train_step = TinyJit(train_step)
|
if allow_jit: train_step = TinyJit(train_step)
|
||||||
|
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
losses, accuracies = [], []
|
losses, accuracies = [], []
|
||||||
for i in (t := trange(steps, disable=None)):
|
for i in (t := trange(steps, disable=None)):
|
||||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||||
|
|
@ -55,3 +55,4 @@ def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=Fal
|
||||||
acc, Y_test_pred = numpy_eval(Y_test, num_classes)
|
acc, Y_test_pred = numpy_eval(Y_test, num_classes)
|
||||||
print("test set accuracy is %f" % acc)
|
print("test set accuracy is %f" % acc)
|
||||||
return (acc, Y_test_pred) if return_predict else acc
|
return (acc, Y_test_pred) if return_predict else acc
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tinygrad import Tensor, dtypes, nn, Context
|
from tinygrad import Tensor, dtypes, nn
|
||||||
from tinygrad.device import Device
|
from tinygrad.device import Device
|
||||||
from tinygrad.helpers import DEV
|
from tinygrad.helpers import DEV
|
||||||
from tinygrad.renderer.nir import NIRRenderer
|
from tinygrad.renderer.nir import NIRRenderer
|
||||||
|
|
@ -101,7 +101,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
|
||||||
# we don't need more of these
|
# we don't need more of these
|
||||||
|
|
||||||
def test_dropout_rate_one(self):
|
def test_dropout_rate_one(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
out = Tensor.ones(100).dropout(1.0)
|
out = Tensor.ones(100).dropout(1.0)
|
||||||
np.testing.assert_allclose(out.numpy(), np.zeros(100))
|
np.testing.assert_allclose(out.numpy(), np.zeros(100))
|
||||||
|
|
||||||
|
|
@ -109,7 +109,7 @@ class TestDropoutProbabilityEdgeCases(unittest.TestCase):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
torch.nn.functional.dropout(torch.ones(10), -0.1, True)
|
torch.nn.functional.dropout(torch.ones(10), -0.1, True)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
Tensor.ones(10).dropout(-0.1)
|
Tensor.ones(10).dropout(-0.1)
|
||||||
|
|
||||||
class TestInputValidation(unittest.TestCase):
|
class TestInputValidation(unittest.TestCase):
|
||||||
|
|
|
||||||
|
|
@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
||||||
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
|
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
|
||||||
assert num_loads <= 4, "more load uops than needed"
|
assert num_loads <= 4, "more load uops than needed"
|
||||||
assert num_loads >= 1, "expected at least one load uop"
|
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
|
||||||
|
|
||||||
@unittest.skip("this is handled at higher level now")
|
@unittest.skip("this is handled at higher level now")
|
||||||
def test_upcast_cse(self):
|
def test_upcast_cse(self):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from test.helpers import not_support_multi_device, needs_second_gpu, slow
|
||||||
@slow
|
@slow
|
||||||
class TestNN(unittest.TestCase):
|
class TestNN(unittest.TestCase):
|
||||||
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
|
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
|
||||||
with Context(TRAINING=training):
|
with Tensor.train(training):
|
||||||
szs = [4, 8, 16, 32]
|
szs = [4, 8, 16, 32]
|
||||||
for sz in szs:
|
for sz in szs:
|
||||||
# create in tinygrad
|
# create in tinygrad
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class TestStunning(unittest.TestCase):
|
||||||
X_samp, Y_samp = X_train[samples], Y_train[samples]
|
X_samp, Y_samp = X_train[samples], Y_train[samples]
|
||||||
vi = Variable('i', 0, samples.shape[0]-1)
|
vi = Variable('i', 0, samples.shape[0]-1)
|
||||||
with Context(SPLIT_REDUCEOP=0):
|
with Context(SPLIT_REDUCEOP=0):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
losses = []
|
losses = []
|
||||||
for i in range(samples.shape[0]):
|
for i in range(samples.shape[0]):
|
||||||
vib = vi.bind(i)
|
vib = vi.bind(i)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Tensor, Variable, GlobalCounters, Context
|
from tinygrad import Tensor, Variable, GlobalCounters
|
||||||
from tinygrad.uop.ops import sym_infer
|
from tinygrad.uop.ops import sym_infer
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from examples.gpt2 import Attention
|
from examples.gpt2 import Attention
|
||||||
|
|
@ -63,7 +63,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||||
self.test_attention(imin=4, imax=5, use_symbolic=True)
|
self.test_attention(imin=4, imax=5, use_symbolic=True)
|
||||||
|
|
||||||
def test_attention_training(self):
|
def test_attention_training(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
self.test_attention(dropout_p=0.0)
|
self.test_attention(dropout_p=0.0)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
# symbolic shape dropout is not supported
|
# symbolic shape dropout is not supported
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import unittest, copy, mmap, random, math, array
|
import unittest, copy, mmap, random, math, array
|
||||||
from tinygrad import Tensor, Device, dtypes, nn, Context
|
from tinygrad import Tensor, Device, dtypes, nn
|
||||||
from tinygrad.helpers import getenv, temp, mv_address
|
from tinygrad.helpers import getenv, temp, mv_address
|
||||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||||
from hypothesis import given, settings, strategies as strat
|
from hypothesis import given, settings, strategies as strat
|
||||||
|
|
@ -203,7 +203,7 @@ class TestTinygrad(unittest.TestCase):
|
||||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||||
|
|
||||||
def test_dropout(self):
|
def test_dropout(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
n, rate = 1_000_000, 0.1
|
n, rate = 1_000_000, 0.1
|
||||||
w = Tensor.ones(n).dropout(rate)
|
w = Tensor.ones(n).dropout(rate)
|
||||||
non_zeros = np.count_nonzero(w.numpy())
|
non_zeros = np.count_nonzero(w.numpy())
|
||||||
|
|
|
||||||
67
test/external/external_test_lr_schedule.py
vendored
67
test/external/external_test_lr_schedule.py
vendored
|
|
@ -1,67 +0,0 @@
|
||||||
import unittest, math
|
|
||||||
import numpy as np
|
|
||||||
from tinygrad.tensor import Tensor
|
|
||||||
from tinygrad.nn.optim import AdamW
|
|
||||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
|
|
||||||
|
|
||||||
np.random.seed(1337)
|
|
||||||
x_init = np.random.randn(1,4).astype(np.float32)
|
|
||||||
W_init = np.random.randn(4,4).astype(np.float32)
|
|
||||||
m_init = np.random.randn(1,4).astype(np.float32)
|
|
||||||
|
|
||||||
class TinyNet:
|
|
||||||
def __init__(self):
|
|
||||||
self.x = Tensor(x_init.copy())
|
|
||||||
self.W = Tensor(W_init.copy())
|
|
||||||
self.m = Tensor(m_init.copy())
|
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
out = self.x.matmul(self.W).relu()
|
|
||||||
out = out.log_softmax(1)
|
|
||||||
out = out.mul(self.m).add(self.m).sum()
|
|
||||||
return out
|
|
||||||
|
|
||||||
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
|
|
||||||
# only tests the lr
|
|
||||||
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
|
|
||||||
net = TinyNet()
|
|
||||||
optim = AdamW([net.W], lr=0.0)
|
|
||||||
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
|
|
||||||
lr = []
|
|
||||||
for _ in range(warmup_steps+decay_steps):
|
|
||||||
lr.append(optim.lr.item())
|
|
||||||
tiny_lr.step()
|
|
||||||
# reimplemented in python
|
|
||||||
expected = []
|
|
||||||
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
|
|
||||||
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
|
|
||||||
np.testing.assert_allclose(lr, expected, rtol=1e-5)
|
|
||||||
|
|
||||||
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
|
|
||||||
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
|
|
||||||
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
|
|
||||||
|
|
||||||
class TestLambdaLRLinearWarmup(unittest.TestCase):
|
|
||||||
def test_linear_lr_warmup(self):
|
|
||||||
BS, BASE_LR = 304, 2.5e-7
|
|
||||||
lr = BS * BASE_LR
|
|
||||||
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
|
|
||||||
optimizer = AdamW([Tensor([1.])])
|
|
||||||
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
|
|
||||||
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
|
|
||||||
lrs = {}
|
|
||||||
|
|
||||||
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
|
|
||||||
for i in range(1200):
|
|
||||||
lr_scheduler.step()
|
|
||||||
if i in {0, 499, 998, 999, 1000, 1199}:
|
|
||||||
lrs[i] = optimizer.lr.item()
|
|
||||||
|
|
||||||
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
|
|
||||||
np.testing.assert_equal(lrs[999], lrs[1000])
|
|
||||||
np.testing.assert_equal(lrs[999], lrs[1199])
|
|
||||||
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
|
|
||||||
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
49
test/external/external_test_optim.py
vendored
49
test/external/external_test_optim.py
vendored
|
|
@ -1,5 +1,5 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import unittest
|
import unittest, math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.optimizers import Lamb
|
from tensorflow.keras.optimizers import Lamb
|
||||||
|
|
@ -7,11 +7,11 @@ from tensorflow.python.ops import math_ops
|
||||||
from extra.lr_scheduler import LRSchedulerGroup
|
from extra.lr_scheduler import LRSchedulerGroup
|
||||||
|
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
|
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW
|
||||||
|
|
||||||
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
|
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
|
||||||
|
|
||||||
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
|
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
|
||||||
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
|
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
|
||||||
|
|
||||||
np.random.seed(1337)
|
np.random.seed(1337)
|
||||||
|
|
@ -173,5 +173,48 @@ class ExternalTestOptim(unittest.TestCase):
|
||||||
'warmup': steps_per_epoch * warmup_epochs,
|
'warmup': steps_per_epoch * warmup_epochs,
|
||||||
}, 1e-5, 1e-5, do_optim=False)
|
}, 1e-5, 1e-5, do_optim=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
|
||||||
|
# only tests the lr
|
||||||
|
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
|
||||||
|
net = TinyNet()
|
||||||
|
optim = AdamW([net.W], lr=0.0)
|
||||||
|
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
|
||||||
|
lr = []
|
||||||
|
for _ in range(warmup_steps+decay_steps):
|
||||||
|
lr.append(optim.lr.item())
|
||||||
|
tiny_lr.step()
|
||||||
|
# reimplemented in python
|
||||||
|
expected = []
|
||||||
|
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
|
||||||
|
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
|
||||||
|
np.testing.assert_allclose(lr, expected, rtol=1e-5)
|
||||||
|
|
||||||
|
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
|
||||||
|
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
|
||||||
|
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
|
||||||
|
|
||||||
|
class TestLambdaLRLinearWarmup(unittest.TestCase):
|
||||||
|
def test_linear_lr_warmup(self):
|
||||||
|
BS, BASE_LR = 304, 2.5e-7
|
||||||
|
lr = BS * BASE_LR
|
||||||
|
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
|
||||||
|
optimizer = AdamW([Tensor([1.])])
|
||||||
|
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
|
||||||
|
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
|
||||||
|
lrs = {}
|
||||||
|
|
||||||
|
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
|
||||||
|
for i in range(1200):
|
||||||
|
lr_scheduler.step()
|
||||||
|
if i in {0, 499, 998, 999, 1000, 1199}:
|
||||||
|
lrs[i] = optimizer.lr.item()
|
||||||
|
|
||||||
|
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
|
||||||
|
np.testing.assert_equal(lrs[999], lrs[1000])
|
||||||
|
np.testing.assert_equal(lrs[999], lrs[1199])
|
||||||
|
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
|
||||||
|
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest, os
|
import unittest, os
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from tinygrad import Context
|
from tinygrad import Tensor
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from examples.mlperf.model_train import train_stable_diffusion
|
from examples.mlperf.model_train import train_stable_diffusion
|
||||||
|
|
||||||
|
|
@ -14,10 +14,10 @@ class TestTrain(unittest.TestCase):
|
||||||
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
|
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
|
||||||
with TemporaryDirectory(prefix="test-train") as tmp:
|
with TemporaryDirectory(prefix="test-train") as tmp:
|
||||||
os.environ["UNET_CKPTDIR"] = tmp
|
os.environ["UNET_CKPTDIR"] = tmp
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
saved_ckpts = train_stable_diffusion()
|
saved_ckpts = train_stable_diffusion()
|
||||||
expected_ckpt = f"{tmp}/{num_steps}.safetensors"
|
expected_ckpt = f"{tmp}/{num_steps}.safetensors"
|
||||||
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
|
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
@ -3,7 +3,7 @@ import ast, pathlib, unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from tinygrad import Tensor, Context
|
from tinygrad import Tensor
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from test.helpers import slow
|
from test.helpers import slow
|
||||||
from extra.models.efficientnet import EfficientNet
|
from extra.models.efficientnet import EfficientNet
|
||||||
|
|
@ -40,7 +40,7 @@ def preprocess(img, new=False):
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def _infer(model: EfficientNet, img):
|
def _infer(model: EfficientNet, img):
|
||||||
with Context(TRAINING=0):
|
with Tensor.train(False):
|
||||||
out = model.forward(Tensor(img)).argmax(axis=-1)
|
out = model.forward(Tensor(img)).argmax(axis=-1)
|
||||||
return out.tolist()
|
return out.tolist()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,10 @@ import numpy as np
|
||||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||||
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
|
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import Context
|
|
||||||
from extra.datasets import fetch_mnist
|
from extra.datasets import fetch_mnist
|
||||||
|
|
||||||
def compare_tiny_torch(model, model_torch, X, Y):
|
def compare_tiny_torch(model, model_torch, X, Y):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
model_torch.train()
|
model_torch.train()
|
||||||
model_state_dict = get_state_dict(model)
|
model_state_dict = get_state_dict(model)
|
||||||
for k,v in model_torch.named_parameters():
|
for k,v in model_torch.named_parameters():
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ class TestRealWorld(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_train_mnist(self):
|
def test_train_mnist(self):
|
||||||
from examples.beautiful_mnist import Model
|
from examples.beautiful_mnist import Model
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
model = Model()
|
model = Model()
|
||||||
optimizer = optim.Adam(get_parameters(model))
|
optimizer = optim.Adam(get_parameters(model))
|
||||||
BS = 32
|
BS = 32
|
||||||
|
|
@ -125,7 +125,7 @@ class TestRealWorld(unittest.TestCase):
|
||||||
def test_forward_cifar(self):
|
def test_forward_cifar(self):
|
||||||
BS = 32
|
BS = 32
|
||||||
# with training batchnorm still though
|
# with training batchnorm still though
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def run(X): return model(X)
|
def run(X): return model(X)
|
||||||
|
|
@ -133,7 +133,7 @@ class TestRealWorld(unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_train_cifar(self):
|
def test_train_cifar(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
||||||
BS = 32
|
BS = 32
|
||||||
|
|
@ -151,7 +151,7 @@ class TestRealWorld(unittest.TestCase):
|
||||||
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
|
@unittest.skipUnless(dtypes.float16 in supported_dtypes, "need dtypes.float16")
|
||||||
def test_train_cifar_hyp(self):
|
def test_train_cifar_hyp(self):
|
||||||
dtypes.default_float = dtypes.float16
|
dtypes.default_float = dtypes.float16
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
||||||
initial_div_factor = hyp['opt']['initial_div_factor']
|
initial_div_factor = hyp['opt']['initial_div_factor']
|
||||||
|
|
@ -163,7 +163,7 @@ class TestRealWorld(unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_bert(self):
|
def test_bert(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
|
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
|
||||||
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
|
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
|
||||||
model = BertForPretraining(**args_tiny)
|
model = BertForPretraining(**args_tiny)
|
||||||
|
|
|
||||||
|
|
@ -1093,14 +1093,14 @@ class TestSchedule(unittest.TestCase):
|
||||||
|
|
||||||
#@unittest.skip("may want to reconsider this")
|
#@unittest.skip("may want to reconsider this")
|
||||||
def test_fold_batchnorm(self):
|
def test_fold_batchnorm(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.empty(1,32,4,4)
|
img = Tensor.empty(1,32,4,4)
|
||||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||||
out = bn(img)
|
out = bn(img)
|
||||||
check_schedule(out, 3, nn.state.get_parameters(bn))
|
check_schedule(out, 3, nn.state.get_parameters(bn))
|
||||||
|
|
||||||
def test_fold_conv_batchnorm_notrain(self):
|
def test_fold_conv_batchnorm_notrain(self):
|
||||||
with Context(TRAINING=0):
|
with Tensor.train(False):
|
||||||
img = Tensor.empty(1,3,8,8)
|
img = Tensor.empty(1,3,8,8)
|
||||||
c1 = nn.Conv2d(3,32,3)
|
c1 = nn.Conv2d(3,32,3)
|
||||||
bn = nn.BatchNorm2d(32, track_running_stats=True)
|
bn = nn.BatchNorm2d(32, track_running_stats=True)
|
||||||
|
|
@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
|
check_schedule(out, 1, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
|
||||||
|
|
||||||
def test_fold_conv_batchnorm_notrain_no_running_stats(self):
|
def test_fold_conv_batchnorm_notrain_no_running_stats(self):
|
||||||
with Context(TRAINING=0):
|
with Tensor.train(False):
|
||||||
img = Tensor.empty(1,3,8,8)
|
img = Tensor.empty(1,3,8,8)
|
||||||
c1 = nn.Conv2d(3,32,3)
|
c1 = nn.Conv2d(3,32,3)
|
||||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||||
|
|
@ -1116,7 +1116,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
|
check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
|
||||||
|
|
||||||
def test_fold_conv_batchnorm(self):
|
def test_fold_conv_batchnorm(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.empty(1,3,8,8)
|
img = Tensor.empty(1,3,8,8)
|
||||||
c1 = nn.Conv2d(3,32,3)
|
c1 = nn.Conv2d(3,32,3)
|
||||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||||
|
|
@ -1125,7 +1125,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
|
|
||||||
def test_fold_conv_batchnorm_optim(self, adam=False):
|
def test_fold_conv_batchnorm_optim(self, adam=False):
|
||||||
optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15)
|
optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15)
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.ones(1,3,4,4)
|
img = Tensor.ones(1,3,4,4)
|
||||||
c1 = nn.Conv2d(3,32,3)
|
c1 = nn.Conv2d(3,32,3)
|
||||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||||
|
|
@ -1139,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
|
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
|
||||||
|
|
||||||
def test_fold_batchnorm_backward(self):
|
def test_fold_batchnorm_backward(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
x = Tensor.empty((2, 16, 8, 8)).contiguous()
|
x = Tensor.empty((2, 16, 8, 8)).contiguous()
|
||||||
bn = nn.BatchNorm2d(16)
|
bn = nn.BatchNorm2d(16)
|
||||||
fw = bn(x).contiguous_backward().relu().contiguous()
|
fw = bn(x).contiguous_backward().relu().contiguous()
|
||||||
|
|
@ -1484,7 +1484,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(out, 4)
|
check_schedule(out, 4)
|
||||||
|
|
||||||
def test_adam_step_fusion(self):
|
def test_adam_step_fusion(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
x = Tensor.empty(4, 64, 32)
|
x = Tensor.empty(4, 64, 32)
|
||||||
layer = nn.Linear(32, 32*4)
|
layer = nn.Linear(32, 32*4)
|
||||||
_realize_weights(layer)
|
_realize_weights(layer)
|
||||||
|
|
@ -1494,7 +1494,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(opt.schedule_step(), 13)
|
check_schedule(opt.schedule_step(), 13)
|
||||||
|
|
||||||
def test_adam_conv_fuse(self):
|
def test_adam_conv_fuse(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.empty(2,3,4,4)
|
img = Tensor.empty(2,3,4,4)
|
||||||
c1 = nn.Conv2d(3,32,3)
|
c1 = nn.Conv2d(3,32,3)
|
||||||
_realize_weights(c1)
|
_realize_weights(c1)
|
||||||
|
|
@ -1505,7 +1505,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(opt.schedule_step(), 13)
|
check_schedule(opt.schedule_step(), 13)
|
||||||
|
|
||||||
def test_adam_2convs_fuse(self):
|
def test_adam_2convs_fuse(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.empty(2,3,4,4)
|
img = Tensor.empty(2,3,4,4)
|
||||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||||
c2 = nn.Conv2d(16,32,2,bias=False)
|
c2 = nn.Conv2d(16,32,2,bias=False)
|
||||||
|
|
@ -1517,7 +1517,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(opt.schedule_step(), 15)
|
check_schedule(opt.schedule_step(), 15)
|
||||||
|
|
||||||
def test_sgd_conv_fuse(self):
|
def test_sgd_conv_fuse(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.empty(2,3,4,4)
|
img = Tensor.empty(2,3,4,4)
|
||||||
c1 = nn.Conv2d(3,32,3)
|
c1 = nn.Conv2d(3,32,3)
|
||||||
_realize_weights(c1)
|
_realize_weights(c1)
|
||||||
|
|
@ -1527,7 +1527,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(opt.schedule_step(), 5) # TODO: 3?
|
check_schedule(opt.schedule_step(), 5) # TODO: 3?
|
||||||
|
|
||||||
def test_sgd_2convs_fuse(self):
|
def test_sgd_2convs_fuse(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.empty(2,3,4,4)
|
img = Tensor.empty(2,3,4,4)
|
||||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||||
c2 = nn.Conv2d(16,32,2,bias=False)
|
c2 = nn.Conv2d(16,32,2,bias=False)
|
||||||
|
|
@ -1538,7 +1538,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(opt.schedule_step(), 7)
|
check_schedule(opt.schedule_step(), 7)
|
||||||
|
|
||||||
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
|
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.empty(2,3,4,4)
|
img = Tensor.empty(2,3,4,4)
|
||||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||||
c2 = nn.Conv2d(16,32,2,bias=False)
|
c2 = nn.Conv2d(16,32,2,bias=False)
|
||||||
|
|
@ -1550,7 +1550,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(opt.schedule_step(), 11)
|
check_schedule(opt.schedule_step(), 11)
|
||||||
|
|
||||||
def test_sgd_4convs_fuse(self):
|
def test_sgd_4convs_fuse(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.empty(2,3,16,16)
|
img = Tensor.empty(2,3,16,16)
|
||||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||||
|
|
@ -1563,7 +1563,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(opt.schedule_step(), 15)
|
check_schedule(opt.schedule_step(), 15)
|
||||||
|
|
||||||
def test_sgd_4convs_fuse_conv_bw(self):
|
def test_sgd_4convs_fuse_conv_bw(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
img = Tensor.empty(2,3,16,16)
|
img = Tensor.empty(2,3,16,16)
|
||||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||||
|
|
@ -1664,7 +1664,7 @@ class TestSchedule(unittest.TestCase):
|
||||||
self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0)
|
self.assertEqual(len([x for x in linear.src[0].src[0].backward_slice_with_self if x.op is Ops.REDUCE]), 0)
|
||||||
|
|
||||||
def test_resnet_block(self):
|
def test_resnet_block(self):
|
||||||
with Context(TRAINING=0):
|
with Tensor.train(False):
|
||||||
in_planes, planes = 64, 64
|
in_planes, planes = 64, 64
|
||||||
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
bn1 = nn.BatchNorm2d(planes)
|
bn1 = nn.BatchNorm2d(planes)
|
||||||
|
|
|
||||||
|
|
@ -16,17 +16,41 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
|
||||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||||
return UOp(Ops.LOAD, dtypes.float, (
|
return UOp(Ops.LOAD, dtypes.float, (
|
||||||
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
|
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
|
||||||
|
UOp.const(dtypes.float, 0.0)
|
||||||
))
|
))
|
||||||
|
|
||||||
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
||||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||||
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
|
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
|
||||||
|
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||||
))
|
))
|
||||||
|
|
||||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
|
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
|
||||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||||
def Range(n, nmax): return UOp.range(nmax, n)
|
def Range(n, nmax): return UOp.range(nmax, n)
|
||||||
|
|
||||||
|
class TestHelpers(unittest.TestCase):
|
||||||
|
def test_is_increasing(self):
|
||||||
|
idx1 = Special("idx1", 32)
|
||||||
|
idx2 = Special("idx2", 64)
|
||||||
|
ridx0 = Variable("ridx0", 0, 5)
|
||||||
|
ridx1 = Variable("ridx1", 0, 2)
|
||||||
|
ridx2 = Variable("ridx2", 0, 2)
|
||||||
|
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
|
||||||
|
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
|
||||||
|
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
|
||||||
|
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
|
||||||
|
f3 = (idx2*2)+ridx1+(-1)
|
||||||
|
|
||||||
|
self.assertFalse(f0.is_increasing())
|
||||||
|
self.assertTrue(f1.is_increasing())
|
||||||
|
self.assertTrue(f2.is_increasing())
|
||||||
|
self.assertTrue(f3.is_increasing())
|
||||||
|
|
||||||
|
rng = UOp.range(5, 2)
|
||||||
|
self.assertTrue(rng.is_increasing())
|
||||||
|
self.assertTrue((rng+2).is_increasing())
|
||||||
|
|
||||||
class TestValidIdxSimplification(unittest.TestCase):
|
class TestValidIdxSimplification(unittest.TestCase):
|
||||||
def check(self, load, sidx, svalid, extra=()):
|
def check(self, load, sidx, svalid, extra=()):
|
||||||
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
|
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
|
||||||
|
|
@ -482,16 +506,6 @@ class TestImageSimplification(unittest.TestCase):
|
||||||
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
|
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
|
||||||
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
|
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
|
||||||
|
|
||||||
def test_drop_non_monotonic_window(self):
|
|
||||||
# two-sided window valid (645 <= gidx0 < 653) on a non-monotonic index (lane split via %4 and //4):
|
|
||||||
# gidx0 outside the window pushes idx_x out of the (1, 48) image, so the gate is dropped
|
|
||||||
gidx0 = Special("gidx0", 1064)
|
|
||||||
r12 = Range(12, 3)
|
|
||||||
valid = ((gidx0 < 645).ne(True)) & (gidx0 < 653)
|
|
||||||
idx = (r12*4 + (gidx0+3)%4 + (gidx0+3)//4*24 - 3888, UOp.const(dtypes.weakint, 0))
|
|
||||||
load = get_load_image_uop((1, 48, 4), valid, idx)
|
|
||||||
self.check(load, None, "(r12*4+(gidx0+3)%4+(gidx0+3)//4*24+-3888)", "0")
|
|
||||||
|
|
||||||
class TestDropTrueGate(unittest.TestCase):
|
class TestDropTrueGate(unittest.TestCase):
|
||||||
def test_drop_true_gate_on_index(self):
|
def test_drop_true_gate_on_index(self):
|
||||||
# test that INDEX with a constant True valid gets simplified to drop the valid
|
# test that INDEX with a constant True valid gets simplified to drop the valid
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# tensor tests that pass on NULL backend (no copyout needed)
|
# tensor tests that pass on NULL backend (no copyout needed)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Tensor, Device, dtypes, Context
|
from tinygrad import Tensor, Device, dtypes
|
||||||
from tinygrad.uop.ops import Ops, UOp
|
from tinygrad.uop.ops import Ops, UOp
|
||||||
from tinygrad.renderer.ptx import PTXRenderer
|
from tinygrad.renderer.ptx import PTXRenderer
|
||||||
from tinygrad.renderer.nir import NIRRenderer
|
from tinygrad.renderer.nir import NIRRenderer
|
||||||
|
|
@ -15,7 +15,7 @@ m_init = np.random.randn(1,3).astype(np.float32)
|
||||||
class TestTrainMode(unittest.TestCase):
|
class TestTrainMode(unittest.TestCase):
|
||||||
def test_train_mode(self):
|
def test_train_mode(self):
|
||||||
assert not Tensor.training
|
assert not Tensor.training
|
||||||
@Context(TRAINING=1)
|
@Tensor.train()
|
||||||
def f():
|
def f():
|
||||||
assert Tensor.training
|
assert Tensor.training
|
||||||
f()
|
f()
|
||||||
|
|
|
||||||
|
|
@ -317,19 +317,6 @@ class TestTensorUOpScatterReduce(unittest.TestCase):
|
||||||
def test_mean_exclude_self(self):
|
def test_mean_exclude_self(self):
|
||||||
self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean", include_self=False)
|
self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean", include_self=False)
|
||||||
|
|
||||||
class TestTensorUOpMaskedSelect(unittest.TestCase):
|
|
||||||
# only the fixed-size path is pure
|
|
||||||
def _check(self, t, mask, **kw):
|
|
||||||
self.assertIs(t.masked_select(mask, **kw).uop, t.uop.masked_select(mask.uop, **kw))
|
|
||||||
def test_masked_select_1d(self): self._check(_t(6), Tensor([True, False, True, False, True, False]), size=4)
|
|
||||||
def test_masked_select_2d(self):
|
|
||||||
self._check(_t(3, 3), Tensor([[True, False, True], [False, True, False], [False, False, True]]), size=6, fill_value=-1)
|
|
||||||
|
|
||||||
class TestTensorUOpNonzero(unittest.TestCase):
|
|
||||||
def _check(self, t, **kw): self.assertIs(t.nonzero(**kw).uop, t.uop.nonzero(**kw))
|
|
||||||
def test_nonzero_1d(self): self._check(_t(5), size=3)
|
|
||||||
def test_nonzero_2d(self): self._check(_t(2, 3), size=4)
|
|
||||||
|
|
||||||
class TestTensorUOpPool(unittest.TestCase):
|
class TestTensorUOpPool(unittest.TestCase):
|
||||||
def test_avg_pool2d(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d())
|
def test_avg_pool2d(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d())
|
||||||
def test_avg_pool2d_padding(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(padding=1))
|
def test_avg_pool2d_padding(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(padding=1))
|
||||||
|
|
@ -475,10 +462,6 @@ class TestTensorUOpCreation(unittest.TestCase):
|
||||||
self.assertIs(_strip_unique(Tensor.ones(2, 3).uop), _strip_unique(UOp.ones(2, 3)))
|
self.assertIs(_strip_unique(Tensor.ones(2, 3).uop), _strip_unique(UOp.ones(2, 3)))
|
||||||
def test_invalids(self):
|
def test_invalids(self):
|
||||||
self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids((2, 3), dtype=dtypes.int8)))
|
self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids((2, 3), dtype=dtypes.int8)))
|
||||||
def test_empty_like(self):
|
|
||||||
t = Tensor.empty(2, 3, dtype=dtypes.int8)
|
|
||||||
self.assertIs(_strip_unique(t.empty_like().uop), _strip_unique(t.uop.empty_like()))
|
|
||||||
self.assertIs(_strip_unique(t.empty_like(dtype=dtypes.float, device="NULL").uop), _strip_unique(t.uop.empty_like(dtypes.float, "NULL")))
|
|
||||||
def test_arange(self):
|
def test_arange(self):
|
||||||
self.assertIs(Tensor.arange(5).uop, UOp.arange(5))
|
self.assertIs(Tensor.arange(5).uop, UOp.arange(5))
|
||||||
def test_arange_empty(self):
|
def test_arange_empty(self):
|
||||||
|
|
|
||||||
|
|
@ -13,26 +13,35 @@ class TestWinograd(unittest.TestCase):
|
||||||
def test_forward_kernels(self):
|
def test_forward_kernels(self):
|
||||||
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
||||||
out = Tensor.conv2d(x,w)
|
out = Tensor.conv2d(x,w)
|
||||||
self.assertEqual(len(out.schedule_linear().src), 4)
|
self.assertEqual(len(out.schedule_linear().src), 2)
|
||||||
|
|
||||||
def test_backward_kernels(self):
|
def test_backward_kernels(self):
|
||||||
x,w = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,4,3,3).realize()
|
x,w = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,4,3,3).realize()
|
||||||
out = Tensor.conv2d(x,w, padding=1)
|
out = Tensor.conv2d(x,w, padding=1)
|
||||||
out.mean().backward()
|
out.mean().backward()
|
||||||
backward_schedule = x.grad.schedule_linear(w.grad)
|
backward_schedule = x.grad.schedule_linear(w.grad)
|
||||||
self.assertEqual(len(backward_schedule.src), 4)
|
self.assertEqual(len(backward_schedule.src), 2)
|
||||||
|
|
||||||
|
@unittest.skip("this requires optimizations")
|
||||||
def test_counters(self):
|
def test_counters(self):
|
||||||
IC, OC, H = 64, 64, 28
|
IC, OC, X, Y = 4,4,9,9
|
||||||
x,w = Tensor.empty(1,IC,H,H,device="NULL").realize(), Tensor.empty(OC,IC,3,3,device="NULL").realize()
|
x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize()
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
with Context(NOOPT=0, WINO=1): Tensor.conv2d(x,w).realize()
|
with Context(WINO=1):
|
||||||
ops_wino = GlobalCounters.global_ops
|
Tensor.conv2d(x,w).realize()
|
||||||
|
ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
with Context(NOOPT=0, WINO=0): Tensor.conv2d(x,w).realize()
|
with Context(WINO=0):
|
||||||
ops_normal = GlobalCounters.global_ops
|
Tensor.conv2d(x,w).realize()
|
||||||
print(f"ops: normal {ops_normal} wino {ops_wino} ratio {ops_wino/ops_normal:.2f}")
|
ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem
|
||||||
self.assertLess(ops_wino/ops_normal, 0.6)
|
|
||||||
|
ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
|
||||||
|
print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
|
||||||
|
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
|
||||||
|
|
||||||
|
# TODO: what's optimal on this?
|
||||||
|
self.assertLess(ops_ratio, 4.3)
|
||||||
|
self.assertLess(mem_ratio, 4)
|
||||||
|
|
||||||
def test_dtype(self):
|
def test_dtype(self):
|
||||||
IC, OC, X, Y = 4,4,9,9
|
IC, OC, X, Y = 4,4,9,9
|
||||||
|
|
|
||||||
|
|
@ -222,7 +222,7 @@ class TestCallSchedule(unittest.TestCase):
|
||||||
# find the FUNCTION nodes
|
# find the FUNCTION nodes
|
||||||
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
|
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
|
||||||
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
|
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
|
||||||
# the function bodies (src[0]) should have identical keys
|
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
|
||||||
self.assertEqual(c0.src[0].key, c1.src[0].key)
|
self.assertEqual(c0.src[0].key, c1.src[0].key)
|
||||||
|
|
||||||
def test_precompile_symbolic_2d(self):
|
def test_precompile_symbolic_2d(self):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad.function import function
|
from tinygrad.function import function
|
||||||
from tinygrad import Tensor, GlobalCounters, Device
|
from tinygrad import Tensor, GlobalCounters
|
||||||
from tinygrad.dtype import dtypes, Invalid
|
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, ProgramInfo
|
|
||||||
|
|
||||||
class TestFunction(unittest.TestCase):
|
class TestFunction(unittest.TestCase):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
|
|
@ -550,36 +549,6 @@ class TestFunctionTuple(unittest.TestCase):
|
||||||
f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize()
|
f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize()
|
||||||
np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.])
|
np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.])
|
||||||
|
|
||||||
def test_custom_kernel_program_invalids_not_captured(self):
|
|
||||||
# llama FP8 kernels are PROGRAM with bare-buffer sinks (no analyzable stores), so the invalids scratch
|
|
||||||
# still must not be captured as an input -- else it is read before the kernel writes it
|
|
||||||
src = "void k(float* restrict data0, float* restrict data1) { for (int i=0;i<4;i++) data0[i]=data1[i]*2.0f; }"
|
|
||||||
lib = Device["CPU"].compiler.compile(src)
|
|
||||||
def prog(C:UOp, A:UOp) -> UOp:
|
|
||||||
sink = UOp.sink(C.base, A.base, arg=KernelInfo(name="k"))
|
|
||||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="CPU"), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
|
||||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)),
|
|
||||||
arg=ProgramInfo(name="k", global_size=(1, 1, 1), local_size=(1, 1, 1), globals=(0, 1)))
|
|
||||||
|
|
||||||
@function(precompile=True)
|
|
||||||
def f(a:Tensor):
|
|
||||||
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
|
|
||||||
return Tensor.custom_kernel(c, a, fxn=prog)[0]
|
|
||||||
|
|
||||||
a = Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()
|
|
||||||
np.testing.assert_allclose(f(a).numpy(), [2., 4., 6., 8.])
|
|
||||||
|
|
||||||
def test_invalid_store_into_realized_buffer_is_captured(self):
|
|
||||||
# only fresh invalids() scratch is skipped; a realized buffer is a real input even if an Invalid store
|
|
||||||
# writes into part of it (its other elements must be preserved), so it is still captured
|
|
||||||
state = Tensor([10., 20., 30., 40.], device="CPU").contiguous().realize()
|
|
||||||
@function(precompile=True, allow_implicit=True)
|
|
||||||
def f(a:Tensor):
|
|
||||||
after = state.uop.after(state.uop.shrink(((0, 2),)).store(UOp.const(dtypes.float32, Invalid, shape=(2,))))
|
|
||||||
return Tensor(after).contiguous() + a
|
|
||||||
out = f(Tensor([1., 1., 1., 1.], device="CPU").contiguous().realize())
|
|
||||||
np.testing.assert_allclose(out.numpy(), [11., 21., 31., 41.])
|
|
||||||
|
|
||||||
def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2):
|
def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2):
|
||||||
devs = ("CPU:0", "CPU:1")
|
devs = ("CPU:0", "CPU:1")
|
||||||
def my_kernel(C:UOp, A:UOp) -> UOp:
|
def my_kernel(C:UOp, A:UOp) -> UOp:
|
||||||
|
|
|
||||||
|
|
@ -207,7 +207,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||||
out.numpy()
|
out.numpy()
|
||||||
|
|
||||||
def test_backprop_conv(self):
|
def test_backprop_conv(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
conv = nn.Conv2d(3, 16, 3)
|
conv = nn.Conv2d(3, 16, 3)
|
||||||
for p in get_parameters(conv): p.shard_(devices_2)
|
for p in get_parameters(conv): p.shard_(devices_2)
|
||||||
optim = nn.optim.Adam(get_parameters(conv))
|
optim = nn.optim.Adam(get_parameters(conv))
|
||||||
|
|
@ -511,7 +511,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||||
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
|
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
|
||||||
|
|
||||||
def test_dropout_on_shard(self):
|
def test_dropout_on_shard(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
X = Tensor.ones(256).to(devices_2)
|
X = Tensor.ones(256).to(devices_2)
|
||||||
output = X.dropout(0.5).numpy()
|
output = X.dropout(0.5).numpy()
|
||||||
unique, counts = np.unique(output, return_counts=True)
|
unique, counts = np.unique(output, return_counts=True)
|
||||||
|
|
@ -519,7 +519,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||||
assert 96 < counts[0] < 160, counts[0]
|
assert 96 < counts[0] < 160, counts[0]
|
||||||
|
|
||||||
def test_dropout_on_shard_axis(self):
|
def test_dropout_on_shard_axis(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
X = Tensor.ones(512).shard(devices_2, axis=0)
|
X = Tensor.ones(512).shard(devices_2, axis=0)
|
||||||
output = X.dropout(0.5).numpy()
|
output = X.dropout(0.5).numpy()
|
||||||
unique, counts = np.unique(output, return_counts=True)
|
unique, counts = np.unique(output, return_counts=True)
|
||||||
|
|
@ -664,7 +664,7 @@ class TestBatchNorm(unittest.TestCase):
|
||||||
def setUp(self): pass
|
def setUp(self): pass
|
||||||
|
|
||||||
def test_unsynced_backprop_conv_bn(self):
|
def test_unsynced_backprop_conv_bn(self):
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
from extra.lr_scheduler import OneCycleLR
|
from extra.lr_scheduler import OneCycleLR
|
||||||
|
|
||||||
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
|
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
|
||||||
|
|
@ -709,7 +709,7 @@ class TestBatchNorm(unittest.TestCase):
|
||||||
bn_ts.append(bni)
|
bn_ts.append(bni)
|
||||||
return bn_ts[0].cat(*bn_ts[1:])
|
return bn_ts[0].cat(*bn_ts[1:])
|
||||||
|
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
conv = nn.Conv2d(3, 16, 3)
|
conv = nn.Conv2d(3, 16, 3)
|
||||||
bn = BatchNorm(16)
|
bn = BatchNorm(16)
|
||||||
|
|
||||||
|
|
@ -731,7 +731,7 @@ class TestBatchNorm(unittest.TestCase):
|
||||||
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
||||||
GPUS = (d1, d2)
|
GPUS = (d1, d2)
|
||||||
|
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
conv = nn.Conv2d(3, 16, 3)
|
conv = nn.Conv2d(3, 16, 3)
|
||||||
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
|
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
|
||||||
|
|
||||||
|
|
@ -756,7 +756,7 @@ class TestBatchNorm(unittest.TestCase):
|
||||||
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
||||||
x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0)
|
x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0)
|
||||||
|
|
||||||
with Context(TRAINING=is_training):
|
with Tensor.train(is_training):
|
||||||
bns = []
|
bns = []
|
||||||
for _ in range(len(devices)):
|
for _ in range(len(devices)):
|
||||||
bn = nn.BatchNorm2d(8)
|
bn = nn.BatchNorm2d(8)
|
||||||
|
|
@ -777,7 +777,7 @@ class TestBatchNorm(unittest.TestCase):
|
||||||
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
||||||
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
|
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
|
||||||
|
|
||||||
with Context(TRAINING=1):
|
with Tensor.train():
|
||||||
synced_bn = BatchNorm2d(8)
|
synced_bn = BatchNorm2d(8)
|
||||||
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
|
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,17 +13,16 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||||
# import all pattern matchers here
|
# import all pattern matchers here
|
||||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
|
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
|
||||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps, get_simplifying_rewrite_patterns
|
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
|
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
|
||||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
ReduceContext, correct_load_store, pm_render, pm_make_images
|
||||||
from tinygrad.codegen.opt.postrange import apply_opts
|
from tinygrad.codegen.opt.postrange import apply_opts
|
||||||
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
||||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
|
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
|
||||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||||
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
||||||
from tinygrad.codegen.late.coalese import memory_coalesing
|
|
||||||
|
|
||||||
pm_index_is_shrink = PatternMatcher([
|
pm_index_is_shrink = PatternMatcher([
|
||||||
# rewrite non-image INDEX to SHRINK
|
# rewrite non-image INDEX to SHRINK
|
||||||
|
|
@ -53,8 +52,12 @@ pm_number_params = PatternMatcher([
|
||||||
(UPat(Ops.PARAM, name="x"), do_number_param),
|
(UPat(Ops.PARAM, name="x"), do_number_param),
|
||||||
])
|
])
|
||||||
|
|
||||||
pm_no_weakints = PatternMatcher([
|
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
|
||||||
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int))
|
pm_load_to_alu = PatternMatcher([
|
||||||
|
# NOTE: the PtrDType thing is temporary
|
||||||
|
(UPat(GroupOp.Elementwise|{Ops.STACK,Ops.GEP}, name="x"), lambda x:
|
||||||
|
x.replace(src=tuple([maybe_load(u) for u in x.src])) if not isinstance(x.dtype, PtrDType) else None),
|
||||||
|
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
|
||||||
])
|
])
|
||||||
|
|
||||||
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
|
|
@ -83,7 +86,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
sink = apply_opts(sink, ren, beam=ast.arg.beam)
|
sink = apply_opts(sink, ren, beam=ast.arg.beam)
|
||||||
|
|
||||||
# ** expander (expand_rewrite) **
|
# ** expander (expand_rewrite) **
|
||||||
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
|
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
|
||||||
|
|
||||||
# expand
|
# expand
|
||||||
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||||
|
|
@ -101,7 +104,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
# **** optimizations are done, now we lower to actual code ****
|
# **** optimizations are done, now we lower to actual code ****
|
||||||
|
|
||||||
# add loads and remove invalids
|
# add loads and remove invalids
|
||||||
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
sink = graph_rewrite(sink, pm_load_to_alu+pm_remove_invalid, name="** add loads (code)")
|
||||||
|
|
||||||
# create image buffers
|
# create image buffers
|
||||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
||||||
|
|
@ -118,23 +121,18 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
# optional pre matcher
|
# optional pre matcher
|
||||||
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
|
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
|
||||||
|
|
||||||
# floordiv+mod / dtype decomp (early)
|
# decompositions
|
||||||
supported_ops = tuple(ren.code_for_op.keys())
|
supported_ops = tuple(ren.code_for_op.keys())
|
||||||
pm_decomp = symbolic_simple+get_simplifying_rewrite_patterns(supported_ops)
|
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
|
||||||
sink = graph_rewrite(sink, pm_decomp, name="early decompositions")
|
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||||
|
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
|
||||||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
||||||
|
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
|
||||||
|
|
||||||
# do memory coalesing (late)
|
# GEP/STACK stuff
|
||||||
sink = memory_coalesing(sink, ren)
|
|
||||||
|
|
||||||
# instruction selection decompositions
|
|
||||||
pm_decomp = pm_decomp+\
|
|
||||||
get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))+\
|
|
||||||
get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
|
||||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="late decompositions")
|
|
||||||
|
|
||||||
# this is new style (TODO: this should all be removed)
|
|
||||||
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
|
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
|
||||||
|
|
||||||
|
# this is new style
|
||||||
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
||||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
|
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
|
||||||
|
|
||||||
|
|
@ -143,7 +141,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
|
|
||||||
# final rules for the renderer (without sym)
|
# final rules for the renderer (without sym)
|
||||||
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
||||||
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends+pm_no_weakints
|
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends
|
||||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
|
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
|
||||||
|
|
||||||
# this was the linearizer
|
# this was the linearizer
|
||||||
|
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
from typing import Any
|
|
||||||
import itertools
|
|
||||||
from collections import defaultdict
|
|
||||||
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
|
|
||||||
from tinygrad.uop.ops import UOp, Ops
|
|
||||||
from tinygrad.helpers import getenv
|
|
||||||
from tinygrad.renderer import Renderer
|
|
||||||
|
|
||||||
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
|
|
||||||
if getenv("DMC"): return sink
|
|
||||||
|
|
||||||
# collect
|
|
||||||
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
|
|
||||||
for u in sink.toposort():
|
|
||||||
# TODO: this should handle images too, it's just memory coalesing
|
|
||||||
if u.op in {Ops.LOAD, Ops.STORE} and not isinstance(u.src[0].src[0].dtype, ImageDType):
|
|
||||||
assert len(u.src) == (2 if u.op is Ops.STORE else 1), "memory coalesing does not support gated loads/stores"
|
|
||||||
assert u.src[0].op is Ops.INDEX
|
|
||||||
buf, idx_u = u.src[0].src
|
|
||||||
if buf.addrspace == AddrSpace.REG: continue
|
|
||||||
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
|
|
||||||
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
|
|
||||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
|
||||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
|
||||||
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
|
||||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
|
||||||
else: root_src, arg = idx, 0
|
|
||||||
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
|
|
||||||
|
|
||||||
# build replacements
|
|
||||||
replacements = {}
|
|
||||||
for (op,buf,base,valid),offsets in memory.items():
|
|
||||||
# allowed lengths (copied in)
|
|
||||||
lengths = []
|
|
||||||
must_divide = True
|
|
||||||
if ctx is not None and ctx.target.device == "DSP":
|
|
||||||
lengths = [128,64,32,16,8,4]
|
|
||||||
must_divide = False
|
|
||||||
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
|
|
||||||
pass
|
|
||||||
elif buf.addrspace == AddrSpace.REG:
|
|
||||||
pass
|
|
||||||
elif isinstance(buf.dtype, ImageDType):
|
|
||||||
lengths = [4]
|
|
||||||
elif ctx is not None and ctx.supports_float4:
|
|
||||||
# TODO: a better way to get this than ctx
|
|
||||||
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
|
|
||||||
lengths.append(1) # worst case, it's not folded
|
|
||||||
# do the grouping
|
|
||||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
|
||||||
for full_grp in grouped_offsets:
|
|
||||||
while len(full_grp):
|
|
||||||
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
|
|
||||||
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
|
|
||||||
grp = full_grp[:length]
|
|
||||||
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
|
|
||||||
if op == Ops.STORE:
|
|
||||||
datas = []
|
|
||||||
for i,g in enumerate(grp):
|
|
||||||
assert len(offsets[g]) == 1, f"attempting multiple stores: {len(offsets[g])}"
|
|
||||||
datas.append(offsets[g][0].src[1])
|
|
||||||
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
|
|
||||||
store = idx.store(data, valid) if valid is not None else idx.store(data)
|
|
||||||
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
|
|
||||||
else:
|
|
||||||
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
|
|
||||||
for i,g in enumerate(grp):
|
|
||||||
for oo in offsets[g]:
|
|
||||||
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
|
|
||||||
full_grp = full_grp[length:]
|
|
||||||
|
|
||||||
# apply
|
|
||||||
return sink.substitute(replacements, name="memory coalesing")
|
|
||||||
|
|
@ -14,7 +14,7 @@ from tinygrad.renderer import Renderer
|
||||||
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
|
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
|
||||||
# can drop valid if idx is out of bound when valid is False
|
# can drop valid if idx is out of bound when valid is False
|
||||||
drop_stmt = []
|
drop_stmt = []
|
||||||
for i,stmt in enumerate(valid.split_uop(Ops.AND)):
|
for stmt in valid.split_uop(Ops.AND):
|
||||||
if (res:=parse_valid(stmt)) is None: continue
|
if (res:=parse_valid(stmt)) is None: continue
|
||||||
X, is_upper_bound, c = res
|
X, is_upper_bound, c = res
|
||||||
|
|
||||||
|
|
@ -25,12 +25,12 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
|
||||||
drop_stmt.append(stmt)
|
drop_stmt.append(stmt)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# check if idx is out of bound when X is on the wrong side of the bound: X in [c+1, vmax] or [vmin, c-1]
|
# if X <= c, check if it's out of bound when X = c+1
|
||||||
lo, hi = (c + 1, X.vmax) if is_upper_bound else (X.vmin, c - 1)
|
# if X >= c, check if it's out of bound when X = c-1
|
||||||
if lo <= hi:
|
test_value = c + 1 if is_upper_bound else c - 1
|
||||||
fake = UOp.variable(f"fake{i}", lo, hi, X.dtype)
|
for i,b in zip(idx.src, (width, height)):
|
||||||
for coord,b in zip(idx.src, (width, height)):
|
if i.is_increasing():
|
||||||
rw = coord.substitute({X:fake}).simplify()
|
rw = i.substitute({X:X.const_like(test_value)})
|
||||||
if rw.vmin >= b or rw.vmax < 0:
|
if rw.vmin >= b or rw.vmax < 0:
|
||||||
drop_stmt.append(stmt)
|
drop_stmt.append(stmt)
|
||||||
break
|
break
|
||||||
|
|
@ -162,8 +162,18 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||||
# determine fold lengths
|
# determine fold lengths
|
||||||
lengths = []
|
lengths = []
|
||||||
must_divide = True
|
must_divide = True
|
||||||
# TODO: this belongs in coalese
|
if ctx is not None and ctx.target.device == "DSP":
|
||||||
if isinstance(buf.dtype, ImageDType): lengths = [4]
|
lengths = [128,64,32,16,8,4]
|
||||||
|
must_divide = False
|
||||||
|
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
|
||||||
|
pass
|
||||||
|
elif buf.addrspace == AddrSpace.REG:
|
||||||
|
pass
|
||||||
|
elif isinstance(buf.dtype, ImageDType):
|
||||||
|
lengths = [4]
|
||||||
|
elif ctx is not None and ctx.supports_float4:
|
||||||
|
# TODO: a better way to get this than ctx
|
||||||
|
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
|
||||||
lengths.append(1) # worst case, it's not folded
|
lengths.append(1) # worst case, it's not folded
|
||||||
|
|
||||||
# filter fold lengths that don't divide
|
# filter fold lengths that don't divide
|
||||||
|
|
@ -279,6 +289,7 @@ pm_render = PatternMatcher([
|
||||||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||||
|
(UPat(Ops.PTRCAT, src=(UPat(name='x'),)), lambda x: x),
|
||||||
])
|
])
|
||||||
|
|
||||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||||
|
|
@ -346,21 +357,6 @@ pm_reduce = PatternMatcher([
|
||||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||||
])
|
])
|
||||||
|
|
||||||
# add loads
|
|
||||||
|
|
||||||
def add_load(idx:UOp):
|
|
||||||
if isinstance(idx.dtype, PtrDType): return None
|
|
||||||
assert isinstance(idx.src[0].dtype, PtrDType), f"param is not PtrDType {idx.src[0].dtype}"
|
|
||||||
return idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)
|
|
||||||
|
|
||||||
pm_add_loads = PatternMatcher([
|
|
||||||
# add loads to non ptr index
|
|
||||||
(UPat(Ops.INDEX, name="idx"), add_load),
|
|
||||||
# remove loads from stores
|
|
||||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
|
|
||||||
(UPat(Ops.LOAD, src=(UPat(Ops.LOAD),), allow_any_len=True, name="l"), lambda l: l.replace(src=(l.src[0].src[0],)+l.src[1:])),
|
|
||||||
])
|
|
||||||
|
|
||||||
# make images
|
# make images
|
||||||
|
|
||||||
pm_imageh_store = PatternMatcher([
|
pm_imageh_store = PatternMatcher([
|
||||||
|
|
|
||||||
|
|
@ -101,12 +101,6 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||||
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
|
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
|
||||||
is_masked = k.rngs[axis] in where_gate_rngs
|
is_masked = k.rngs[axis] in where_gate_rngs
|
||||||
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
||||||
# upcasting a masked global axis moves that range out of the launch grid into each work-item
|
|
||||||
# under IMAGE, skip the upcast unless enough global work-items remain after it to hide memory latency
|
|
||||||
if IMAGE and k.axis_types[axis] is AxisType.GLOBAL:
|
|
||||||
global_upcast = prod(k.full_shape[i] for i in to_upcast if k.axis_types[i] is AxisType.GLOBAL) * k.full_shape[axis]
|
|
||||||
global_items_after = prod(k.full_shape[i] for i in k.axes_of(AxisType.GLOBAL)) // global_upcast
|
|
||||||
if resolve(global_items_after < getenv("OCCUPANCY_FLOOR", 4096), False): continue
|
|
||||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||||
to_upcast.append(axis)
|
to_upcast.append(axis)
|
||||||
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,8 @@ def _prepare_jit_inputs(args, kwargs):
|
||||||
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
|
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
|
||||||
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
|
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
|
||||||
def get_input_uops() -> list[UOp]: return flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
def get_input_uops() -> list[UOp]: return flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||||
if any(u.device is None for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
|
# TODO: drop the CONST branch once all CONST are deviceless
|
||||||
|
if any(u.device is None or u.base.op is Ops.CONST for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
|
||||||
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||||
input_uops = get_input_uops()
|
input_uops = get_input_uops()
|
||||||
# collect buffer UOps (including MultiBuffer)
|
# collect buffer UOps (including MultiBuffer)
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,26 @@
|
||||||
import functools, time
|
import functools, itertools, time
|
||||||
from typing import Generic, TypeVar, Callable, cast, overload
|
from typing import Generic, TypeVar, Callable, cast, overload
|
||||||
from tinygrad.helpers import Context, dedup, getenv, DEBUG
|
from tinygrad.helpers import Context, dedup, getenv, DEBUG
|
||||||
from tinygrad.dtype import Invalid
|
|
||||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
|
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn.state import get_state_dict
|
from tinygrad.nn.state import get_state_dict
|
||||||
|
|
||||||
def add_to_ctx(ctx, x:UOp):
|
def add_to_ctx(ctx, x:UOp):
|
||||||
if x.buf_uop in ctx[1]: return None
|
|
||||||
ret = x.param_like(len(ctx[0]))
|
ret = x.param_like(len(ctx[0]))
|
||||||
ctx[0].append(x)
|
ctx[0].append(x)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
pm_transform_unique_const = PatternMatcher([
|
||||||
|
# transform unique consts to LUNIQUE
|
||||||
|
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
|
||||||
|
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))),
|
||||||
|
])
|
||||||
|
|
||||||
pm_ctx = PatternMatcher([
|
pm_ctx = PatternMatcher([
|
||||||
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
|
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
|
||||||
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
|
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
|
||||||
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
|
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
|
||||||
])
|
])+pm_transform_unique_const
|
||||||
|
|
||||||
def invalid_outputs(uret:UOp) -> set[UOp]:
|
|
||||||
# invalids() returns fresh write-only scratch: a clone storing CONST(Invalid)
|
|
||||||
# don't capture it as an input; only skip fresh buffers, not realized ones
|
|
||||||
return {u.src[0].buf_uop for u in uret.backward_slice_with_self
|
|
||||||
if u.op is Ops.STORE and u.src[1].base.op is Ops.CONST and u.src[1].base.arg is Invalid
|
|
||||||
and not u.src[0].buf_uop.is_realized}
|
|
||||||
|
|
||||||
ReturnType = TypeVar('ReturnType')
|
ReturnType = TypeVar('ReturnType')
|
||||||
class _function(Generic[ReturnType]):
|
class _function(Generic[ReturnType]):
|
||||||
|
|
@ -66,7 +63,7 @@ class _function(Generic[ReturnType]):
|
||||||
|
|
||||||
# the BUFFERs that are left are the implicit inputs
|
# the BUFFERs that are left are the implicit inputs
|
||||||
num_explicit = len(call_uops)
|
num_explicit = len(call_uops)
|
||||||
uret = graph_rewrite(uret, pm_ctx, (call_uops, invalid_outputs(uret)), bottom_up=True, name="get_implicit_inputs")
|
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
|
||||||
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
|
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
|
||||||
if not self.allow_implicit:
|
if not self.allow_implicit:
|
||||||
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]
|
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]
|
||||||
|
|
|
||||||
|
|
@ -240,7 +240,6 @@ DEV, DEBUG, BEAM, NOOPT = _DEV("DEV", ""), ContextVar("DEBUG", 0), ContextVar("B
|
||||||
IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0)
|
IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0)
|
||||||
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
|
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
|
||||||
WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0)
|
WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0)
|
||||||
TRAINING = ContextVar("TRAINING", 0)
|
|
||||||
USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0)
|
USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0)
|
||||||
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
|
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
|
||||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)
|
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import functools, itertools, string
|
import functools, itertools
|
||||||
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args, cast
|
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args
|
||||||
from tinygrad.mixin.elementwise import ElementwiseMixin
|
from tinygrad.mixin.elementwise import ElementwiseMixin
|
||||||
from tinygrad.mixin.movement import MovementMixin
|
from tinygrad.mixin.movement import MovementMixin
|
||||||
from tinygrad.mixin.reduce import ReduceMixin
|
from tinygrad.mixin.reduce import ReduceMixin
|
||||||
from tinygrad.uop import Ops
|
from tinygrad.uop import Ops
|
||||||
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
|
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
|
||||||
from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||||
from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod
|
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod
|
||||||
from tinygrad.helpers import resolve_pool_pads, round_up
|
from tinygrad.helpers import resolve_pool_pads, round_up
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -17,59 +17,35 @@ ReductionStr = Literal["mean", "sum", "none"]
|
||||||
|
|
||||||
|
|
||||||
class OpMixin(ElementwiseMixin, ReduceMixin):
|
class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||||
def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory")
|
@staticmethod
|
||||||
|
def const(dtype, b): raise NotImplementedError
|
||||||
|
|
||||||
def item(self) -> PyConst:
|
@classmethod
|
||||||
|
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
|
||||||
|
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
|
||||||
"""
|
"""
|
||||||
Returns the value of this tensor as a standard Python number.
|
Creates a tensor with the given shape, filled with the given value.
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
t = Tensor(42)
|
print(Tensor.full((2, 3), 42).numpy())
|
||||||
print(t.item())
|
|
||||||
```
|
```
|
||||||
"""
|
|
||||||
assert self.numel() == 1, "must have one element for item"
|
|
||||||
return self.data()[(0,) * len(self.shape)]
|
|
||||||
|
|
||||||
def __getitem__(self, indices) -> Self:
|
|
||||||
"""
|
|
||||||
Retrieves a sub-tensor using indexing.
|
|
||||||
|
|
||||||
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
t = Tensor.arange(12).reshape(3, 4)
|
print(Tensor.full((2, 3), False).numpy())
|
||||||
print(t.numpy())
|
|
||||||
```
|
|
||||||
|
|
||||||
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t[1, 2].numpy())
|
|
||||||
```
|
|
||||||
|
|
||||||
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t[0:2, ::2].numpy())
|
|
||||||
```
|
|
||||||
|
|
||||||
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
|
|
||||||
```
|
|
||||||
|
|
||||||
- `None` Indexing: Add a new dimension to the tensor.
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t[:, None].shape)
|
|
||||||
```
|
|
||||||
|
|
||||||
NOTE: Out-of-bounds indexing results in a value of `0`.
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor([1, 2, 3])
|
|
||||||
print(t[Tensor([4, 3, 2])].numpy())
|
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
return self._getitem(indices)
|
# TODO: enable this check
|
||||||
|
# if not buffer: assert device is None, "buffer=False does not support device specification"
|
||||||
|
from tinygrad.uop.ops import UOp
|
||||||
|
new_shape = argfix(shape)
|
||||||
|
dt = to_dtype(dtype) if dtype is not None else None
|
||||||
|
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
|
||||||
|
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
|
||||||
|
return val.clone(device=device) if buffer else val
|
||||||
|
|
||||||
|
def __getitem__(self, indices) -> Self: return self._getitem(indices)
|
||||||
|
|
||||||
def _getitem(self, indices, v=None) -> Self:
|
def _getitem(self, indices, v=None) -> Self:
|
||||||
from tinygrad.uop.ops import UOp
|
from tinygrad.uop.ops import UOp
|
||||||
|
|
@ -162,6 +138,40 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||||
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
|
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
|
||||||
return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self)
|
return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def zeros(cls, *shape, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the given shape, filled with zeros.
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(Tensor.zeros(2, 3).numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return cls.full(argfix(*shape), 0.0, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ones(cls, *shape, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the given shape, filled with ones.
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(Tensor.ones(2, 3).numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return cls.full(argfix(*shape), 1.0, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self:
|
def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self:
|
||||||
"""
|
"""
|
||||||
|
|
@ -357,7 +367,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||||
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
|
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
|
||||||
raise NotImplementedError(f"{mode=} is not supported")
|
raise NotImplementedError(f"{mode=} is not supported")
|
||||||
|
|
||||||
def _broadcasted(self, y:Self|ConstType|UOp, reverse:bool=False) -> tuple[Self, Self]:
|
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
|
||||||
if not isinstance(y, type(self)): y = self.ufix(y)
|
if not isinstance(y, type(self)): y = self.ufix(y)
|
||||||
x, y = (self, y) if not reverse else (y, self)
|
x, y = (self, y) if not reverse else (y, self)
|
||||||
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
|
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
|
||||||
|
|
@ -413,47 +423,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||||
def __matmul__(self, x:Self) -> Self: return self.matmul(x)
|
def __matmul__(self, x:Self) -> Self: return self.matmul(x)
|
||||||
def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True)
|
def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
|
|
||||||
"""
|
|
||||||
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
|
||||||
|
|
||||||
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
x = Tensor([[1, 2], [3, 4]])
|
|
||||||
y = Tensor([[5, 6], [7, 8]])
|
|
||||||
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
|
|
||||||
# expand ellipsis to letters, determine output
|
|
||||||
if "..." in formula:
|
|
||||||
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
|
|
||||||
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
|
|
||||||
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
|
|
||||||
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
|
|
||||||
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
|
|
||||||
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
|
|
||||||
inputs = lhs.split(",")
|
|
||||||
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
|
|
||||||
# trace: take diagonal when letter repeats in single input
|
|
||||||
for i, (s, x) in enumerate(zip(inputs, xs)):
|
|
||||||
for c in set(s):
|
|
||||||
while s.count(c) > 1:
|
|
||||||
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
|
|
||||||
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
|
|
||||||
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
|
|
||||||
s = s[:k] + s[k+1:]
|
|
||||||
inputs[i], xs[i] = s, x
|
|
||||||
# check sizes and build sorted alphabet
|
|
||||||
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
|
|
||||||
alpha = sorted(sz)
|
|
||||||
# align all tensors to alphabet, multiply, sum non-output, permute to output order
|
|
||||||
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
|
|
||||||
for s, x in zip(inputs, xs)]
|
|
||||||
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
|
|
||||||
|
|
||||||
def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]:
|
def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]:
|
||||||
"""
|
"""
|
||||||
Computes the gradient of the targets with respect to self.
|
Computes the gradient of the targets with respect to self.
|
||||||
|
|
@ -1178,68 +1147,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||||
# select from values for each True element in mask else select from self
|
# select from values for each True element in mask else select from self
|
||||||
return mask.where(values, self)
|
return mask.where(values, self)
|
||||||
|
|
||||||
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
|
|
||||||
"""
|
|
||||||
Selects elements from `self` based on the boolean `mask`.
|
|
||||||
|
|
||||||
With `size=None` (default), output length equals the number of `True` values (not jittable).
|
|
||||||
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
|
|
||||||
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
|
|
||||||
print(t.numpy())
|
|
||||||
print(mask.numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.masked_select(mask).numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
|
|
||||||
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
|
|
||||||
mask_cumsum = mask.cumsum()
|
|
||||||
if size is None:
|
|
||||||
counts = type(self).zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
|
|
||||||
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
|
|
||||||
counts = type(self).zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
|
|
||||||
return (type(self).arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
|
|
||||||
|
|
||||||
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Self:
|
|
||||||
"""
|
|
||||||
Returns the indices of the elements that are non-zero.
|
|
||||||
|
|
||||||
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
|
|
||||||
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor([1, 0, 2, 0, 3])
|
|
||||||
print(t.numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.nonzero().numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor([[1, 0], [0, 2]])
|
|
||||||
print(t.numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.nonzero().numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.nonzero(size=3, fill_value=-1).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if self.ndim == 0:
|
|
||||||
return type(self).zeros(size if size is not None else int(self.ne(0).item()), 0, dtype=dtypes.int32, device=self.device)
|
|
||||||
mask = self.ne(0).flatten()
|
|
||||||
indices = type(self).stack(*[type(self).arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
|
|
||||||
for i, s in enumerate(self.shape)], dim=-1)
|
|
||||||
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
|
|
||||||
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
|
|
||||||
|
|
||||||
# ***** functional nn ops *****
|
# ***** functional nn ops *****
|
||||||
|
|
||||||
def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:
|
def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:
|
||||||
|
|
|
||||||
|
|
@ -1,101 +1,13 @@
|
||||||
from typing import TYPE_CHECKING, Callable, Self
|
from typing import Self
|
||||||
from tinygrad.dtype import ConstType, DTypeLike, Invalid, dtypes, to_dtype
|
from tinygrad.dtype import ConstType, DType
|
||||||
from tinygrad.helpers import argfix
|
|
||||||
from tinygrad.mixin.dtype import DTypeMixin
|
|
||||||
from tinygrad.mixin.movement import MovementMixin
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
class CreationMixin:
|
||||||
from tinygrad.uop.ops import sint, UOp
|
def const_like(self, b: ConstType) -> Self: raise NotImplementedError
|
||||||
|
def cast(self, dtype: DType) -> Self: raise NotImplementedError
|
||||||
|
|
||||||
class CreationMixin(DTypeMixin, MovementMixin):
|
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self:
|
||||||
@staticmethod
|
"""Creates a tensor with the same shape as `self`, filled with the given value."""
|
||||||
def const(dtype, b): raise NotImplementedError
|
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype)
|
||||||
|
|
||||||
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
|
|
||||||
|
|
||||||
def _multi_like(self, fxn:'Callable[[tuple[sint, ...], str|None], Self]') -> Self:
|
|
||||||
from tinygrad.uop.ops import UOp
|
|
||||||
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
|
|
||||||
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
|
|
||||||
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
|
|
||||||
|
|
||||||
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
|
|
||||||
"""
|
|
||||||
Creates an empty tensor with the same shape as `self`.
|
|
||||||
If `dtype` is not specified, the dtype of `self` is used.
|
|
||||||
"""
|
|
||||||
return self._wrap_uop(self._uop.empty_like(dtype, device))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with Invalid.
|
|
||||||
|
|
||||||
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
|
|
||||||
|
|
||||||
Eventually Tensor.empty will be replaced by this.
|
|
||||||
"""
|
|
||||||
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def full(cls, shape:'tuple[sint, ...]', fill_value:'ConstType|UOp', dtype:DTypeLike|None=None,
|
|
||||||
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with the given value.
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(Tensor.full((2, 3), 42).numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(Tensor.full((2, 3), False).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
# TODO: enable this check
|
|
||||||
# if not buffer: assert device is None, "buffer=False does not support device specification"
|
|
||||||
from tinygrad.uop.ops import UOp
|
|
||||||
new_shape = argfix(shape)
|
|
||||||
dt = to_dtype(dtype) if dtype is not None else None
|
|
||||||
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
|
|
||||||
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
|
|
||||||
return val.clone(device=device) if buffer else val
|
|
||||||
|
|
||||||
def full_like(self, fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the same shape as `self`, filled with the given value.
|
|
||||||
If `dtype` is not specified, the dtype of `self` is used.
|
|
||||||
|
|
||||||
You can pass in the `device` keyword argument to control device of the tensor.
|
|
||||||
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor.ones(2, 3)
|
|
||||||
print(Tensor.full_like(t, 42).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if isinstance(self.device, tuple):
|
|
||||||
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
|
||||||
return self._multi_like(lambda shape, dev: type(self).full(shape, fill_value, dtype=dtype or self.dtype, device=dev, buffer=buffer))
|
|
||||||
return type(self).full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device, buffer=buffer)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def zeros(cls, *shape, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with zeros.
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(Tensor.zeros(2, 3).numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
return cls.full(argfix(*shape), 0.0, **kwargs)
|
|
||||||
|
|
||||||
def zeros_like(self, **kwargs) -> Self:
|
def zeros_like(self, **kwargs) -> Self:
|
||||||
"""
|
"""
|
||||||
|
|
@ -110,23 +22,6 @@ class CreationMixin(DTypeMixin, MovementMixin):
|
||||||
"""
|
"""
|
||||||
return self.full_like(0, **kwargs)
|
return self.full_like(0, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ones(cls, *shape, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with ones.
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(Tensor.ones(2, 3).numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
return cls.full(argfix(*shape), 1.0, **kwargs)
|
|
||||||
|
|
||||||
def ones_like(self, **kwargs) -> Self:
|
def ones_like(self, **kwargs) -> Self:
|
||||||
"""
|
"""
|
||||||
Creates a tensor with the same shape as `self`, filled with ones.
|
Creates a tensor with the same shape as `self`, filled with ones.
|
||||||
|
|
|
||||||
|
|
@ -1,36 +1,13 @@
|
||||||
from typing import TYPE_CHECKING, Self
|
from typing import Self
|
||||||
from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
|
from tinygrad.dtype import DType, dtypes
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from tinygrad.uop.ops import UOp
|
|
||||||
|
|
||||||
class DTypeMixin:
|
class DTypeMixin:
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> DType: raise NotImplementedError
|
def dtype(self) -> DType: raise NotImplementedError
|
||||||
@property
|
|
||||||
def _uop(self) -> 'UOp': raise NotImplementedError
|
|
||||||
def _wrap_uop(self, u:'UOp') -> Self: raise NotImplementedError
|
|
||||||
|
|
||||||
def cast(self, dtype:DTypeLike) -> Self:
|
def cast(self, dtype:DType) -> Self: raise NotImplementedError
|
||||||
"""
|
|
||||||
Casts `self` to the given `dtype`.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
def bitcast(self, dtype:DType) -> Self: raise NotImplementedError
|
||||||
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
|
|
||||||
print(t.dtype, t.numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = t.cast(dtypes.int32)
|
|
||||||
print(t.dtype, t.numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = t.cast(dtypes.uint8)
|
|
||||||
print(t.dtype, t.numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
return self if self.dtype == (dt:=to_dtype(dtype)) else self._wrap_uop(self._uop.cast(dt))
|
|
||||||
|
|
||||||
def bitcast(self, dtype:DTypeLike) -> Self: raise NotImplementedError
|
|
||||||
|
|
||||||
def element_size(self) -> int:
|
def element_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -3,17 +3,23 @@ from typing import TYPE_CHECKING, Literal, Self
|
||||||
from tinygrad.uop import Ops
|
from tinygrad.uop import Ops
|
||||||
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
|
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
|
||||||
from tinygrad.helpers import argfix, polyN
|
from tinygrad.helpers import argfix, polyN
|
||||||
|
from tinygrad.mixin.dtype import DTypeMixin
|
||||||
from tinygrad.mixin.creation import CreationMixin
|
from tinygrad.mixin.creation import CreationMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tinygrad.uop.ops import UOp
|
from tinygrad.uop.ops import UOp
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseMixin(CreationMixin):
|
class ElementwiseMixin(DTypeMixin, CreationMixin):
|
||||||
# required to implement
|
# required to implement
|
||||||
def alu(self, op: Ops, *src: Self) -> Self:
|
def alu(self, op: Ops, *src: Self) -> Self:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _uop(self) -> 'UOp': raise NotImplementedError
|
||||||
|
|
||||||
|
def _wrap_uop(self, u: 'UOp') -> Self: raise NotImplementedError
|
||||||
|
|
||||||
# great functions you get!
|
# great functions you get!
|
||||||
def ufix(self, x: 'Self|ConstType|UOp') -> Self:
|
def ufix(self, x: 'Self|ConstType|UOp') -> Self:
|
||||||
return x if isinstance(x, type(self)) else self._wrap_uop(self._uop.ufix(x))
|
return x if isinstance(x, type(self)) else self._wrap_uop(self._uop.ufix(x))
|
||||||
|
|
@ -45,11 +51,7 @@ class ElementwiseMixin(CreationMixin):
|
||||||
"""
|
"""
|
||||||
return self.cast(dtypes.bool).ne(True)
|
return self.cast(dtypes.bool).ne(True)
|
||||||
|
|
||||||
def contiguous(self, **kwargs) -> Self:
|
def contiguous(self, *args, **kwargs) -> Self: raise NotImplementedError
|
||||||
"""
|
|
||||||
Returns a contiguous tensor.
|
|
||||||
"""
|
|
||||||
return self._wrap_uop(self._uop.contiguous(**kwargs))
|
|
||||||
|
|
||||||
def contiguous_backward(self) -> Self:
|
def contiguous_backward(self) -> Self:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import cast
|
from typing import cast
|
||||||
import math, dataclasses
|
import math, dataclasses, itertools
|
||||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
|
||||||
from tinygrad.helpers import argsort
|
from tinygrad.helpers import argsort
|
||||||
from tinygrad.dtype import sum_acc_dtype
|
from tinygrad.dtype import sum_acc_dtype
|
||||||
|
|
||||||
|
|
@ -33,7 +33,7 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
||||||
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||||
grad_args = ctx.src
|
grad_args = ctx.src
|
||||||
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
|
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
|
||||||
g if g.base.op is Ops.CONST else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
|
g if g.base.op is Ops.CONST and g.device is None else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
|
||||||
grads = compute_gradient(fxn, root_grad, set(params.values()))
|
grads = compute_gradient(fxn, root_grad, set(params.values()))
|
||||||
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
|
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
|
||||||
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
|
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
|
||||||
|
|
@ -42,6 +42,9 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
||||||
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
|
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
|
||||||
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
|
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
|
||||||
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
|
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
|
||||||
|
# TODO: is this okay here?
|
||||||
|
from tinygrad.function import pm_transform_unique_const
|
||||||
|
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
|
||||||
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
|
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
|
||||||
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
|
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
|
||||||
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
|
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import math
|
from typing import Self
|
||||||
from typing import Self, cast
|
from tinygrad.dtype import DType, dtypes
|
||||||
from tinygrad.dtype import DType, DTypeLike, dtypes, least_upper_dtype, to_dtype
|
from tinygrad.helpers import ceildiv, prod
|
||||||
from tinygrad.helpers import all_int, argfix, ceildiv, prod, TRAINING
|
|
||||||
from tinygrad.mixin import OpMixin
|
from tinygrad.mixin import OpMixin
|
||||||
from tinygrad.device import canonicalize_device
|
|
||||||
|
|
||||||
|
|
||||||
class RandMixin(OpMixin):
|
class RandMixin(OpMixin):
|
||||||
|
|
@ -41,286 +39,3 @@ class RandMixin(OpMixin):
|
||||||
bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4))
|
bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4))
|
||||||
out = cls._bits_to_rand(bits, shape, dtype)
|
out = cls._bits_to_rand(bits, shape, dtype)
|
||||||
return out.contiguous() if contiguous else out
|
return out.contiguous() if contiguous else out
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _next_counter(device:str, num:int):
|
|
||||||
raise NotImplementedError("_next_counter requires the stateful per-device RNG counter, only implemented on Tensor")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def rand(cls, *shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
t = Tensor.rand(2, 3)
|
|
||||||
print(t.numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
dt = to_dtype(dtype or dtypes.default_float)
|
|
||||||
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
|
|
||||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
|
||||||
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
|
||||||
device = cast(str, canonicalize_device(device))
|
|
||||||
key, counter = cls._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
|
|
||||||
return cls._rand(key, counter, shape, dt, contiguous=contiguous)
|
|
||||||
|
|
||||||
def rand_like(self, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor.ones(2, 3)
|
|
||||||
print(Tensor.rand_like(t).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if isinstance(self.device, tuple):
|
|
||||||
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
|
||||||
dtype = kwargs.pop("dtype", self.dtype)
|
|
||||||
return self._multi_like(lambda shape, dev: type(self).rand(*shape, dtype=dtype, device=dev, **kwargs))
|
|
||||||
return type(self).rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
|
|
||||||
|
|
||||||
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor.ones(2, 3)
|
|
||||||
print(Tensor.randn_like(t).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
|
|
||||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
|
||||||
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(to_dtype(dtype or self.dtype))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def randn(cls, *shape, dtype:DTypeLike|None=None, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
|
||||||
If `dtype` is not specified, the default type is used.
|
|
||||||
|
|
||||||
You can pass in the `device` keyword argument to control device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
print(Tensor.randn(2, 3).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
return cls.empty(*shape, **kwargs).randn_like(dtype=dtype) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def randint(cls, *shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
|
||||||
Requires `low < high`. If `dtype` is not specified, the default type is used.
|
|
||||||
|
|
||||||
You can pass in the `device` keyword argument to control device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
print(Tensor.randint(2, 3, low=5, high=10).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
|
|
||||||
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
|
|
||||||
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
|
|
||||||
return cls.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def normal(cls, *shape, mean=0.0, std=1.0, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
|
||||||
Requires `std >= 0`.
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
|
|
||||||
return std * cls.randn(*shape, **kwargs) + mean
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def uniform(cls, *shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
|
||||||
Requires `low < high`.
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
|
||||||
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
|
|
||||||
return ((high-low) * cls.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def scaled_uniform(cls, *shape, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Creates a tensor with the given shape, filled with random values from a uniform distribution
|
|
||||||
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
print(Tensor.scaled_uniform(2, 3).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
return cls.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def glorot_uniform(cls, *shape, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
print(Tensor.glorot_uniform(2, 3).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
|
|
||||||
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def kaiming_uniform(cls, *shape, a:float = 0.01, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
print(Tensor.kaiming_uniform(2, 3).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
|
||||||
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def kaiming_normal(cls, *shape, a:float = 0.01, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
|
|
||||||
|
|
||||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
|
||||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
print(Tensor.kaiming_normal(2, 3).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
|
||||||
return cls.normal(*shape, mean=0.0, std=std, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def randperm(cls, n:int, device=None, dtype=dtypes.int32, **kwargs) -> Self:
|
|
||||||
"""
|
|
||||||
Returns a tensor with a random permutation of integers from `0` to `n-1`.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
print(Tensor.randperm(6).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
return cls.rand(n, device=device, **kwargs).argsort().cast(dtype)
|
|
||||||
|
|
||||||
def multinomial(self, num_samples:int = 1, replacement:bool = False) -> Self:
|
|
||||||
"""
|
|
||||||
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
t = Tensor([1, 2, 3, 4])
|
|
||||||
print(t.multinomial(20, replacement=True).numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
t = Tensor([1, 2, 3, 4])
|
|
||||||
print(t.multinomial(3, replacement=False).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
|
||||||
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
|
||||||
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
|
|
||||||
if replacement or num_samples == 1:
|
|
||||||
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
|
||||||
unif_samples = type(self).rand(num_samples, cdf.shape[0], 1).to(self.device) # type: ignore[attr-defined]
|
|
||||||
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
|
||||||
else:
|
|
||||||
# Efraimidis-Spirakis
|
|
||||||
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
|
|
||||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
|
||||||
|
|
||||||
def dropout(self, p=0.5) -> Self:
|
|
||||||
"""
|
|
||||||
Applies dropout to `self`.
|
|
||||||
|
|
||||||
NOTE: dropout is only applied when `TRAINING` is set (e.g. inside `Context(TRAINING=1)`).
|
|
||||||
|
|
||||||
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
Tensor.manual_seed(42)
|
|
||||||
t = Tensor.randn(2, 2)
|
|
||||||
with Context(TRAINING=1):
|
|
||||||
print(t.dropout().numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
|
|
||||||
if not TRAINING or p == 0: return self
|
|
||||||
if p == 1: return self.const_like(0)
|
|
||||||
return (self.rand_like(dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
|
||||||
|
|
||||||
def scaled_dot_product_attention(self, key:Self, value:Self, attn_mask:Self|None=None, dropout_p:float=0.0,
|
|
||||||
is_causal:bool=False, enable_gqa:bool=False) -> Self:
|
|
||||||
"""
|
|
||||||
Computes scaled dot-product attention.
|
|
||||||
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
|
||||||
|
|
||||||
- Paper: https://arxiv.org/abs/1706.03762v7
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
q = Tensor.randn(2, 4, 8)
|
|
||||||
k = Tensor.randn(2, 4, 8)
|
|
||||||
v = Tensor.randn(2, 4, 8)
|
|
||||||
print(q.scaled_dot_product_attention(k, v).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
|
||||||
if enable_gqa:
|
|
||||||
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
|
|
||||||
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
|
|
||||||
|
|
||||||
q = self
|
|
||||||
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
|
|
||||||
# handle attention mask
|
|
||||||
if is_causal:
|
|
||||||
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
|
||||||
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
|
|
||||||
if attn_mask is not None:
|
|
||||||
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
|
||||||
qk = qk + attn_mask
|
|
||||||
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Self, Sequence
|
import string
|
||||||
|
from typing import Self, Sequence, cast
|
||||||
from tinygrad.uop import Ops
|
from tinygrad.uop import Ops
|
||||||
from tinygrad.dtype import DTypeLike, dtypes, sum_acc_dtype, to_dtype
|
from tinygrad.dtype import DTypeLike, dtypes, sum_acc_dtype, to_dtype
|
||||||
from tinygrad.helpers import make_tuple
|
from tinygrad.helpers import argfix, argsort, make_tuple, merge_dicts
|
||||||
from tinygrad.mixin.dtype import DTypeMixin
|
from tinygrad.mixin.dtype import DTypeMixin
|
||||||
from tinygrad.mixin.movement import MovementMixin
|
from tinygrad.mixin.movement import MovementMixin
|
||||||
|
|
||||||
|
|
@ -135,3 +136,44 @@ class ReduceMixin(DTypeMixin, MovementMixin):
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
return self.bool().prod(axis, keepdim)
|
return self.bool().prod(axis, keepdim)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
|
||||||
|
"""
|
||||||
|
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
||||||
|
|
||||||
|
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
x = Tensor([[1, 2], [3, 4]])
|
||||||
|
y = Tensor([[5, 6], [7, 8]])
|
||||||
|
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
|
||||||
|
# expand ellipsis to letters, determine output
|
||||||
|
if "..." in formula:
|
||||||
|
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
|
||||||
|
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
|
||||||
|
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
|
||||||
|
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
|
||||||
|
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
|
||||||
|
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
|
||||||
|
inputs = lhs.split(",")
|
||||||
|
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
|
||||||
|
# trace: take diagonal when letter repeats in single input
|
||||||
|
for i, (s, x) in enumerate(zip(inputs, xs)):
|
||||||
|
for c in set(s):
|
||||||
|
while s.count(c) > 1:
|
||||||
|
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
|
||||||
|
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
|
||||||
|
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
|
||||||
|
s = s[:k] + s[k+1:]
|
||||||
|
inputs[i], xs[i] = s, x
|
||||||
|
# check sizes and build sorted alphabet
|
||||||
|
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
|
||||||
|
alpha = sorted(sz)
|
||||||
|
# align all tensors to alphabet, multiply, sum non-output, permute to output order
|
||||||
|
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
|
||||||
|
for s, x in zip(inputs, xs)]
|
||||||
|
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ base_rewrite = PatternMatcher([
|
||||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \
|
||||||
if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None),
|
if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None),
|
||||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"),
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"),
|
||||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: ctx[x.src[0]] if x.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else None),
|
|
||||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"),
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"),
|
||||||
|
|
||||||
# GPU stuff
|
# GPU stuff
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import sys
|
||||||
sys.setrecursionlimit(10000)
|
sys.setrecursionlimit(10000)
|
||||||
|
|
||||||
def add_ranges_to_store(ctx, x):
|
def add_ranges_to_store(ctx, x):
|
||||||
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == () or x.src[0].max_numel() == x.src[1].max_numel() == 1: return None
|
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == (): return None
|
||||||
assert x.src[0].shape == x.src[1].shape, "bad store shape"
|
assert x.src[0].shape == x.src[1].shape, "bad store shape"
|
||||||
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
|
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
|
||||||
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)
|
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)
|
||||||
|
|
@ -543,6 +543,9 @@ to_define_global = PatternMatcher([
|
||||||
# remove device from local BUFFERIZE
|
# remove device from local BUFFERIZE
|
||||||
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
|
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
|
||||||
|
|
||||||
|
# remove UNIQUE/DEVICE to dedup CONST
|
||||||
|
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||||
|
|
||||||
# renumber the ranges starting with 0 so that kernel deduping works
|
# renumber the ranges starting with 0 so that kernel deduping works
|
||||||
(UPat(Ops.RANGE, name="r"), renumber_range),
|
(UPat(Ops.RANGE, name="r"), renumber_range),
|
||||||
])
|
])
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
|
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
|
||||||
from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
from contextlib import ContextDecorator
|
||||||
|
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||||
if TYPE_CHECKING: import numpy
|
if TYPE_CHECKING: import numpy
|
||||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, to_dtype
|
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype
|
||||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
|
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
|
||||||
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
|
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
|
||||||
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
|
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
|
||||||
from tinygrad.helpers import suppress_finalizing, disable_gc, TRAINING
|
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||||
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
|
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
|
||||||
from tinygrad.mixin.rand import RandMixin
|
from tinygrad.mixin.rand import RandMixin
|
||||||
from tinygrad.schedule import create_linear_with_vars
|
from tinygrad.schedule import create_linear_with_vars
|
||||||
|
|
@ -58,25 +59,19 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
||||||
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# TODO: deprecate this, always use TRAINING
|
class Tensor(RandMixin):
|
||||||
class TensorMeta(type):
|
|
||||||
@property
|
|
||||||
def training(cls) -> bool: return bool(TRAINING.value)
|
|
||||||
@training.setter
|
|
||||||
def training(cls, mode:bool): TRAINING.value = int(mode)
|
|
||||||
|
|
||||||
class Tensor(RandMixin, metaclass=TensorMeta):
|
|
||||||
"""
|
"""
|
||||||
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
||||||
|
|
||||||
```python exec="true" session="tensor"
|
```python exec="true" session="tensor"
|
||||||
from tinygrad import Tensor, dtypes, nn, Context
|
from tinygrad import Tensor, dtypes, nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
np.set_printoptions(precision=4)
|
np.set_printoptions(precision=4)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
__slots__ = "uop", "is_param", "grad"
|
__slots__ = "uop", "is_param", "grad"
|
||||||
|
training: ClassVar[bool] = False
|
||||||
|
|
||||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
|
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
|
||||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None):
|
device:str|tuple|list|None=None, dtype:DTypeLike|None=None):
|
||||||
|
|
@ -130,9 +125,9 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
@suppress_finalizing
|
@suppress_finalizing
|
||||||
def __del__(self): all_tensors.pop(weakref.ref(self), None)
|
def __del__(self): all_tensors.pop(weakref.ref(self), None)
|
||||||
|
|
||||||
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, **kwargs) -> Tensor:
|
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, extra_args=(), **kwargs) -> Tensor:
|
||||||
srcs = (self,)+x
|
srcs = (self,)+x
|
||||||
new_uop: UOp = fxn(*[t.uop for t in srcs], **kwargs)
|
new_uop: UOp = fxn(*[t.uop for t in srcs], *extra_args, **kwargs)
|
||||||
if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
|
if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
|
||||||
# directly create the Tensor
|
# directly create the Tensor
|
||||||
ret = Tensor.__new__(Tensor)
|
ret = Tensor.__new__(Tensor)
|
||||||
|
|
@ -141,18 +136,34 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
all_tensors[weakref.ref(ret)] = None
|
all_tensors[weakref.ref(ret)] = None
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# alu, _uop, _wrap_uop and const are used by the mixins
|
# alu and const_like are used by the mixins
|
||||||
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
|
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
|
||||||
@property
|
@property
|
||||||
def _uop(self) -> UOp: return self.uop
|
def _uop(self) -> UOp: return self.uop
|
||||||
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
|
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
|
||||||
|
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def const(dtype:DType, b:ConstType|UOp) -> Tensor: return Tensor(UOp.const(dtype, b))
|
def const(dtype:DType, b:ConstType|UOp) -> Tensor: return Tensor(UOp.const(dtype, b))
|
||||||
|
@staticmethod
|
||||||
|
def invalids(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the given shape, filled with Invalid.
|
||||||
|
|
||||||
|
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
|
||||||
|
|
||||||
|
Eventually Tensor.empty will be replaced by this.
|
||||||
|
"""
|
||||||
|
return Tensor(UOp.invalids(argfix(*shape), dtype, device))
|
||||||
|
|
||||||
def is_param_(self, is_param:bool=True) -> Tensor:
|
def is_param_(self, is_param:bool=True) -> Tensor:
|
||||||
self.is_param = is_param
|
self.is_param = is_param
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
class train(ContextDecorator):
|
||||||
|
def __init__(self, mode:bool = True): self.mode = mode
|
||||||
|
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
ld = self.uop
|
ld = self.uop
|
||||||
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>"
|
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]}>"
|
||||||
|
|
@ -264,7 +275,6 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
x = self.cast(self.dtype.base).contiguous()
|
x = self.cast(self.dtype.base).contiguous()
|
||||||
if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU")
|
if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU")
|
||||||
return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
|
return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
|
||||||
|
|
||||||
def _data(self) -> memoryview: return self._buffer().as_memoryview()
|
def _data(self) -> memoryview: return self._buffer().as_memoryview()
|
||||||
|
|
||||||
def data(self) -> memoryview:
|
def data(self) -> memoryview:
|
||||||
|
|
@ -280,7 +290,19 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||||
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
||||||
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
|
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
|
||||||
return self._data().cast(self.dtype.base.fmt, self.shape)
|
return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape)
|
||||||
|
|
||||||
|
def item(self) -> PyConst:
|
||||||
|
"""
|
||||||
|
Returns the value of this tensor as a standard Python number.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor(42)
|
||||||
|
print(t.item())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
assert self.numel() == 1, "must have one element for item"
|
||||||
|
return self.data()[(0,) * len(self.shape)]
|
||||||
|
|
||||||
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
|
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
|
||||||
def tolist(self) -> PyConst|list[Any]:
|
def tolist(self) -> PyConst|list[Any]:
|
||||||
|
|
@ -442,6 +464,13 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
"""
|
"""
|
||||||
return Tensor(UOp.empty(argfix(*shape), dtype, device))
|
return Tensor(UOp.empty(argfix(*shape), dtype, device))
|
||||||
|
|
||||||
|
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates an empty tensor with the same shape as `self`.
|
||||||
|
If `dtype` is not specified, the dtype of `self` is used.
|
||||||
|
"""
|
||||||
|
return Tensor(self.uop.empty_like(dtype, device))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
|
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
@ -504,6 +533,261 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
|
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
|
||||||
return Tensor._device_seeds[device], low.cat(high)
|
return Tensor._device_seeds[device], low.cat(high)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
t = Tensor.rand(2, 3)
|
||||||
|
print(t.numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
dt = to_dtype(dtype or dtypes.default_float)
|
||||||
|
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
|
||||||
|
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||||
|
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
||||||
|
device = cast(str, canonicalize_device(device))
|
||||||
|
key, counter = Tensor._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
|
||||||
|
return Tensor._rand(key, counter, shape, dt, contiguous=contiguous)
|
||||||
|
|
||||||
|
# ***** creation helper functions *****
|
||||||
|
|
||||||
|
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Tensor]) -> Tensor:
|
||||||
|
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
|
||||||
|
if self.uop.axis is None: return fxn(self.shape, None).shard(self.device)
|
||||||
|
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, d).uop for d in self.device])
|
||||||
|
return Tensor(stacked.multi(self.uop.axis))
|
||||||
|
|
||||||
|
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the same shape as `self`, filled with the given value.
|
||||||
|
If `dtype` is not specified, the dtype of `self` is used.
|
||||||
|
|
||||||
|
You can pass in the `device` keyword argument to control device of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.ones(2, 3)
|
||||||
|
print(Tensor.full_like(t, 42).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if isinstance(self.device, tuple):
|
||||||
|
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||||
|
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
|
||||||
|
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
|
||||||
|
|
||||||
|
def rand_like(self, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.ones(2, 3)
|
||||||
|
print(Tensor.rand_like(t).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if isinstance(self.device, tuple):
|
||||||
|
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||||
|
dtype = kwargs.pop("dtype", self.dtype)
|
||||||
|
return self._multi_like(lambda shape, dev: Tensor.rand(*shape, dtype=dtype, device=dev, **kwargs))
|
||||||
|
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
|
||||||
|
|
||||||
|
# ***** random functions *****
|
||||||
|
|
||||||
|
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.ones(2, 3)
|
||||||
|
print(Tensor.randn_like(t).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
|
||||||
|
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||||
|
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or self.dtype)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def randn(*shape, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
||||||
|
If `dtype` is not specified, the default type is used.
|
||||||
|
|
||||||
|
You can pass in the `device` keyword argument to control device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
print(Tensor.randn(2, 3).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return Tensor.empty(*shape, **kwargs).randn_like(dtype=dtype)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
||||||
|
Requires `low < high`. If `dtype` is not specified, the default type is used.
|
||||||
|
|
||||||
|
You can pass in the `device` keyword argument to control device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
print(Tensor.randint(2, 3, low=5, high=10).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
|
||||||
|
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
|
||||||
|
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
|
||||||
|
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
||||||
|
Requires `std >= 0`.
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
|
||||||
|
return std * Tensor.randn(*shape, **kwargs) + mean
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
||||||
|
Requires `low < high`.
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||||
|
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
|
||||||
|
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Creates a tensor with the given shape, filled with random values from a uniform distribution
|
||||||
|
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
print(Tensor.scaled_uniform(2, 3).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def glorot_uniform(*shape, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
print(Tensor.glorot_uniform(2, 3).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
|
||||||
|
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
print(Tensor.kaiming_uniform(2, 3).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
||||||
|
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
|
||||||
|
|
||||||
|
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||||
|
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
print(Tensor.kaiming_normal(2, 3).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
||||||
|
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns a tensor with a random permutation of integers from `0` to `n-1`.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
print(Tensor.randperm(6).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return Tensor.rand(n, device=device, **kwargs).argsort().cast(dtype)
|
||||||
|
|
||||||
|
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
t = Tensor([1, 2, 3, 4])
|
||||||
|
print(t.multinomial(20, replacement=True).numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
t = Tensor([1, 2, 3, 4])
|
||||||
|
print(t.multinomial(3, replacement=False).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
||||||
|
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
||||||
|
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
|
||||||
|
if replacement or num_samples == 1:
|
||||||
|
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
||||||
|
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
|
||||||
|
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||||
|
else:
|
||||||
|
# Efraimidis–Spirakis
|
||||||
|
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
|
||||||
|
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
||||||
|
|
||||||
# ***** toposort and backward pass *****
|
# ***** toposort and backward pass *****
|
||||||
|
|
||||||
def backward(self, gradient:Tensor|None=None) -> Tensor:
|
def backward(self, gradient:Tensor|None=None) -> Tensor:
|
||||||
|
|
@ -530,9 +814,49 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
|
|
||||||
# ***** movement ops *****
|
# ***** movement ops *****
|
||||||
|
|
||||||
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, op=op, arg=arg)
|
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
|
||||||
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
|
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
|
||||||
|
|
||||||
|
def __getitem__(self, indices) -> Tensor:
|
||||||
|
"""
|
||||||
|
Retrieves a sub-tensor using indexing.
|
||||||
|
|
||||||
|
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.arange(12).reshape(3, 4)
|
||||||
|
print(t.numpy())
|
||||||
|
```
|
||||||
|
|
||||||
|
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t[1, 2].numpy())
|
||||||
|
```
|
||||||
|
|
||||||
|
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t[0:2, ::2].numpy())
|
||||||
|
```
|
||||||
|
|
||||||
|
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
|
||||||
|
```
|
||||||
|
|
||||||
|
- `None` Indexing: Add a new dimension to the tensor.
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t[:, None].shape)
|
||||||
|
```
|
||||||
|
|
||||||
|
NOTE: Out-of-bounds indexing results in a value of `0`.
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor([1, 2, 3])
|
||||||
|
print(t[Tensor([4, 3, 2])].numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return super().__getitem__(indices)
|
||||||
|
|
||||||
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
|
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
|
||||||
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
|
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
|
||||||
# raise if mutation would diverge from eager (allow only pure views of a realized buffer; exclude +=/-= RHS via v_uop/v_bw)
|
# raise if mutation would diverge from eager (allow only pure views of a realized buffer; exclude +=/-= RHS via v_uop/v_bw)
|
||||||
|
|
@ -563,6 +887,68 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
def __delitem__(self, indices) -> None:
|
def __delitem__(self, indices) -> None:
|
||||||
raise TypeError("Tensor does not support deleting items")
|
raise TypeError("Tensor does not support deleting items")
|
||||||
|
|
||||||
|
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
|
||||||
|
"""
|
||||||
|
Selects elements from `self` based on the boolean `mask`.
|
||||||
|
|
||||||
|
With `size=None` (default), output length equals the number of `True` values (not jittable).
|
||||||
|
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
|
||||||
|
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
|
||||||
|
print(t.numpy())
|
||||||
|
print(mask.numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.masked_select(mask).numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
|
||||||
|
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
|
||||||
|
mask_cumsum = mask.cumsum()
|
||||||
|
if size is None:
|
||||||
|
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
|
||||||
|
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
|
||||||
|
counts = Tensor.zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
|
||||||
|
return (Tensor.arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
|
||||||
|
|
||||||
|
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns the indices of the elements that are non-zero.
|
||||||
|
|
||||||
|
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
|
||||||
|
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor([1, 0, 2, 0, 3])
|
||||||
|
print(t.numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.nonzero().numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor([[1, 0], [0, 2]])
|
||||||
|
print(t.numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.nonzero().numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.nonzero(size=3, fill_value=-1).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if self.ndim == 0:
|
||||||
|
return Tensor.zeros(size if size is not None else int((self != 0).item()), 0, dtype=dtypes.int32, device=self.device)
|
||||||
|
mask = (self != 0).flatten()
|
||||||
|
indices = Tensor.stack(*[Tensor.arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
|
||||||
|
for i, s in enumerate(self.shape)], dim=-1)
|
||||||
|
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
|
||||||
|
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
|
||||||
|
|
||||||
# ***** reduce ops *****
|
# ***** reduce ops *****
|
||||||
|
|
||||||
def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
|
def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
|
||||||
|
|
@ -668,11 +1054,11 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
|
|
||||||
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
||||||
|
|
||||||
# compute 6x6 winograd tiles: GgGt, BtdB. contiguous so the transforms are materialized once
|
# compute 6x6 winograd tiles: GgGt, BtdB
|
||||||
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
||||||
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
||||||
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
||||||
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
|
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
|
||||||
|
|
||||||
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
||||||
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
|
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
|
||||||
|
|
@ -719,6 +1105,14 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
if IMAGE: return self.image_dot(w, dtype)
|
if IMAGE: return self.image_dot(w, dtype)
|
||||||
return super().dot(w, dtype)
|
return super().dot(w, dtype)
|
||||||
|
|
||||||
|
# ***** unary ops *****
|
||||||
|
|
||||||
|
def contiguous(self, *args, **kwargs) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns a contiguous tensor.
|
||||||
|
"""
|
||||||
|
return self._apply_uop(UOp.contiguous, extra_args=args, **kwargs)
|
||||||
|
|
||||||
# ***** broadcasted elementwise ops *****
|
# ***** broadcasted elementwise ops *****
|
||||||
|
|
||||||
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
|
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
|
||||||
|
|
@ -778,8 +1172,80 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
||||||
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
|
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
|
||||||
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)))
|
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)))
|
||||||
|
|
||||||
|
# ***** functional nn ops *****
|
||||||
|
|
||||||
|
def dropout(self, p=0.5) -> Tensor:
|
||||||
|
"""
|
||||||
|
Applies dropout to `self`.
|
||||||
|
|
||||||
|
NOTE: dropout is only applied when `Tensor.training` is `True`.
|
||||||
|
|
||||||
|
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
Tensor.manual_seed(42)
|
||||||
|
t = Tensor.randn(2, 2)
|
||||||
|
with Tensor.train():
|
||||||
|
print(t.dropout().numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
|
||||||
|
if not Tensor.training or p == 0: return self
|
||||||
|
if p == 1: return self.const_like(0)
|
||||||
|
return (Tensor.rand_like(self, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
|
||||||
|
is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
|
||||||
|
"""
|
||||||
|
Computes scaled dot-product attention.
|
||||||
|
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
||||||
|
|
||||||
|
- Paper: https://arxiv.org/abs/1706.03762v7
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
q = Tensor.randn(2, 4, 8)
|
||||||
|
k = Tensor.randn(2, 4, 8)
|
||||||
|
v = Tensor.randn(2, 4, 8)
|
||||||
|
print(q.scaled_dot_product_attention(k, v).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
|
if enable_gqa:
|
||||||
|
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
|
||||||
|
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
|
||||||
|
|
||||||
|
q = self
|
||||||
|
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
|
||||||
|
# handle attention mask
|
||||||
|
if is_causal:
|
||||||
|
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
||||||
|
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
|
||||||
|
if attn_mask is not None:
|
||||||
|
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
||||||
|
qk = qk + attn_mask
|
||||||
|
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
|
||||||
|
|
||||||
# ***** cast ops *****
|
# ***** cast ops *****
|
||||||
|
|
||||||
|
def cast(self, dtype:DTypeLike) -> Tensor:
|
||||||
|
"""
|
||||||
|
Casts `self` to the given `dtype`.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
|
||||||
|
print(t.dtype, t.numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = t.cast(dtypes.int32)
|
||||||
|
print(t.dtype, t.numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = t.cast(dtypes.uint8)
|
||||||
|
print(t.dtype, t.numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return self if self.dtype == (dt:=to_dtype(dtype)) else self._apply_uop(UOp.cast, dtype=dt)
|
||||||
|
|
||||||
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Bitcasts `self` to the given `dtype` of the same itemsize.
|
Bitcasts `self` to the given `dtype` of the same itemsize.
|
||||||
|
|
|
||||||
|
|
@ -454,8 +454,7 @@ def floormod_to_mod(a:UOp, b:UOp) -> UOp:
|
||||||
|
|
||||||
powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
|
powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
|
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
|
||||||
# these are rewrites that make things simpler
|
|
||||||
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
|
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
|
||||||
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
|
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
|
||||||
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
|
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
|
||||||
|
|
@ -464,11 +463,6 @@ def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
|
||||||
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
|
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
|
||||||
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
|
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||||
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
|
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
|
||||||
return PatternMatcher(pat)
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
|
|
||||||
pat: list[tuple[UPat, Callable]] = []
|
|
||||||
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
|
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
|
||||||
lambda x,y: (x | y).logical_not())]
|
lambda x,y: (x | y).logical_not())]
|
||||||
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
||||||
|
|
|
||||||
|
|
@ -560,6 +560,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
|
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
|
||||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
|
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, unique=True) -> UOp:
|
||||||
|
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
|
||||||
|
ret = UOp(Ops.CONST, dt, arg=dt.const(Invalid),
|
||||||
|
src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device))))
|
||||||
|
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
|
||||||
|
@staticmethod
|
||||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
|
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
|
||||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -791,7 +797,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
|
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
|
||||||
return self.src[0].addrspace
|
return self.src[0].addrspace
|
||||||
if self.op in GroupOp.Movement: return self.src[0].addrspace
|
if self.op in GroupOp.Movement: return self.src[0].addrspace
|
||||||
if self.op in {Ops.STACK, Ops.WMMA} or self.op in GroupOp.Elementwise:
|
if self.op in {Ops.STACK, Ops.PTRCAT, Ops.WMMA} or self.op in GroupOp.Elementwise:
|
||||||
ad = [x.addrspace for x in self.src if x.addrspace is not None]
|
ad = [x.addrspace for x in self.src if x.addrspace is not None]
|
||||||
if not len(ad) or not all_same(ad): return None
|
if not len(ad) or not all_same(ad): return None
|
||||||
return ad[0]
|
return ad[0]
|
||||||
|
|
@ -919,6 +925,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
|
|
||||||
# *** uop symbolic stuff ***
|
# *** uop symbolic stuff ***
|
||||||
|
|
||||||
|
def is_increasing(self:UOp) -> bool:
|
||||||
|
# is f a monotonically increasing function regards its input
|
||||||
|
if self.op in GroupOp.Irreducible: return True
|
||||||
|
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
|
||||||
|
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
|
||||||
|
return False # False if not sure
|
||||||
def const_factor(self) -> int:
|
def const_factor(self) -> int:
|
||||||
"""largest known int that divides self"""
|
"""largest known int that divides self"""
|
||||||
# TODO: for negatives it's not the largest
|
# TODO: for negatives it's not the largest
|
||||||
|
|
@ -1052,9 +1064,9 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
|
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
|
||||||
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
|
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
|
||||||
return ret
|
return ret
|
||||||
def placeholder_like(self, slot:int, addrspace=AddrSpace.GLOBAL):
|
def placeholder_like(self, slot:int):
|
||||||
assert all_int(self.shape), "no placeholder-like on symbolic shape"
|
assert all_int(self.shape), "no placeholder-like on symbolic shape"
|
||||||
return UOp.placeholder(self.max_shard_shape, self.dtype, slot, addrspace)
|
return UOp.placeholder(self.max_shard_shape, self.dtype, slot)
|
||||||
|
|
||||||
# set is store+end+after
|
# set is store+end+after
|
||||||
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
|
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
|
||||||
|
|
@ -1184,9 +1196,8 @@ python_alu: dict[Ops, Callable] = {
|
||||||
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
|
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
|
||||||
|
|
||||||
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
||||||
if any(isinstance(x, tuple) for x in operands):
|
if dtype.count > 1:
|
||||||
count = max(len(x) for x in operands if isinstance(x, tuple))
|
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
||||||
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(count)])
|
|
||||||
if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid
|
if dtype==dtypes.weakint and op in GroupOp.Binary and Invalid in operands: return Invalid
|
||||||
alu = python_alu[op](*operands)
|
alu = python_alu[op](*operands)
|
||||||
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
|
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
|
||||||
|
|
@ -1667,8 +1678,6 @@ pm_lower_index_dtype = PatternMatcher([
|
||||||
lambda var,val: var.bind(val).cast(dtypes.weakint)),
|
lambda var,val: var.bind(val).cast(dtypes.weakint)),
|
||||||
# remove hanging casts
|
# remove hanging casts
|
||||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("slen", dtypes.ints).cast(),), name="shrink"),
|
|
||||||
lambda shrink,buf,idx,slen: shrink.replace(src=(buf,idx,slen))),
|
|
||||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
|
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
|
||||||
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
|
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
|
||||||
# remove hanging casts for images
|
# remove hanging casts for images
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes, Invalid
|
||||||
from tinygrad.uop import Ops, GroupOp
|
from tinygrad.uop import Ops, GroupOp
|
||||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
|
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
|
||||||
from tinygrad.helpers import strip_parens
|
from tinygrad.helpers import strip_parens
|
||||||
|
|
@ -77,6 +77,8 @@ def render_marg(ctx,x:UOp):
|
||||||
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
|
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
|
||||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
|
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
|
||||||
pm_pyrender_extra = PatternMatcher([
|
pm_pyrender_extra = PatternMatcher([
|
||||||
|
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"),
|
||||||
|
lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
|
||||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||||
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
|
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
|
||||||
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
|
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
|
||||||
|
|
|
||||||
|
|
@ -126,6 +126,9 @@ spec_tensor = PatternMatcher([
|
||||||
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
||||||
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
|
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
|
||||||
|
|
||||||
|
# CONST with a UNIQUE and DEVICE
|
||||||
|
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid),
|
||||||
|
|
||||||
# BUFFER
|
# BUFFER
|
||||||
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
|
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
|
||||||
|
|
@ -149,7 +152,8 @@ spec_tensor = PatternMatcher([
|
||||||
|
|
||||||
# movement ops
|
# movement ops
|
||||||
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat())), lambda: True),
|
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat())), lambda: True),
|
||||||
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"), lambda x: x.src[1].shape == x.src[2].shape),
|
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"),
|
||||||
|
lambda x: x.src[1].dtype.count == x.src[2].dtype.count),
|
||||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
|
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
|
||||||
|
|
||||||
# REDUCE has arg=(op, axis_tuple), src[1:] are ranges after lowering
|
# REDUCE has arg=(op, axis_tuple), src[1:] are ranges after lowering
|
||||||
|
|
|
||||||
|
|
@ -121,8 +121,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||||
# TODO: combine this with "# rules for threefry" below
|
# TODO: combine this with "# rules for threefry" below
|
||||||
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
|
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
|
||||||
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
|
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
|
||||||
((UPat.var("x") & UPat.cvar("mask")) // UPat.cvar("c"),
|
|
||||||
lambda x,mask,c: x // c.arg if c.arg > 0 and c.arg & (c.arg-1) == 0 and mask.arg | (c.arg-1) == -1 else None),
|
|
||||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
|
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
|
||||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||||
# ** constant folding **
|
# ** constant folding **
|
||||||
|
|
@ -162,7 +160,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||||
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
||||||
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
||||||
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
||||||
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
|
||||||
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
|
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
|
||||||
# ** simple where folding **
|
# ** simple where folding **
|
||||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||||
|
|
@ -170,9 +167,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||||
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||||
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
||||||
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
||||||
# STACK on INDEX CONST (TODO: remove all the GEP crap)
|
|
||||||
(UPat(Ops.STACK, src=UPat(Ops.INDEX, src=(UPat.var("src"), UPat(Ops.CONST))), name="stk"),
|
|
||||||
lambda src,stk: src if stk.shape == src.shape and list(range(len(stk.src))) == [x.src[1].arg for x in stk.src] else None),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
||||||
|
|
|
||||||
|
|
@ -121,6 +121,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
||||||
for u in (toposort:=x.toposort()):
|
for u in (toposort:=x.toposort()):
|
||||||
# always exclude DEVICE/CONST/UNIQUE
|
# always exclude DEVICE/CONST/UNIQUE
|
||||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||||
|
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
|
||||||
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
||||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue