Merge remote-tracking branch 'upstream/master' into new_x86_backend

This commit is contained in:
ttomsa 2026-03-19 20:54:58 +00:00
commit 292e1745b2
60 changed files with 1800 additions and 471 deletions

View file

@ -590,10 +590,14 @@ jobs:
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: openpilot compile3 0.10.0 driving_policy
run: BENCHMARK_LOG=openpilot_0_10_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.10.0/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.10.0 dmonitoring
run: BENCHMARK_LOG=openpilot_0_10_0_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.10.0/selfdrive/modeld/models/dmonitoring_model.onnx
- name: IMAGE=1 openpilot compile3 0.11.0 driving_vision
run: BENCHMARK_LOG=image_1_openpilot_0_11_0_vision PYTHONPATH="." DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot compile3 0.11.0 driving_vision
run: BENCHMARK_LOG=openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot compile3 0.11.0 driving_policy
run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.11.0 dmonitoring
run: BENCHMARK_LOG=openpilot_0_11_0_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/dmonitoring_model.onnx
- name: DEBUG=2 openpilot compile3 0.10.1 driving_vision
run: PYTHONPATH="." DEBUG=2 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
- name: DEBUG=2 IMAGE=1 openpilot compile3 0.10.1 driving_vision
@ -699,6 +703,14 @@ jobs:
- name: Run 10 MLPerf Bert training steps (1 gpu)
# TODO: remove BERT_LAYERS once scheduler is fast
run: BENCHMARK_LOG=bert_10steps AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
- name: Remote
run: |
pkill -f 'extra/remote/serve.py' || true
PYTHONPATH=. python3 extra/remote/serve.py 6482 &
sleep 1
DEBUG=2 PYTHONPATH=. REMOTE=127.0.0.1:6482 AM_RESET=1 AMD=1 AMD_IFACE=PCI python3 test/test_tiny.py
DEBUG=2 PYTHONPATH=. REMOTE=127.0.0.1:6482 AM_RESET=1 AMD=1 AMD_AQL=1 AMD_IFACE=PCI python3 test/test_tiny.py
pkill -f 'extra/remote/serve.py' || true
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
@ -755,5 +767,12 @@ jobs:
- name: Run 10 MLPerf Bert training steps (1 gpu)
# TODO: remove BERT_LAYERS once scheduler is fast
run: BENCHMARK_LOG=bert_10steps NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
- name: Remote
run: |
pkill -f 'extra/remote/serve.py' || true
PYTHONPATH=. python3 extra/remote/serve.py 6483 &
sleep 1
DEBUG=2 PYTHONPATH=. REMOTE=127.0.0.1:6483 NV=1 python3 test/test_tiny.py
pkill -f 'extra/remote/serve.py' || true
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py

View file

@ -708,8 +708,7 @@ jobs:
fail-fast: false
matrix:
backend: [amd, amdllvm]
arch: [rdna3, rdna4]
#arch: [rdna3, rdna4, cdna4]
arch: [rdna3, rdna4, cdna4]
name: Linux (${{ matrix.backend }} ${{ matrix.arch }})
runs-on: ubuntu-22.04
@ -735,7 +734,7 @@ jobs:
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['AMD'], Device.DEFAULT"
DEBUG=5 FORWARD_ONLY=1 python3 test/test_tiny.py TestTiny.test_plus
- name: Run pytest (amd)
run: python -m pytest -n=auto test/backend/test_ops.py test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_linearizer.py test/backend/test_randomness.py test/backend/test_jit.py test/backend/test_graph.py test/backend/test_multitensor.py test/device/test_hcq.py test/external/external_test_am.py --durations=20
run: python -m pytest -n=auto test/backend/test_ops.py test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_linearizer.py test/backend/test_randomness.py test/backend/test_jit.py test/backend/test_graph.py test/backend/test_multitensor.py test/device/test_hcq.py test/external/external_test_am.py test/backend/test_asm_gemm.py::TestAsmGEMM --durations=20
- name: Run TRANSCENDENTAL math
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/backend/test_ops.py::TestOps::test_sin test/backend/test_ops.py::TestOps::test_cos test/backend/test_ops.py::TestOps::test_tan test/backend/test_ops.py::TestOps::test_exp test/backend/test_ops.py::TestOps::test_log --durations=20
- name: Run process replay tests

View file

@ -3,8 +3,18 @@ if __name__ == "__main__":
os.environ["DEFAULT_FLOAT"] = "bfloat16"
os.environ["OPTIM_DTYPE"] = "bfloat16"
os.environ["DEV"] = "NULL"
# CDNA
os.environ["EMULATE"] = "AMD_CDNA4"
os.environ["DEVICE_IN_FUNCTION_BUG"] = "1"
os.environ["ALL2ALL"] = "1"
os.environ["USE_ATOMICS"] = "1"
if "HK_FLASH_ATTENTION" not in os.environ:
os.environ["HK_FLASH_ATTENTION"] = "1"
if "ASM_GEMM" not in os.environ:
os.environ["ASM_GEMM"] = "1"
from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
from tinygrad.helpers import Timing, colored, GlobalCounters
from tinygrad.uop.ops import Ops, UOp
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
def rmsnorm(x_in:Tensor, eps:float):
@ -32,8 +42,8 @@ class FlatTransformer:
self.w3 = self.lin_per_layer(dim, hidden_dim)
self.norm_eps = norm_eps
self.attention_norm = Tensor.ones(n_layers, dim)
self.ffn_norm = Tensor.ones(n_layers, dim)
self.attention_norm = Tensor.ones(n_layers, dim).contiguous()
self.ffn_norm = Tensor.ones(n_layers, dim).contiguous()
# output
self.norm = nn.RMSNorm(dim, norm_eps)
@ -50,18 +60,14 @@ class FlatTransformer:
x = rmsnorm(x, self.norm_eps) * attention_norm
xqkv = x @ wqkv.T
# reshapes
xqkv = xqkv.reshape(xqkv.shape[0], xqkv.shape[1], self.n_kv_heads, self.n_rep + 2, self.head_dim)
xq = xqkv[:, :, :, :self.n_rep].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xk = xqkv[:, :, :, self.n_rep:self.n_rep+1].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xv = xqkv[:, :, :, self.n_rep+1:self.n_rep+2].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
bsz, seqlen, _ = xqkv.shape
# interleaved layout: each kv group has [n_rep q heads, 1 k head, 1 v head] for clean MP sharding
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
xq = xqkv[:, :, :, :self.n_rep].reshape(bsz, seqlen, self.n_heads, self.head_dim)
xk = xqkv[:, :, :, self.n_rep].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
bsz, seqlen, _, _ = xq.shape
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2)
attn = attn.reshape(bsz, seqlen, -1)
@ -80,6 +86,24 @@ class FlatTransformer:
h = x + self.attention(x, freqs_cis, attention_norm, wqkv, wo)
return h + self.feed_forward(h, ffn_norm, w1, w2, w3)
def shard(self, device:tuple[str, ...], mp:bool=False):
from tinygrad.nn.state import get_parameters
if not mp:
for v in get_parameters(self): v.shard_(device, axis=None)
else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
self.w1.shard_(device, axis=1).realize() # (n_layers, hidden, dim) shard out
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
self.w3.shard_(device, axis=1).realize() # (n_layers, hidden, dim) shard out
self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize()
self.tok_embeddings.weight.shard_(device, axis=0).realize()
self.output.weight.shard_(device, axis=0).realize()
self.freqs_cis.shard_(device, axis=None).realize()
def __call__(self, tokens:Tensor):
h = self.tok_embeddings(tokens)
freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :]
@ -101,25 +125,51 @@ if __name__ == "__main__":
model = FlatTransformer(**model_params, max_context=SEQLEN)
state = nn.state.get_state_dict(model)
print("tensor count:", len(state))
# shard the model
from tinygrad import Device
if (DP := getenv("DP", 1)) > 1:
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)))
if (MP := getenv("MP", 1)) > 1:
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)), mp=True)
# preallocate all the grad buffers and zero them out
grads = {x:Tensor.zeros_like(x).contiguous() for x in state.values() if x.requires_grad is None}
# print model size
sz = 0
for k,v in state.items():
if v.requires_grad is None: v.requires_grad_(True)
print(f"{colored(k, 'green' if v.requires_grad else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device}")
print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device} {v.nbytes()/1e9:.2f} GB")
sz += v.nbytes()
print(f"total sz: {sz/1e9:.2f} GB")
with Timing("realize weights: "): Tensor.realize(*state.values())
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int).realize()
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int)
with Timing("realize weights/grads/data: "): Tensor.realize(*state.values(), *grads.values(), tokens)
print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items())))
if DP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)), axis=0)
if MP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)))
# TODO: this shouldn't be needed, but it prevents a copy of the grads. CAT can help
def apply_grad(old_grad:UOp, new_grad:UOp) -> list[UOp]:
if new_grad.op == Ops.ADD:
return apply_grad(old_grad, new_grad.src[0])+apply_grad(old_grad, new_grad.src[1])
elif new_grad.op == Ops.PAD:
grad_shrink = tuple([(p[0], s+p[0]) for s,p in zip(new_grad.src[0].shape, new_grad.marg)])
return apply_grad(old_grad.shrink(grad_shrink), new_grad.src[0])
else:
return [old_grad.store(old_grad + new_grad)]
@TinyJit
def jit_step(tokens:Tensor):
GlobalCounters.reset()
print(colored("*** step", "red"))
with Timing("python forward: "): loss = model(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:])
with Timing("python backward: "): loss.backward()
with Timing("run step: "): loss.realize(*[x.grad for x in state.values() if x.requires_grad])
with Timing("python backward: "):
for t,g in zip(grads, loss.gradient(*grads)):
grads[t] = Tensor(grads[t].uop.after(UOp.group(*apply_grad(grads[t].uop, g.uop))), device=t.device)
with Timing("run step: "): loss.realize(*grads.values())
jit_step(tokens)
jit_step(tokens)
jit_step(tokens)
print(f"mem used: {GlobalCounters.mem_used/1e9:.2f} GB")
print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items())))

View file

@ -0,0 +1,117 @@
import os
os.environ["WQKV"] = "1"
import unittest
import numpy as np
from tinygrad import Tensor, nn
from tinygrad.nn.state import get_parameters
from examples.mlperf.models.llama import Transformer
from examples.mlperf.models.flat_llama import FlatTransformer
def copy_weights(flat:FlatTransformer, ref:Transformer):
n_layers = flat.n_layers
Tensor.realize(*nn.state.get_state_dict(ref).values())
flat.wqkv.assign(Tensor(np.stack([ref.layers[i].attention.wqkv.weight.numpy() for i in range(n_layers)])))
flat.wo.assign(Tensor(np.stack([ref.layers[i].attention.wo.weight.numpy() for i in range(n_layers)])))
flat.w1.assign(Tensor(np.stack([ref.layers[i].feed_forward.w1.weight.numpy() for i in range(n_layers)])))
flat.w2.assign(Tensor(np.stack([ref.layers[i].feed_forward.w2.weight.numpy() for i in range(n_layers)])))
flat.w3.assign(Tensor(np.stack([ref.layers[i].feed_forward.w3.weight.numpy() for i in range(n_layers)])))
flat.attention_norm.assign(Tensor(np.stack([ref.layers[i].attention_norm.weight.numpy() for i in range(n_layers)])))
flat.ffn_norm.assign(Tensor(np.stack([ref.layers[i].ffn_norm.weight.numpy() for i in range(n_layers)])))
flat.norm.weight.assign(Tensor(ref.norm.weight.numpy()))
flat.tok_embeddings.weight.assign(Tensor(ref.tok_embeddings.weight.numpy()))
flat.output.weight.assign(Tensor(ref.output.weight.numpy()))
class TestFlatLlama(unittest.TestCase):
def test_forward_match(self):
Tensor.manual_seed(42)
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
ref = Transformer(**params)
flat = FlatTransformer(**params)
copy_weights(flat, ref)
Tensor.realize(*nn.state.get_state_dict(flat).values())
tokens = Tensor([[1, 50, 100, 999, 2]])
ref_logits = ref(tokens).realize()
flat_logits = flat(tokens).realize()
self.assertEqual(ref_logits.shape, flat_logits.shape)
diff = (ref_logits - flat_logits).abs().max().item()
self.assertLess(diff, 1e-5, f"forward mismatch: max abs diff {diff}")
def test_backward_match(self):
Tensor.manual_seed(42)
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
ref = Transformer(**params)
flat = FlatTransformer(**params)
copy_weights(flat, ref)
for p in get_parameters(ref): p.requires_grad_(True)
for p in get_parameters(flat): p.requires_grad_(True)
Tensor.realize(*nn.state.get_state_dict(flat).values())
tokens = Tensor([[1, 50, 100, 999, 2, 10]])
ref_loss = ref(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:])
ref_loss.backward()
ref_grads = {k: v.grad.numpy() for k, v in nn.state.get_state_dict(ref).items() if v.grad is not None}
flat_loss = flat(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:])
flat_loss.backward()
flat_grads = {k: v.grad.numpy() for k, v in nn.state.get_state_dict(flat).items() if v.grad is not None}
# check loss matches
self.assertAlmostEqual(ref_loss.item(), flat_loss.item(), places=4)
# check output weight grad matches
diff = abs(ref_grads["output.weight"] - flat_grads["output.weight"]).max()
self.assertLess(diff, 1e-4, f"output.weight grad mismatch: max abs diff {diff}")
# check per-layer weight grads match
for i in range(params["n_layers"]):
for flat_key, ref_key in [
("wqkv", f"layers.{i}.attention.wqkv.weight"),
("wo", f"layers.{i}.attention.wo.weight"),
("w1", f"layers.{i}.feed_forward.w1.weight"),
("w2", f"layers.{i}.feed_forward.w2.weight"),
("w3", f"layers.{i}.feed_forward.w3.weight"),
]:
diff = abs(ref_grads[ref_key] - flat_grads[flat_key][i]).max()
self.assertLess(diff, 1e-4, f"layer {i} {flat_key} grad mismatch: max abs diff {diff}")
@unittest.skipUnless(os.getenv("CPU", "") == "1", "multi-device CPU test")
def test_forward_match_mp(self):
Tensor.manual_seed(42)
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
from tinygrad import Device
devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")
ref = Transformer(**params)
flat = FlatTransformer(**params)
copy_weights(flat, ref)
Tensor.realize(*nn.state.get_state_dict(flat).values())
flat.shard(devices, mp=True)
tokens = Tensor([[1, 50, 100, 999, 2]], device=devices[0])
ref_logits = ref(tokens.to(devices[0])).numpy()
flat_logits = flat(tokens.shard(devices)).numpy()
self.assertEqual(ref_logits.shape, flat_logits.shape)
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
@unittest.skipUnless(os.getenv("CPU", "") == "1", "multi-device CPU test")
def test_forward_match_dp(self):
Tensor.manual_seed(42)
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
from tinygrad import Device
devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")
ref = Transformer(**params)
flat = FlatTransformer(**params)
copy_weights(flat, ref)
Tensor.realize(*nn.state.get_state_dict(flat).values())
flat.shard(devices)
tokens = Tensor([[1, 50, 100, 999, 2], [2, 100, 50, 1, 999]], device=devices[0])
ref_logits = ref(tokens.to(devices[0])).numpy()
flat_logits = flat(tokens.shard(devices, axis=0)).numpy()
self.assertEqual(ref_logits.shape, flat_logits.shape)
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
if __name__ == "__main__":
unittest.main()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.5 MiB

After

Width:  |  Height:  |  Size: 1.6 MiB

Before After
Before After

Binary file not shown.

Before

Width:  |  Height:  |  Size: 454 KiB

After

Width:  |  Height:  |  Size: 369 KiB

Before After
Before After

View file

@ -2702,13 +2702,14 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
if is_multi:
if n_sharded:
out = Tensor(Tensor.empty(batch, M, N//len(a.device), dtype=a.dtype, device=a.device).uop.multi(2), device=a.device)
out = Tensor(Tensor.invalid(batch, M, N//len(a.device), dtype=a.dtype, device=a.device).uop.multi(2), device=a.device)
elif m_sharded:
out = Tensor(Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device).uop.multi(1), device=a.device)
out = Tensor(Tensor.invalid(batch, M, N, dtype=a.dtype, device=a.device).uop.multi(1), device=a.device)
else:
out = Tensor(Tensor.empty(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
out = Tensor(Tensor.invalid(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=a.dtype, device=a.device).uop.multi(0),
device=a.device)
else:
out = Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device)
out = Tensor.invalid(batch, M, N, dtype=a.dtype, device=a.device)
renderer = Device[a.device[0] if is_multi else a.device].renderer
dname, arch = renderer.device, getattr(renderer, "arch", "")

46
extra/remote/bench.py Normal file
View file

@ -0,0 +1,46 @@
#!/usr/bin/env python3
import os, sys, time
from tinygrad.runtime.support.system import RemotePCIDevice
LAT_N_RUNS = 500
THROUGHPUT_N_RUNS = 8
SIZES = [4, 1 << 10, 8 << 20]
if __name__ == "__main__":
os.environ["REMOTE"] = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("REMOTE", "127.0.0.1:6667")
# choose any amd/nv gpu.
devs = RemotePCIDevice.remote_list(0x1002, ((0, (0,)),), 0) or RemotePCIDevice.remote_list(0x10de, ((0, (0,)),), 0x03)
if not devs: raise RuntimeError("no GPU found on remote")
pci = RemotePCIDevice("BN", devs[0])
print(f"connected to {os.environ['REMOTE']}, device: {devs[0]}\n")
# ping (minimal server round-trip, no device I/O)
from tinygrad.runtime.support.system import RemoteCmd
sock = pci.sock
for _ in range(10): RemotePCIDevice._rpc(sock, 0, RemoteCmd.PING)
st = time.perf_counter()
for _ in range(LAT_N_RUNS): RemotePCIDevice._rpc(sock, 0, RemoteCmd.PING)
ping_lat = (time.perf_counter() - st) / LAT_N_RUNS
print(f"PING latency: {ping_lat*1e6:.1f} us ({1/ping_lat:,.0f} ops/sec)\n")
# throughput
sysmem, _ = pci.alloc_sysmem(max(SIZES))
print(f"{'size':>10s} {'write MB/s':>10s} {'read MB/s':>10s}")
for sz in SIZES:
data = b'\x01' * sz
for _ in range(5): sysmem[0:sz] = data
st = time.perf_counter()
for _ in range(THROUGHPUT_N_RUNS): sysmem[0:sz] = data
pci.read_config(0, 4) # flush, since writes are posted
w = (time.perf_counter() - st) / THROUGHPUT_N_RUNS
for _ in range(5): sysmem[0:sz]
st = time.perf_counter()
for _ in range(THROUGHPUT_N_RUNS): sysmem[0:sz]
r = (time.perf_counter() - st) / THROUGHPUT_N_RUNS
sfx, div = [('B',1),('K',1<<10),('M',1<<20)][[sz>=1<<10,sz>=1<<20,sz>=1<<30].count(True)]
print(f"{sz/div:>9.4g}{sfx} {sz/w/1e6:>10.1f} {sz/r/1e6:>10.1f}")

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3
import socket, struct, sys
from tinygrad.runtime.support.system import PCIDevice, RemoteCmd, System
from tinygrad.helpers import DEBUG
from tinygrad.helpers import DEBUG, OSX
def resp(resp0=0, resp1=0, status=0): return struct.pack('<BQQ', status, resp0, resp1)
def resp_err(msg): return struct.pack('<BQQ', 1, len(err:=msg.encode()), 0) + err
@ -12,6 +12,9 @@ mapped_bars: dict[tuple[int, int], object] = {}
sysmem_allocs: list[tuple] = []
def handle(conn, cmd, dev_id, bar, arg0, arg1, arg2):
if cmd == RemoteCmd.PING:
return conn.sendall(resp())
if cmd == RemoteCmd.PROBE:
payload = conn.recv(arg1, socket.MSG_WAITALL) if arg1 > 0 else b""
filter_devices: dict[int, list[int]] = {}
@ -47,15 +50,19 @@ def handle(conn, cmd, dev_id, bar, arg0, arg1, arg2):
pci_dev.reset()
conn.sendall(resp())
elif cmd == RemoteCmd.MMIO_READ:
conn.sendmsg([resp(arg1), mapped_bars[(dev_id, bar)][arg0:arg0+arg1]])
bar_view = mapped_bars[(dev_id, bar)]
if arg0 % 4 == 0 and arg1 == 4: conn.sendmsg([resp(arg1), struct.pack(f'<{arg1 // 4}I', bar_view.view(arg0, arg1, fmt='I')[0])])
else: conn.sendmsg([resp(arg1), bar_view[arg0:arg0+arg1]])
elif cmd == RemoteCmd.MMIO_WRITE:
mapped_bars[(dev_id, bar)][arg0:arg0+arg1] = conn.recv(arg1, socket.MSG_WAITALL)
data = conn.recv(arg1, socket.MSG_WAITALL)
bar_view = mapped_bars[(dev_id, bar)]
if arg0 % 4 == 0 and arg1 == 4: bar_view.view(arg0, arg1, fmt='I')[0] = struct.unpack(f'<{arg1 // 4}I', data)[0]
else: bar_view[arg0:arg0+arg1] = data
elif cmd == RemoteCmd.MAP_SYSMEM:
memview, paddrs = pci_dev.alloc_sysmem(arg0)
hdl = len(sysmem_allocs)
memview, paddrs = pci_dev.alloc_sysmem(arg0, contiguous=bool(arg1))
sysmem_allocs.append((memview, paddrs))
paddrs_bytes = struct.pack(f'<{len(paddrs)}Q', *paddrs)
conn.sendall(resp(len(paddrs_bytes), hdl) + paddrs_bytes)
conn.sendall(resp(len(paddrs_bytes), len(sysmem_allocs) - 1) + paddrs_bytes)
elif cmd == RemoteCmd.SYSMEM_READ:
conn.sendmsg([resp(arg1), sysmem_allocs[bar][0][arg0:arg0+arg1]])
elif cmd == RemoteCmd.SYSMEM_WRITE:
@ -77,6 +84,8 @@ def serve(conn:socket.socket):
conn.sendall(resp_err(str(e)))
if __name__ == "__main__":
if not OSX: System.reserve_hugepages(128) # for sysmem allocations
port = int(sys.argv[1]) if len(sys.argv) > 1 else 6667
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

View file

@ -10,11 +10,11 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo
def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None, dtype:DTypeLike|None=None) -> Tensor:
dtype = dtype or ref.dtype
if not isinstance(ref.device, tuple): return Tensor.empty(*shape, dtype=dtype, device=ref.device)
if not isinstance(ref.device, tuple): return Tensor.invalid(*shape, dtype=dtype, device=ref.device)
shard_axis = ref.uop.axis if axis is None else axis
shape = tuple(s // len(ref.device) if i == shard_axis else s for i, s in enumerate(shape))
axis = ref.uop.axis if axis is None else axis
return Tensor(Tensor.empty(*shape, dtype=dtype, device=ref.device).uop.multi(axis), dtype=dtype, device=ref.device)
return Tensor(Tensor.invalid(*shape, dtype=dtype, device=ref.device).uop.multi(axis), dtype=dtype, device=ref.device)
def _sharded_empty_like(ref:Tensor, axis:int|None=None) -> Tensor:
return _sharded_empty(ref.shape, ref, axis)

View file

@ -126,7 +126,7 @@ static int map_bar(uint32_t bar, response_t *resp) {
return 0;
}
static int map_sysmem_fd(uint64_t size, response_t *resp, int *out_fd) {
static int map_sysmem_fd(uint64_t size, int contiguous, response_t *resp, int *out_fd) {
if (g_sysmem_count >= MAX_SYSMEM) return -1;
int idx = g_sysmem_count;
int fd = -1;
@ -208,7 +208,7 @@ static void handle_client(int fd) {
case CMD_MAP_SYSMEM_FD: {
int shm_fd = -1;
resp.status = map_sysmem_fd(req.arg0, &resp, &shm_fd) ? 1 : 0;
resp.status = map_sysmem_fd(req.arg0, (int)req.arg1, &resp, &shm_fd) ? 1 : 0;
send_response(fd, &resp, shm_fd);
continue;
}

View file

@ -408,6 +408,23 @@ class TestVOP3P(unittest.TestCase):
self.assertEqual(lo, 0x0005, f"lo: expected 0x0005, got 0x{lo:04x}")
self.assertEqual(hi, 0x4003, f"hi: expected 0x4003, got 0x{hi:04x}")
def test_v_pk_add_u16_literal_constant(self):
"""V_PK_ADD_U16 with a literal constant (value > 64, requires VOP3P_LIT encoding).
Regression test: VOP3P literal constants were not passed to rsrc_dyn, so literal src read as 0.
"""
instructions = [
s_mov_b32(s[0], 0x1C001C00), # packed u16: hi=0x1C00, lo=0x1C00 (f16 for 2^-8)
v_mov_b32_e32(v[0], s[0]),
v_pk_add_u16(v[1], 0x2000, v[0], opsel_hi=2, opsel_hi2=1), # add 0x2000 bias to both halves
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1]
lo = result & 0xffff
hi = (result >> 16) & 0xffff
# lo = 0x1C00 + 0x2000 = 0x3C00 (f16 1.0), hi = 0x1C00 + 0x2000 = 0x3C00 (f16 1.0)
self.assertEqual(lo, 0x3C00, f"lo: expected 0x3C00, got 0x{lo:04x}")
self.assertEqual(hi, 0x3C00, f"hi: expected 0x3C00, got 0x{hi:04x}")
class TestWMMAF16(unittest.TestCase):
"""Tests for WMMA F16 output variant (V_WMMA_F16_16X16X16_F16).

View file

@ -92,7 +92,8 @@ class TestSQTTMapBase(unittest.TestCase):
if "ALT" not in e.name.display_name: execs += 1
elif "WAVE" in e.device:
# sopk/immediates don't get ALU/MEM EXEC
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL"}: insts += 1
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL",
"WAVEEND"}: insts += 1
else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}")
self.assertEqual(execs, insts)

View file

@ -81,16 +81,37 @@ class TestGemm(unittest.TestCase):
@needs_second_gpu
def test_gemm_k_sharded_3d(self): verify_asm_gemm_k_sharded_3d(1, 64, 32, 2*64, gpus=2)
# uses the Asm GEMM on CDNA4 only for speed reasons
class TestGemmLarge(unittest.TestCase):
# uses the smallest size for the cdna assembly gemm
class TestAsmGEMM(unittest.TestCase):
def setUp(self):
if not is_cdna4():
self.skipTest("assembly gemm is only for cdna4")
def test_tiny(self): verify_asm_gemm(1, 256, 256, 64)
def test_verify_with_numpy(self):
import numpy as np
M, N, K = 256, 256, 64
rng = np.random.default_rng(0)
a_np = (rng.random((M, K), dtype=np.float32) - 0.5).astype(np.half)
b_np = (rng.random((K, N), dtype=np.float32) - 0.5).astype(np.half)
c_np = a_np @ b_np
a, b = Tensor(a_np), Tensor(b_np)
c = asm_gemm(a, b)
c.realize()
# no validation on the NULL device
if a.device.startswith("NULL"): return None
np.testing.assert_allclose(c.numpy(), c_np, atol=2e-3, rtol=5e-2)
# test the Asm GEMM with Llama shapes, only run on the real machine for speed
class TestGemmLlama(unittest.TestCase):
def setUp(self):
if not is_cdna4() or getenv("MOCKGPU"):
self.skipTest("very slow on non mi350x")
@Context(ASM_GEMM=1)
def test_empty(self): (Tensor.empty(N:=getenv("N", 4096), N, dtype=dtypes.half)@Tensor.empty(N, N, dtype=dtypes.half)).realize()
def test_tiny(self): verify_asm_gemm(1, 256, 256, 64)
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336)
def test_gemm_batched(self): verify_asm_gemm(2, 8192, 4096, 4096)

View file

@ -283,5 +283,30 @@ class TestCustomKernel(unittest.TestCase):
self.assertIsNotNone(custom_idx, "custom_addmul kernel not found in schedule")
self.assertEqual(custom_idx, 3, f"custom_addmul should be at index 3, got {custom_idx}")
def test_anonymous_buffers_in_function(self):
"""Test that custom kernels with anonymous output buffers work inside @function."""
a = Tensor.full((4, 4), 3.).contiguous()
b = Tensor.full((4, 4), 2.).contiguous()
Tensor.realize(a, b)
def custom_add_with_tmp(o1:UOp, o2:UOp, A:UOp, B:UOp) -> UOp:
o1,o2,A,B = o1.flatten(), o2.flatten(), A.flatten(), B.flatten()
i = UOp.range(o1.size, 0)
store_o1 = o1[i].store(A[i]+B[i])
store_o2 = o2[i].store(A[i]+B[i]+2)
return UOp.group(store_o1, store_o2).end(i).sink(arg=KernelInfo(name=f"add_with_tmp_{o1.size}")).simplify()
from tinygrad import function
@function(precompile=True)
def run(x:Tensor, w:Tensor) -> Tensor:
out = Tensor.invalid(*x.shape, dtype=x.dtype)
tmp = Tensor.invalid(*x.shape, dtype=x.dtype)
out, tmp = Tensor.custom_kernel(out, tmp, x, w, fxn=custom_add_with_tmp)[:2]
return out+tmp
result = run(a, b).flatten().tolist()
expected = (3+2)*2+2
assert all(x == expected for x in result), f"expected all {expected}, got {result}"
if __name__ == '__main__':
unittest.main()

View file

@ -110,7 +110,7 @@ class TestMultiTensor(unittest.TestCase):
def test_tensor_from_multi(self):
X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0)
Y = Tensor(X.uop)
self.assertEqual(Y.device, Device.DEFAULT)
self.assertEqual(Y.device, devices_2)
np.testing.assert_equal(X.numpy(), Y.numpy())
with self.assertRaises(AssertionError):
@ -645,6 +645,7 @@ class TestMultiTensor(unittest.TestCase):
out = t0.flip(0) + 1
self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device)))
@unittest.skip("flaky")
def test_reshape_on_axis(self):
t0 = Tensor.rand((26, 15, 7)).shard(devices_3, axis=1)
@ -1129,6 +1130,51 @@ class TestTensorOps(unittest.TestCase):
def test_bitcast(self):
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiBufferView(unittest.TestCase):
@needs_second_gpu
def setUp(self): pass
def _check(self, a_ref:Tensor, a_multi:Tensor, view_fn):
"""Apply view_fn to both, verify zero compiled kernels and matching values."""
b_ref = view_fn(a_ref)
b_multi = view_fn(a_multi).contiguous()
sched = b_multi.schedule()
compiled = [si for si in sched if isinstance(si.prg, CompiledRunner)]
self.assertEqual(len(compiled), 0, f"expected zero compiled kernels, got {len(compiled)}")
run_schedule(sched)
np.testing.assert_equal(b_multi.numpy(), b_ref.numpy())
@unittest.skip("flaky on LLVM")
def test_shrink_non_shard_axis(self):
ref = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().realize()
a = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t[3])
def test_shrink_2d(self):
ref = Tensor.arange(6*4).reshape(6, 4).contiguous().realize()
a = Tensor.arange(6*4).reshape(6, 4).contiguous().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t.shrink(((1, 4), None)))
def test_reshape_then_shrink(self):
ref = Tensor.arange(8*6).reshape(8, 6).contiguous().realize()
a = Tensor.arange(8*6).reshape(8, 6).contiguous().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t.reshape(4, 2, 6)[1])
def test_chained_shrink(self):
ref = Tensor.arange(10*8).reshape(10, 8).contiguous().realize()
a = Tensor.arange(10*8).reshape(10, 8).contiguous().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t.shrink(((2, 8), None)).shrink(((1, 4), None)))
def test_4_devices(self):
ref = Tensor.arange(8*12).reshape(8, 12).contiguous().realize()
a = Tensor.arange(8*12).reshape(8, 12).contiguous().shard(devices_4, axis=1).realize()
sched = a[5].contiguous().schedule()
compiled = [si for si in sched if isinstance(si.prg, CompiledRunner)]
self.assertEqual(len(compiled), 0)
run_schedule(sched)
np.testing.assert_equal(a[5].contiguous().numpy(), ref[5].numpy())
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiFromUnrenderable(unittest.TestCase):
@needs_second_gpu

View file

@ -62,7 +62,7 @@ def equal_distribution(tiny_func, torch_func=None, numpy_func=None, shape=(40, 4
return (numpy_func is None or (kstest(x1, y) >= alpha and kstest(x2, y) >= alpha)) and \
(torch_func is None or (kstest(x1, z) >= alpha and kstest(x2, z) >= alpha))
def normal_test(func, shape=(20, 23), alpha=0.05): return equal_distribution(func, numpy_func=lambda x: np.random.randn(*x), shape=shape, alpha=alpha)
def normal_test(func, shape=(20, 45), alpha=0.05): return equal_distribution(func, numpy_func=lambda x: np.random.randn(*x), shape=shape, alpha=alpha)
class TestRandomness(unittest.TestCase):
def test_rand(self):
@ -131,29 +131,32 @@ class TestRandomness(unittest.TestCase):
"""
key0 = 1337
key1 = int.from_bytes(hashlib.sha256(int(0).to_bytes(4)).digest(), "big") & 0xffffffff
values = jax.extend.random.threefry_2x32((np.uint32(key1), np.uint32(key0)), np.arange(20, dtype=np.uint32))
# derive new key for the counter offset (c_low=0, c_high=0 for first call)
new_key_values = jax.extend.random.threefry_2x32((np.uint32(key1), np.uint32(key0)), np.array([0, 0], dtype=np.uint32))
new_key = (np.uint32(new_key_values[0]), np.uint32(new_key_values[1]))
values = jax.extend.random.threefry_2x32(new_key, np.arange(20, dtype=np.uint32))
values = (values >> (32 - 23)) | np.array(1, dtype=np.float32).view(np.uint32)
values = values.view(np.float32) - 1
values = values.view(np.float32) - 1
print(f"[{', '.join(f'{v}' for v in values)}]")
"""
jr = np.array([0.9073467254638672, 0.8235964775085449, 0.6872662305831909, 0.9920015335083008, 0.4941047430038452,
0.3108327388763428, 0.09639489650726318, 0.004686474800109863, 0.8435229063034058, 0.824237585067749,
0.5873836278915405, 0.4232727289199829, 0.2530076503753662, 0.40300023555755615, 0.03966474533081055,
0.27904558181762695, 0.9150195121765137, 0.48057758808135986, 0.23821306228637695, 0.7676635980606079], dtype=np.float32)
jr = np.array([0.45735931396484375, 0.6311527490615845, 0.15571284294128418, 0.8149417638778687, 0.7862188816070557,
0.8008807897567749, 0.568588376045227, 0.9852620363235474, 0.42314577102661133, 0.9811755418777466,
0.38059568405151367, 0.09186363220214844, 0.9497315883636475, 0.5826880931854248, 0.3796330690383911,
0.5610522031784058, 0.16122901439666748, 0.3732343912124634, 0.9795231819152832, 0.3280656337738037], dtype=np.float32)
r = Tensor.rand(20).numpy()
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
# next 20, np.arange(20, 40, dtype=np.uint32)
jr = np.array([0.7444133758544922, 0.7713677883148193, 0.8233780860900879, 0.43871235847473145, 0.517757773399353,
0.6437174081802368, 0.967403769493103, 0.26167726516723633, 0.6825339794158936, 0.14966607093811035,
0.28920769691467285, 0.017063498497009277, 0.2627382278442383, 0.9525482654571533, 0.9351049661636353,
0.43904995918273926, 0.043945908546447754, 0.6616791486740112, 0.6667773723602295, 0.5228077173233032], dtype=np.float32)
# next 20 (c_low=20, c_high=0)
jr = np.array([0.09199333190917969, 0.9130761623382568, 0.7048608064651489, 0.22254979610443115, 0.0014830827713012695,
0.37023448944091797, 0.7790107727050781, 0.7484984397888184, 0.7524604797363281, 0.19875383377075195,
0.48537540435791016, 0.10002851486206055, 0.5369305610656738, 0.3294715881347656, 0.5246957540512085,
0.7659651041030884, 0.7949080467224121, 0.34988296031951904, 0.9798505306243896, 0.2599533796310425], dtype=np.float32)
r = Tensor.rand(20).numpy()
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
# next 10, np.arange(40, 50, dtype=np.uint32)
jr = np.array([0.9614430665969849, 0.059279561042785645, 0.01909029483795166, 0.47882091999053955, 0.9677121639251709,
0.36863112449645996, 0.3102607727050781, 0.06608951091766357, 0.35329878330230713, 0.26518797874450684], dtype=np.float32)
# next 10 (c_low=40, c_high=0)
jr = np.array([0.3198714256286621, 0.7984923124313354, 0.320881724357605, 0.4716068506240845, 0.7323365211486816,
0.9663800001144409, 0.13873648643493652, 0.16062307357788086, 0.49300849437713623, 0.10077548027038574], dtype=np.float32)
r = Tensor.rand(10).numpy()
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
@ -324,7 +327,7 @@ class TestRandomness(unittest.TestCase):
lambda x: np.random.uniform(-1, 1, size=x) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
def test_kaiming_uniform(self):
for shape in [(32, 16, 3, 3), (20, 44), (3, 15, 35)]:
for shape in [(32, 16, 3, 3), (20, 44), (5, 15, 35)]:
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
def test_kaiming_normal(self):

View file

@ -292,6 +292,13 @@ class TestSetitem(unittest.TestCase):
np.testing.assert_allclose(a.numpy(), [4, 6, 8, 10])
np.testing.assert_allclose(b.numpy(), [0, 2, 4, 6])
def test_setitem_multiple_disjoint_on_invalid(self):
z = Tensor.invalid(10, dtype="int").realize()
z[2:5] = 2
z[6:7] = 3
z.realize()
self.assertListEqual(z[2:5].tolist(), [2, 2, 2])
self.assertListEqual(z[6:7].tolist(), [3])
class TestWithGrad(unittest.TestCase):
def test_no_requires_grad_works(self):

View file

@ -256,7 +256,7 @@ class TestTinygrad(unittest.TestCase):
def test_randperm(self):
Tensor.manual_seed(0)
a = Tensor.randperm(10).realize()
np.testing.assert_equal(a.numpy(), [5, 2, 8, 1, 3, 7, 9, 6, 0, 4])
np.testing.assert_equal(a.numpy(), [8, 9, 4, 3, 6, 1, 7, 5, 2, 0])
b = Tensor.randperm(1000).realize()
np.testing.assert_equal(set(b.numpy()), set(range(1000)))
@ -493,6 +493,17 @@ class TestTinygrad(unittest.TestCase):
dev = a.to(Device.DEFAULT)
np.testing.assert_allclose(a.numpy(), dev.numpy())
def test_copy_from_numpy_dtype(self):
data = np.array([1.0, 2, 3], dtype=np.float32)
t = Tensor(data, dtype=dtypes.bfloat16)
try:
# TODO: fix dtype in tinygrad space
assert t.dtype == dtypes.bfloat16
except AssertionError:
assert t.dtype == dtypes.float32
np.testing.assert_equal(t.tolist(), data)
np.testing.assert_equal((t+1).tolist(), data+1)
# Regression test for https://github.com/tinygrad/tinygrad/issues/1751
def test_copy_from_numpy_unaligned(self):
# 2**15 is the minimum for repro

View file

@ -77,15 +77,15 @@ class TestUOps(unittest.TestCase):
def _test_uop_fxn(self, op, fxn, dts=(dtypes.float32, )):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 0.0, 1.0]:
a = dtypes.as_const(a, dts[0])
a = dts[0].const(a)
self._equal(f([a], op, dts), fxn(a))
def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False, no_b_neg=False):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 0.0, 1.0]:
for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]):
a = dtypes.as_const(a, dts[0])
b = dtypes.as_const(abs(b) if no_b_neg else b, dts[1])
a = dts[0].const(a)
b = dts[1].const(abs(b) if no_b_neg else b)
self._equal(f([a,b], op, dts), fxn(a,b))
def _test_top_fxn(self, op, fxn, dts=(dtypes.float32, )*3):
@ -93,9 +93,9 @@ class TestUOps(unittest.TestCase):
for a in [-2.0, 0, 1]:
for b in [-3.0, 3.0]:
for c in [-4.0, 4.0]:
a = dtypes.as_const(a, dts[0])
b = dtypes.as_const(b, dts[1])
c = dtypes.as_const(c, dts[2])
a = dts[0].const(a)
b = dts[1].const(b)
c = dts[2].const(c)
self._equal(f([a,b,c], op, dts), fxn(a,b,c))
class TestFloatUOps(TestUOps):
@ -117,7 +117,7 @@ class TestFloatUOps(TestUOps):
def test_cmpne_nan(self): # NaN != x for any x (IEEE 754)
for a, b in [(math.nan, 1.0), (1.0, math.nan), (math.nan, math.nan)]:
self.assertTrue(_test_single_value(
[dtypes.as_const(a, dtypes.float32), dtypes.as_const(b, dtypes.float32)],
[dtypes.float32.const(a), dtypes.float32.const(b)],
Ops.CMPNE, (dtypes.float32, dtypes.float32)))
# MOD isn't tested on floats

View file

@ -176,6 +176,34 @@ class TestAMPageTable(unittest.TestCase):
mm0.map_range(helper_va(0x1000000), 2 << 20, paddrs=[(0x10000, 2 << 20)], aspace=AddrSpace.PHYS)
mm0.unmap_range(helper_va(0x1000000), 2 << 20)
def test_inspect_mode(self):
mm0 = self.d[0].mm
# Map a few disjoint ranges inside a larger region.
mappings = [(0x10000, 0x3000), (0x20000, 0x2000), (0x1000000, 2 << 20)]
for va, sz in mappings:
mm0.map_range(helper_va(va), sz, paddrs=[(va, sz)], aspace=AddrSpace.PHYS)
# Inspect over the whole region: should visit all mapped pages.
ctx = PageTableTraverseContext(self.d[0], mm0.root_page_table, helper_va(0x0), inspect=True)
visited = set()
for _off, pt, pte_idx, n_ptes, pte_covers in ctx.next(0x4000000):
for i in range(n_ptes):
pte = helper_read_entry_components(pt.entries[pte_idx + i])
if pte['valid']:
for p in range(0, pte_covers, 0x1000): visited.add(pte['paddr'] + p)
expected_pages = {va + off for va, sz in mappings for off in range(0, sz, 0x1000)}
assert visited == expected_pages
for va, sz in mappings:
mm0.unmap_range(helper_va(va), sz)
# Inspect after unmap: should find no valid entries.
ctx = PageTableTraverseContext(self.d[0], mm0.root_page_table, helper_va(0x0), inspect=True)
for _off, pt, pte_idx, n_ptes, pte_covers in ctx.next(0x4000000):
for i in range(n_ptes): assert not pt.valid(pte_idx + i)
def test_frag_size(self):
mm0 = self.d[0].mm

View file

@ -3,7 +3,8 @@ import numpy as np
from tinygrad import dtypes, Tensor
from tinygrad.uop.ops import Ops
from tinygrad.device import is_dtype_supported
from tinygrad.nn.onnx import OnnxRunner, OnnxDataType
from typing import Any
from tinygrad.nn.onnx import OnnxRunner, OnnxPBParser, OnnxDataType
from hypothesis import given, strategies as st
# copied from test_const_folding.py
@ -136,5 +137,40 @@ class TestOnnxRunnerDtypes(unittest.TestCase):
from_disk=False)
self.assertEqual(runner.graph_nodes[0].opts['value'].dtype, expected_dtype)
# from openpilot selfdrive/modeld/get_model_metadata.py
class MetadataOnnxPBParser(OnnxPBParser):
def _parse_ModelProto(self) -> dict:
obj: dict[str, Any] = {"graph": {"input": [], "output": []}, "metadata_props": []}
for fid, wire_type in self._parse_message(self.reader.len):
match fid:
case 7: obj["graph"] = self._parse_GraphProto()
case 14: obj["metadata_props"].append(self._parse_StringStringEntryProto())
case _: self.reader.skip_field(wire_type)
return obj
class TestOnnxMetadata(unittest.TestCase):
def test_metadata_props(self):
graph = onnx.helper.make_graph(
nodes=[onnx.helper.make_node('Identity', ['input'], ['output'])],
name='test',
inputs=[onnx.helper.make_tensor_value_info('input', onnx.TensorProto.FLOAT, (1, 3))],
outputs=[onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, (1, 3))],
)
model = onnx.helper.make_model(graph)
model.metadata_props.append(onnx.StringStringEntryProto(key="model_checkpoint", value="v1.0"))
model.metadata_props.append(onnx.StringStringEntryProto(key="output_slices", value="dGVzdA=="))
with tempfile.TemporaryDirectory() as tmpdir:
model_path = pathlib.Path(tmpdir) / "model.onnx"
onnx.save(model, model_path)
parsed = MetadataOnnxPBParser(model_path).parse()
# metadata_props should be accessible as dicts with "key" and "value"
self.assertEqual(len(parsed["metadata_props"]), 2)
self.assertEqual(parsed["metadata_props"][0]["key"], "model_checkpoint")
self.assertEqual(parsed["metadata_props"][0]["value"], "v1.0")
self.assertEqual(parsed["metadata_props"][1]["key"], "output_slices")
self.assertEqual(parsed["metadata_props"][1]["value"], "dGVzdA==")
if __name__ == '__main__':
unittest.main()

View file

@ -17,7 +17,7 @@ regCOMPUTE_USER_DATA_0 = 0x1be0 + amd_gpu.GC_BASE__INST0_SEG0
regCOMPUTE_NUM_THREAD_X = 0x1ba7 + amd_gpu.GC_BASE__INST0_SEG0
regGRBM_GFX_INDEX = 0x2200 + amd_gpu.GC_BASE__INST0_SEG1
regSQ_THREAD_TRACE_BUF0_BASE = 0x39e8 + amd_gpu.GC_BASE__INST0_SEG1
regSQ_THREAD_TRACE_BUF0_SIZE = {"rdna3": 0x39e9, "rdna4": 0x39e6}[MOCKGPU_ARCH] + amd_gpu.GC_BASE__INST0_SEG1
regSQ_THREAD_TRACE_BUF0_SIZE = {"rdna3": 0x39e9, "rdna4": 0x39e6, "cdna4": 0x39e9}[MOCKGPU_ARCH] + amd_gpu.GC_BASE__INST0_SEG1
regSQ_THREAD_TRACE_WPTR = 0x39ef + amd_gpu.GC_BASE__INST0_SEG1
regSQ_THREAD_TRACE_STATUS = 0x39f4 + amd_gpu.GC_BASE__INST0_SEG1
regCP_PERFMON_CNTL = 0x3808 + amd_gpu.GC_BASE__INST0_SEG1
@ -200,11 +200,13 @@ class PM4Executor(AMDQueue):
if st <= prg_addr < st+sz: prg_sz = sz - (prg_addr - st)
# Get scratch size from COMPUTE_TMPRING_SIZE register
# For gfx11: WAVESIZE = ceildiv(64 * size_per_thread, 256), so size_per_thread ≈ WAVESIZE * 256 / 64 = WAVESIZE * 4
# WAVESIZE = ceildiv(lanes * size_per_thread, mem_alignment_size)
# GFX11+: mem_alignment_size=256, so size_per_thread = WAVESIZE * 256 / 64 = WAVESIZE * 4
# GFX9: mem_alignment_size=1024, so size_per_thread = WAVESIZE * 1024 / 64 = WAVESIZE * 16
try: tmpring_size = self.gpu.regs[regCOMPUTE_TMPRING_SIZE]
except KeyError: tmpring_size = 0
wavesize = (tmpring_size >> 12) & 0x3FFF # WAVESIZE field is bits 12:25 for gfx11
scratch_size = wavesize * 4 # This gives the scratch size per thread (lane)
wavesize = (tmpring_size >> 12) & 0x3FFF # WAVESIZE field is bits 12:25
scratch_size = wavesize * (16 if self.gpu.arch == "cdna" else 4) # per-thread scratch size in bytes
assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)"
# Pass valid memory ranges, rsrc2, scratch_size, arch, and user data registers to Python emulator
@ -402,7 +404,7 @@ p2p_links_count 5
cpu_core_id_base 0
simd_id_base 2147488032
max_waves_per_simd 16
lds_size_in_kb 128
lds_size_in_kb 160
gds_size_in_kb 0
num_gws 64
wave_front_size 64

File diff suppressed because it is too large Load diff

View file

@ -2,6 +2,7 @@
from typing import Any, Callable
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops, UOp
from tinygrad.uop.decompositions import f2f
# Type alias for vars dict: stores UOps for variables and tuples for lambda definitions
VarVal = UOp | tuple[str, list[str], str]
@ -40,9 +41,13 @@ def _bitreverse(v: UOp, bits: int) -> UOp:
def _extract_bits(val: UOp, hi: int, lo: int) -> UOp:
dt = dtypes.uint64 if val.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
result = ((val >> _const(dt, lo)) if lo > 0 else val) & _const(val.dtype, (1 << (hi - lo + 1)) - 1)
# Downcast to uint32 when extracting <=32 bits from a 64-bit value, so .f32 bitcast works correctly
if dt == dtypes.uint64 and (hi - lo + 1) <= 32: result = result.cast(dtypes.uint32)
width = hi - lo + 1
# Cast to dt first to ensure shift operands have matching types
val_cast = val.cast(dt) if val.dtype != dt else val
result = ((val_cast >> _const(dt, lo)) if lo > 0 else val_cast) & _const(dt, (1 << width) - 1)
# Downcast to match extracted bit width so brace-concat { hi, lo } computes correct output dtype
target_dt = _BITS_DT.get(width) or (dtypes.uint32 if width <= 32 else dtypes.uint64 if width <= 64 else dt)
if result.dtype != target_dt: result = result.cast(target_dt)
return result
def _set_bit(old, pos, val):
@ -60,6 +65,48 @@ def _floor(x):
return ((x < _const(x.dtype, 0)) & x.ne(t)).where(t - _const(x.dtype, 1), t)
def _f16_extract(v): return (v & _u32(0xFFFF)).cast(dtypes.uint16).bitcast(dtypes.half) if v.dtype == dtypes.uint32 else v
# ═════ FP8 (E4M3) and BF8 (E5M2) conversion helpers ═════
# f32→fp8/bf8 uses f2f decomposition directly. fp8/bf8→f32 wraps f2f with subnormal handling
# (f2f flushes denormals to zero, but AMD V_CVT_F32_FP8/BF8 preserves subnormals).
def _fp8_to_f32(v: UOp) -> UOp:
b = (v.cast(dtypes.uint32) & _u32(0xFF)).cast(dtypes.uint8)
# E4M3 subnormal: exp==0, mant!=0 -> (-1)^sign * 2^(1-7) * (mant/8) = (-1)^sign * mant * 2^(-9)
bu = b.cast(dtypes.uint32)
sign, exp, mant = (bu >> _u32(7)) << _u32(31), (bu >> _u32(3)) & _u32(0xF), bu & _u32(0x7)
is_sub = exp.eq(_u32(0)) & mant.ne(_u32(0))
sub_f32 = (mant.cast(dtypes.float32) * _const(dtypes.float32, 1.0/512.0)).bitcast(dtypes.uint32) | sign
normal = f2f(b, dtypes.fp8e4m3, dtypes.float32)
return is_sub.where(sub_f32.bitcast(dtypes.float32), normal)
def _bf8_to_f32(v: UOp) -> UOp:
b = (v.cast(dtypes.uint32) & _u32(0xFF)).cast(dtypes.uint8)
# E5M2 subnormal: exp==0, mant!=0 -> (-1)^sign * 2^(1-15) * (mant/4) = (-1)^sign * mant * 2^(-16)
bu = b.cast(dtypes.uint32)
sign, exp, mant = (bu >> _u32(7)) << _u32(31), (bu >> _u32(2)) & _u32(0x1F), bu & _u32(0x3)
is_sub = exp.eq(_u32(0)) & mant.ne(_u32(0))
sub_f32 = (mant.cast(dtypes.float32) * _const(dtypes.float32, 1.0/65536.0)).bitcast(dtypes.uint32) | sign
normal = f2f(b, dtypes.fp8e5m2, dtypes.float32)
return is_sub.where(sub_f32.bitcast(dtypes.float32), normal)
def _f32_to_fp8(v: UOp) -> UOp:
return f2f((v.bitcast(dtypes.float32) if v.dtype != dtypes.float32 else v).bitcast(dtypes.uint32), dtypes.float32, dtypes.fp8e4m3)
def _f32_to_bf8(v: UOp) -> UOp:
return f2f((v.bitcast(dtypes.float32) if v.dtype != dtypes.float32 else v).bitcast(dtypes.uint32), dtypes.float32, dtypes.fp8e5m2)
def _f32_to_bf16(v: UOp) -> UOp:
"""Convert f32 to bf16 with round-to-nearest-even. BF16 is the upper 16 bits of F32 with rounding."""
bits = (v.bitcast(dtypes.float32) if v.dtype != dtypes.float32 else v).bitcast(dtypes.uint32)
# Round-to-nearest-even: add rounding bias. If the bit just below the truncation point is 1 and the rest are 0, round to even.
round_bit = (bits >> _u32(16)) & _u32(1) # bit 16 (LSB of kept part)
rounding = _u32(0x7FFF) + round_bit # 0x7FFF + bit16: rounds to even
rounded = bits + rounding
return (rounded >> _u32(16)).cast(dtypes.uint16)
def _f32_to_bf16_sr(v: UOp, stoch: UOp) -> UOp:
"""Convert f32 to bf16 with stochastic rounding."""
bits = (v.bitcast(dtypes.float32) if v.dtype != dtypes.float32 else v).bitcast(dtypes.uint32)
# Stochastic rounding: add lower 16 bits of stochastic value to lower 16 bits of f32
rounded = bits + (stoch & _u32(0xFFFF))
return (rounded >> _u32(16)).cast(dtypes.uint16)
def _check_nan(v: UOp, quiet: bool) -> UOp:
if v.op == Ops.CAST and v.dtype == dtypes.float64: v = v.src[0]
bits, exp_m, mant_m, qb, _ = _float_info(v)
@ -119,7 +166,10 @@ def _abs(val: UOp) -> UOp:
bt, ft = {10: (dtypes.uint16, dtypes.half), 23: (dtypes.uint32, dtypes.float32), 52: (dtypes.uint64, dtypes.float64)}[shift]
return (val.bitcast(bt) & _const(bt, sign_mask)).bitcast(ft)
def _f_to_u(f, dt): return UOp(Ops.TRUNC, f.dtype, ((f < _const(f.dtype, 0.0)).where(_const(f.dtype, 0.0), f),)).cast(dt)
def _f_to_u(f, dt):
clamped = (f < _const(f.dtype, 0.0)).where(_const(f.dtype, 0.0), f)
truncated = UOp(Ops.TRUNC, f.dtype, (clamped,))
return (truncated >= _const(f.dtype, 2**(dt.itemsize*8))).where(_const(dt, dt.max), truncated.cast(dt))
def _cvt_quiet(val: UOp) -> UOp:
bits, _, _, qb, _ = _float_info(val)
@ -289,6 +339,9 @@ _FUNCS: dict[str, Callable[..., UOp]] = {
'CalcDsAddr': lambda a, o, *r: a.cast(dtypes.uint32) + o.cast(dtypes.uint32),
'CalcGlobalAddr': lambda v, s, *r: v.cast(dtypes.uint64) + s.cast(dtypes.uint64),
'CalcScratchAddr': lambda v, s, *r: v.cast(dtypes.uint64) + s.cast(dtypes.uint64),
# FP8/BF8/BF16 conversion functions
'fp8_to_f32': _fp8_to_f32, 'bf8_to_f32': _bf8_to_f32, 'f32_to_fp8': _f32_to_fp8, 'f32_to_bf8': _f32_to_bf8,
'f32_to_bf16': _f32_to_bf16, 'f32_to_bf16_SR': _f32_to_bf16_sr, 'f32_to_bf16_sr': _f32_to_bf16_sr,
}
for is_max, name in [(False, 'min'), (True, 'max')]:
for dt, sfx in [(dtypes.float32, 'f32'), (dtypes.int, 'i32'), (dtypes.uint32, 'u32'), (dtypes.int16, 'i16'), (dtypes.uint16, 'u16')]:
@ -313,7 +366,8 @@ for is_max, name in [(False, 'min'), (True, 'max')]:
DTYPES = {'u32': dtypes.uint32, 'i32': dtypes.int, 'f32': dtypes.float32, 'b32': dtypes.uint32, 'u64': dtypes.uint64, 'i64': dtypes.int64,
'f64': dtypes.float64, 'b64': dtypes.uint64, 'u16': dtypes.uint16, 'i16': dtypes.short, 'f16': dtypes.half, 'b16': dtypes.uint16,
'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u4': dtypes.uint8, 'i4': dtypes.int8, 'u1': dtypes.uint32}
'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u4': dtypes.uint8, 'i4': dtypes.int8, 'u1': dtypes.uint32,
'fp8': dtypes.uint8, 'bf8': dtypes.uint8, 'b3': dtypes.uint8, 'b2': dtypes.uint8}
_BITS_DT = {8: dtypes.uint8, 16: dtypes.uint16, 32: dtypes.uint32, 64: dtypes.uint64}
_NUM_SUFFIXES = ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f')
def _strip_suffix(num: str) -> tuple[str, str]:
@ -425,7 +479,13 @@ class Parser:
case '+' | '-':
if op == '-' and left.op == Ops.CONST and right.op == Ops.CONST: return _const(left.dtype, left.arg - right.arg)
return (left + right) if op == '+' else (left - right)
case '*' | '/': return (left * right) if op == '*' else (left / right)
case '*' | '/':
# Integer promotion: promote 16-bit integers to 32-bit before multiply to avoid overflow
# (e.g. SOPP branch offset: SIMM16.i16 * 16'4 can exceed int16 range)
if op == '*' and left.dtype.itemsize == 2 and left.dtype in (dtypes.int16, dtypes.short, dtypes.uint16, dtypes.ushort):
pdt = dtypes.int if left.dtype in (dtypes.int16, dtypes.short) else dtypes.uint
left, right = left.cast(pdt), right.cast(pdt)
return (left * right) if op == '*' else (left / right)
case '**': return UOp(Ops.EXP2, left.dtype, (right.cast(left.dtype),)) if left.op == Ops.CONST and left.arg == 2.0 else left
_PREC = [('||',), ('&&',), ('|',), ('^',), ('&',), ('==', '!=', '<>'), ('>=', '<=', '>', '<'), ('>>', '<<'), ('+', '-'), ('*', '/'), ('**',)]
@ -502,7 +562,8 @@ class Parser:
self.eat('RBRACKET')
vgpr = self.vars.get('_vgpr')
if vgpr is None: return _u32(0)
return vgpr.index(_to_u32(reg) * _u32(32) + _to_u32(lane), ptr=True).load()
ws = self.vars.get('_wave_size', 32)
return vgpr.index(_to_u32(reg) * _u32(ws) + _to_u32(lane), ptr=True).load()
if self.try_eat('LPAREN'):
args = self._parse_args()
self.eat('RPAREN')
@ -514,8 +575,8 @@ class Parser:
if name == 'OVERFLOW_F32': return _const(dtypes.uint32, 0x7F7FFFFF).bitcast(dtypes.float32)
if name == 'UNDERFLOW_F64': return _const(dtypes.uint64, 1).bitcast(dtypes.float64)
if name == 'OVERFLOW_F64': return _const(dtypes.uint64, 0x7FEFFFFFFFFFFFFF).bitcast(dtypes.float64)
if name == 'WAVE32': return _const(dtypes.bool, True)
if name == 'WAVE64': return _const(dtypes.bool, False)
if name == 'WAVE32': return _const(dtypes.bool, self.vars.get('_wave_size', 32) <= 32)
if name == 'WAVE64': return _const(dtypes.bool, self.vars.get('_wave_size', 32) > 32)
if name == 'WAVE_MODE' and self.try_eat('DOT') and self.try_eat_val('IEEE', 'IDENT'): return _u32(1)
if self.try_eat('LBRACE'):
idx = self.eat('NUM').val
@ -527,7 +588,8 @@ class Parser:
self.eat('RBRACKET')
vgpr = self.vars.get('_vgpr')
if vgpr is None: return _u32(0)
return vgpr.index(_to_u32(reg) * _u32(32) + _u32(int(idx)), ptr=True).load()
ws = self.vars.get('_wave_size', 32)
return vgpr.index(_to_u32(reg) * _u32(ws) + _u32(int(idx)), ptr=True).load()
elem = self.vars.get(f'{name}@{idx}', self.vars.get(f'{name}{idx}'))
if elem is None:
# Extract bit idx from base variable (like var[idx])
@ -660,13 +722,14 @@ class Parser:
return None
def _sized_literal(self, bits: int) -> UOp:
if self.at('IDENT') and self.peek().val in ('U', 'I', 'F', 'B'):
if self.at('IDENT') and self.peek().val in ('U', 'I', 'F', 'B', 'BF'):
type_char = self.eat('IDENT').val
self.eat('LPAREN')
inner = self.parse()
self.eat('RPAREN')
dt = {('U',32): dtypes.uint32, ('U',64): dtypes.uint64, ('I',32): dtypes.int, ('I',64): dtypes.int64,
('F',16): dtypes.half, ('F',32): dtypes.float32, ('F',64): dtypes.float64,
('BF',16): dtypes.bfloat16,
('B',32): dtypes.uint32, ('B',64): dtypes.uint64}.get((type_char, bits), dtypes.uint64 if bits > 32 else dtypes.uint32)
if type_char == 'F' and inner.dtype in (dtypes.uint32, dtypes.uint64, dtypes.ulong, dtypes.int, dtypes.int64):
if inner.dtype.itemsize != dt.itemsize: inner = inner.cast(dtypes.uint32 if dt.itemsize == 4 else dtypes.uint64)
@ -769,6 +832,18 @@ class Parser:
elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF)
elif dt in (dtypes.uint16, dtypes.int16):
val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF)
else:
# Handle unaligned 32-bit loads: combine two consecutive dwords and shift.
# To avoid OOB at buffer boundaries for aligned loads, clamp idx_hi to idx (safe).
# Use int64 for the WHERE to avoid 32-bit int overflow in C pointer arithmetic (addr can be >8GB).
byte_off = (addr & _const(adt, 3)).cast(dtypes.uint32)
is_unaligned = byte_off.ne(_u32(0))
idx_native = (addr >> _const(adt, 2)).cast(dtypes.int64)
idx_hi_native = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int64)
safe_idx_hi = is_unaligned.where(idx_hi_native, idx_native)
hi = mem.index(safe_idx_hi, *gate)
combined = val.cast(dtypes.uint64) | (hi.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
val = is_unaligned.where((combined >> (byte_off.cast(dtypes.uint64) * UOp.const(dtypes.uint64, 8))).cast(dtypes.uint32), val)
return val
def _coerce_cmp(self, l: UOp, r: UOp) -> tuple[UOp, UOp]:
@ -980,17 +1055,34 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
i += 1
continue
# VGPR assignment: VGPR[lane][reg] = value
# VGPR assignment: VGPR[lane][reg] = value or VGPR[lane][reg][hi:lo].type = { ... }
if first == 'vgpr' and toks[1].type == 'LBRACKET':
j, lane_toks = _match_bracket(toks, 1)
if j < len(toks) and toks[j].type == 'LBRACKET':
j, reg_toks = _match_bracket(toks, j)
# Check for bit-slice: VGPR[lane][reg][hi:lo].type = value (read-modify-write)
if j < len(toks) and toks[j].type == 'LBRACKET':
j, slice_toks = _match_bracket(toks, j)
slice_str = _tok_str(slice_toks)
hi_str, lo_str = slice_str.split(':')
hi_val, lo_val = int(eval(hi_str.strip())), int(eval(lo_str.strip()))
if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
ln = parse_tokens(lane_toks, env, funcs)
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
ws = env.get('_wave_size', 32)
vgpr_idx = _to_u32(rg) * _u32(ws) + _to_u32(ln)
if assigns is not None:
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}][{hi_val}:{lo_val}]', (vgpr_idx, val, hi_val, lo_val)))
i += 1
continue
if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
ln = parse_tokens(lane_toks, env, funcs)
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
if assigns is not None:
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val)))
ws = env.get('_wave_size', 32)
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(ws) + _to_u32(ln), val)))
i += 1
continue
@ -1143,12 +1235,16 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
def is_const(c, v): return c.op == Ops.CONST and c.arg is v
cond = parse_cond(line, 'if')
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not is_const(cond, False) else []
branch_assigns: list[tuple[UOp, list]] = [] # (cond, assigns_list) for side-effect merging
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
else_side_effects: list = []
env_snap = dict(env)
static_true = is_const(cond, True) # track if any condition is statically true
i += 1
i, branch, ret = parse_block(lines, i, env, funcs, assigns if not is_const(cond, False) else None)
if_side: list = [] if assigns is not None and not is_const(cond, False) else []
i, branch, ret = parse_block(lines, i, env, funcs, if_side if assigns is not None and not is_const(cond, False) else None)
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
if assigns is not None and not is_const(cond, False): branch_assigns.append((cond, if_side))
env.clear()
env.update(env_snap)
while i < len(lines):
@ -1159,16 +1255,21 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
c = parse_cond(lines[i], 'elsif')
take = not static_true and not is_const(c, False)
i += 1
i, branch, ret = parse_block(lines, i, env, funcs, assigns if take else None)
br_side: list = [] if assigns is not None and take else []
i, branch, ret = parse_block(lines, i, env, funcs, br_side if assigns is not None and take else None)
if take:
conditions.append((c, ret if ret is not None else branch))
if is_const(c, True): static_true = True
if assigns is not None: branch_assigns.append((c, br_side))
env.clear()
env.update(env_snap)
elif lf == 'else':
i += 1
i, branch, ret = parse_block(lines, i, env, funcs, assigns if not static_true else None)
if not static_true: else_branch = (ret, branch)
el_side: list = [] if assigns is not None and not static_true else []
i, branch, ret = parse_block(lines, i, env, funcs, el_side if assigns is not None and not static_true else None)
if not static_true:
else_branch = (ret, branch)
if assigns is not None: else_side_effects = el_side
env.clear()
env.update(env_snap)
elif lf == 'endif':
@ -1188,6 +1289,10 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
ba = next((b for c, b in conditions if is_const(c, True) and isinstance(b, dict)), {})
block_assigns.update(ba)
env.update(ba)
# For static true, forward side effects unconditionally
if assigns is not None:
for bc, bse in branch_assigns:
if is_const(bc, True): assigns.extend(bse)
else:
else_assigns = else_branch[1]
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
@ -1199,6 +1304,21 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
if isinstance(tv, UOp) and isinstance(res, UOp):
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
block_assigns[var] = env[var] = res
# Merge side effects from branches with conditions
if assigns is not None:
def _cond_side_effect(cnd, dest, val):
if isinstance(val, tuple) and len(val) == 4: # VGPR bit-slice: (idx, rhs, hi, lo) -> add condition
return (dest, (val[0], val[1], val[2], val[3], cnd))
if isinstance(val, tuple) and len(val) == 2: # VGPR/MEM write: (addr, rhs) -> condition rhs
return (dest, (val[0], cnd.where(val[1], val[1])))
return (dest, val)
# Build combined condition: each branch fires when its cond is true AND no earlier cond was true
remaining = UOp.const(dtypes.bool, True)
for bc, bse in branch_assigns:
effective = remaining & bc if remaining.op != Ops.CONST else bc
for dest, val in bse: assigns.append(_cond_side_effect(effective, dest, val))
remaining = remaining & bc.logical_not() if remaining.op != Ops.CONST else bc.logical_not()
for dest, val in else_side_effects: assigns.append(_cond_side_effect(remaining, dest, val))
continue
# Regular assignment: var = value

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python
import unittest, os, subprocess
from unittest.mock import patch
from tinygrad import Tensor
from tinygrad.device import Device, Compiler, enumerate_devices_str
from tinygrad.helpers import diskcache_get, diskcache_put, getenv, Context, WIN, CI, OSX
@ -76,6 +77,18 @@ class TestDevice(unittest.TestCase):
self.assertIsInstance(Device["CPU"].compiler, CPULLVMCompiler)
assert inst is Device["CPU"].compiler # cached
@unittest.skipIf(Device.DEFAULT != "CPU", "only run on CPU")
def test_compiler_autodetect_fallback(self):
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
try: CPULLVMCompiler()
except Exception as e: self.skipTest(f"skipping: LLVM not available: {e}")
dev = Device["CPU"]
dev.cached_pair.clear()
with patch("tinygrad.renderer.cstyle.ClangJITRenderer.__init__", side_effect=RuntimeError("broken")):
self.assertIsInstance(dev.renderer.compiler, CPULLVMCompiler)
class MockCompiler(Compiler):
def __init__(self, key): super().__init__(key)
def compile(self, src) -> bytes: return src.encode()

View file

@ -1,9 +1,10 @@
import unittest, math
import z3
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops
from tinygrad.codegen.gpudims import get_grouped_dims, add_gpudims
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.uop.validate import uops_to_z3
from tinygrad.dtype import dtypes
from tinygrad.renderer import Renderer
from tinygrad.helpers import flatten, dedup
class TestGroupedDims(unittest.TestCase):
@ -93,6 +94,14 @@ class TestGroupedDims(unittest.TestCase):
assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}"
assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}"
def test_global_prod_max(self):
g, l = UOp.range(256, 0, AxisType.GLOBAL), UOp.range(256, 1, AxisType.LOCAL)
sink = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0).index(g + l).store(UOp.const(dtypes.float, 1.0)).end(g, l).sink(arg=KernelInfo())
class R(Renderer): global_max, local_max, global_prod_max = (256, 256, 256), (128, 128, 128), (128, 128, 128)
specials = [u for u in add_gpudims(R(), sink).toposort() if u.op is Ops.SPECIAL]
self.assertGreater(len([s for s in specials if "lidx" in s.arg]), 1)
self.assertGreater(len([s for s in specials if "gidx" in s.arg]), 1)
def test_max_sizes_none(self):
self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4])
self._check_grouped_dims("gidx", (100,), None, False, [100])

View file

@ -1189,6 +1189,61 @@ class TestBufferView(unittest.TestCase):
b = a.shrink(((200, 800),)).shrink(((0, 300),)).reshape((30, 10)).shrink(((20, 25), (0, 10))).contiguous()
run_schedule(check_schedule(b, 0))
def test_shrink_non_shard_axis_is_buffer_view_multi(self):
# indexing a non-shard axis of a realized sharded tensor should be BUFFER_VIEW on each device, not copy kernels
# this is the flat_llama pattern: weight[layer_idx] where weight is (n_layers, out, dim) sharded on axis=1
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a[3].contiguous(), 0))
def test_shrink_2d_non_shard_axis_multi(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(6*4).reshape(6, 4).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.shrink(((1, 4), None)).contiguous(), 0))
def test_shrink_shard_axis_0_multi(self):
# shrinking a middle dim is not contiguous per shard, so this needs copy kernels
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*6*2).reshape(4, 6, 2).contiguous().shard(devices, axis=0).realize()
run_schedule(check_schedule(a.shrink((None, (2, 5), None)).contiguous(), 2))
def test_reshape_then_shrink_multi(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(8*6).reshape(8, 6).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.reshape(4, 2, 6)[1].contiguous(), 0))
def test_permute_then_shrink_multi(self):
# permute makes per-shard view non-contiguous, needs copy kernels
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*6*2).reshape(4, 6, 2).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.permute(1, 0, 2).shrink(((0, 6), (1, 3), None)).contiguous(), 2))
def test_multi_buffer_view_4_devices(self):
devices = tuple(f"NULL:{i}" for i in range(4))
a = Tensor.arange(8*12).reshape(8, 12).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a[5].contiguous(), 0))
def test_chained_shrink_multi(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(10*8).reshape(10, 8).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.shrink(((2, 8), None)).shrink(((1, 4), None)).contiguous(), 0))
# negative tests: these should NOT become BUFFER_VIEW (non-contiguous per shard)
def test_expand_multi_not_buffer_view(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*2).reshape(4, 1, 2).contiguous().shard(devices, axis=2).realize()
run_schedule(check_schedule(a.expand(4, 3, 2).contiguous(), 2))
def test_pad_multi_not_buffer_view(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*2).reshape(4, 2).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.pad(((1, 1), (0, 0))).contiguous(), 2))
def test_flip_multi_not_buffer_view(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*2).reshape(4, 2).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.flip(0).contiguous(), 2))
class TestInvalidTensor(unittest.TestCase):
def test_full_invalid_is_zero_kernels(self):
from tinygrad.dtype import Invalid

View file

@ -162,5 +162,17 @@ class TestTensorUnique(unittest.TestCase):
Tensor.realize(b,c)
self.assertIs(b.uop.buffer, c.uop.buffer)
class TestRand(unittest.TestCase):
def test_rand_large_tensor(self):
# large tensor rand (num > uint32.max) should not crash in frontend
Tensor.manual_seed(0)
Tensor.rand(2**17, 2**17).schedule()
Tensor.rand(2**17, 2**17).schedule()
Tensor.rand(2**17, 2**17).schedule()
class TestTensorDevice(unittest.TestCase):
def test_create_from_single_device_tuple(self):
(Tensor([1.0], device=(Device.DEFAULT,)) + Tensor([2.0])).realize()
if __name__ == '__main__':
unittest.main()

View file

@ -1,5 +1,6 @@
import unittest, struct
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp
# format types: https://docs.python.org/3/library/struct.html
@ -78,5 +79,9 @@ class TestTensorData(unittest.TestCase):
assert dat.shape == (2,2)
# NOTE: python can't deref float16
def test_data_uop_device(self):
uop = UOp.const(dtypes.float, 1.0, "DEVICE")
self.assertEqual(Tensor(uop).device, "DEVICE")
if __name__ == '__main__':
unittest.main()

View file

@ -81,7 +81,11 @@ def add_gpudims(ctx:Renderer, s:UOp):
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
else:
# define indexes for GPU-like execution
idxs = get_grouped_dims("gidx", global_shape, ctx.global_max, reverse=True) + get_grouped_dims("lidx", local_shape, ctx.local_max)
local_idxs = get_grouped_dims("lidx", local_shape, ctx.local_max)
hw_local = [_dim_max(u.src[0]) for u in local_idxs if u.op is Ops.SPECIAL]
global_max = ctx.global_max if ctx.global_prod_max is None else \
tuple(min(gm, pm//l) for gm,pm,l in zip(ctx.global_max or ctx.global_prod_max, ctx.global_prod_max, hw_local+[1]*3))
idxs = get_grouped_dims("gidx", global_shape, global_max, reverse=True) + local_idxs
# apply to multiple ranges
subs = {}

View file

@ -6,7 +6,7 @@ import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup, ContextVar
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, CPU_X86, NV_PTX, CUDA_PTX, NV_NAK
from tinygrad.helpers import EMULATED_DTYPES, NULL_IR3, NULL_QCOMCL, TracingKey
from tinygrad.helpers import EMULATED_DTYPES, NULL_IR3, NULL_QCOMCL, TracingKey, size_to_str
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
if TYPE_CHECKING: from tinygrad.renderer import Renderer
@ -212,7 +212,9 @@ class Allocator(Generic[DeviceType]):
# overridden in LRUAllocator
def alloc(self, size:int, options:BufferSpec|None=None):
assert size > 0, f"alloc size must be positive, getting {size}"
return self._alloc(size, options if options is not None else self.default_buffer_spec)
try: return self._alloc(size, options if options is not None else self.default_buffer_spec)
except (RuntimeError, MemoryError) as e: raise MemoryError(f"Allocation of {size_to_str(size)} failed on {self.dev.device}. "
f"Used: {size_to_str(GlobalCounters.mem_used_per_device[self.dev.device])}") from e
def free(self, opaque, size:int, options:BufferSpec|None=None):
self._free(opaque, options if options is not None else self.default_buffer_spec)
@ -304,7 +306,7 @@ class Compiled:
# remove disabled compilers
for en, rc in self.comp_sets.values():
if en is not None and en.value == 0 and rc in comps: comps.remove(rc)
if en is not None and en.value == 0 and en.key in os.environ and rc in comps: comps.remove(rc)
return select_first_inited(list(forced_comps) if len(forced_comps)>0 else comps, f"No compiler for {self.device} is available", self.cached_pair)

View file

@ -87,6 +87,15 @@ class DType(metaclass=DTypeMetaClass):
def max(self):
if dtypes.is_int(self): return 2**(self.scalar().bitsize)-1+self.min
return float("inf") if dtypes.is_float(self) else True
def const(self, val: tuple[ConstType, ...]|ConstType):
if isinstance(val, tuple):
assert len(val) == self.count, f"mismatch {val} {self}"
return tuple(map(self.const, val))
if isinstance(val, InvalidType): return val
# NOTE: float('nan') != float('nan'), so we canonicalize here
if isinstance(val, float) and math.isnan(val): val = math.nan
# int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache
return ConstFloat(float(val)) if dtypes.is_float(self) else bool(val) if dtypes.is_bool(self) else int(val)
@dataclass(frozen=True, eq=False)
class PtrDType(DType):
@ -165,16 +174,6 @@ class dtypes:
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
@staticmethod
def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType):
if isinstance(val, tuple):
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)
if isinstance(val, InvalidType): return val
# NOTE: float('nan') != float('nan'), so we canonicalize here
if isinstance(val, float) and math.isnan(val): val = math.nan
# int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache
return ConstFloat(float(val)) if dtypes.is_float(dtype) else bool(val) if dtypes.is_bool(dtype) else int(val)
@staticmethod
def finfo(dtype:DType) -> tuple[int, int]:
"""(exponent, mantissa)"""
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")

View file

@ -66,6 +66,13 @@ def replace_store_after_with_contig(u:UOp, src:UOp):
while assigned_to.op in {Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base
if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag)
def _make_buffer_view(src:UOp) -> UOp|None:
"""If movement ops on src collapse to a contiguous range, return BUFFER_VIEW.reshape(src.shape). Otherwise None."""
if (offset := src.contiguous_view_offset()) is None: return None
buf = src.base
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0]
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape)
def contiguous_mops_to_view(c:UOp, src:UOp):
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
buf = src.base
@ -76,18 +83,22 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
if not all_int(c.shape): return None
# check if view is supported
if not isinstance(c.device, str): return None
from tinygrad.device import Device
if not hasattr(Device[c.device].allocator, "_offset"): return None
if isinstance(c.device, str):
if not hasattr(Device[c.device].allocator, "_offset"): return None
elif not all(hasattr(Device[d].allocator, "_offset") for d in c.device): return None
# see if this can be a view
if (offset := src.contiguous_view_offset()) is None: return None
# merge BUFFER_VIEWs
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0]
# for MULTI tensors, use multi_pm to resolve per-shard movement ops, then create BUFFER_VIEW on the resolved result
if not isinstance(c.device, str):
from tinygrad.schedule.multi import multi_pm
resolved = graph_rewrite(src, multi_pm, name="multi_buffer_view")
if resolved.op is not Ops.MULTI: return None
if (view := _make_buffer_view(resolved.src[0])) is None: return None
return view.multi(resolved.arg).contiguous(tag=c.tag)
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape).contiguous(tag=c.tag)
if (view := _make_buffer_view(src)) is None: return None
return view.contiguous(tag=c.tag)
def transform_precompiled_call(c:UOp) -> UOp|None:
if not c.arg.precompile: return None

View file

@ -1,4 +1,4 @@
import functools
import functools, itertools
from typing import Generic, TypeVar, Callable, cast, overload
from tinygrad.helpers import Context, dedup, getenv
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
@ -6,17 +6,21 @@ from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict
def add_to_ctx(ctx, x:UOp):
ret = x.param_like(len(ctx))
ctx.append(x)
ret = x.param_like(len(ctx[0]))
ctx[0].append(x)
return ret
pm_transform_unique_const = PatternMatcher([
# transform unique consts to LUNIQUE
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))),
])
pm_ctx = PatternMatcher([
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) else None),
# strip UNIQUE from unique consts — they don't need buffer identity inside function bodies
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"), lambda ctx,x: x.replace(src=(x.src[1],))),
])
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
])+pm_transform_unique_const
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
@ -56,7 +60,7 @@ class _function(Generic[ReturnType]):
# the BUFFERs that are left are the implicit inputs
num_explicit = len(call_uops)
uret = graph_rewrite(uret, pm_ctx, call_uops, bottom_up=True, name="get_implicit_inputs")
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
if not self.allow_implicit:
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]
@ -74,9 +78,9 @@ class _function(Generic[ReturnType]):
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
precompile_backward=self.precompile_backward)
if isinstance(ret, tuple):
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret))))
else:
return cast(ReturnType, Tensor(fret.gettuple(0), device=fret.device))
return cast(ReturnType, Tensor(fret.gettuple(0)))
# overload signatures support both @function and @function(precompile=True) syntax
@overload

View file

@ -1,6 +1,6 @@
from typing import cast
import math, dataclasses
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
import math, dataclasses, itertools
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
from tinygrad.helpers import argsort
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
@ -37,6 +37,9 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
# TODO: is this okay here?
from tinygrad.function import pm_transform_unique_const
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))

View file

@ -36,6 +36,7 @@ def colored(st, color:str|None, background=False): # replace the termcolor libra
return f"\u001b[{10*background+60*(color.upper() == color)+30+colors.index(color.lower())}m{st}\u001b[0m" if color is not None else st
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
def size_to_str(s:int) -> str: return next((f"{s / d:.2f} {pr}" for d,pr in [(1<<30, "GB"),(1<<20, "MB"),(1<<10, "KB")] if s >= d), f"{s} B")
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
def ansilen(s:str): return len(ansistrip(s))
def make_tuple(x:int|Sequence[int], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)

View file

@ -25,6 +25,13 @@ class ElementwiseMixin(DTypeMixin):
return self.ne(True)
def neg(self) -> Self:
"""
Negates the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
```
"""
return self.logical_not() if self.dtype.scalar() == dtypes.bool else self * (-1)
def _check_dtype(self) -> None:
@ -141,6 +148,9 @@ class ElementwiseMixin(DTypeMixin):
def __neg__(self) -> Self:
return self.neg()
def __invert__(self) -> Self:
return self.bitwise_not()
def __add__(self, x: Self | ConstType) -> Self:
return self.add(x)

View file

@ -174,6 +174,9 @@ class MovementMixin:
def shrink_to(self, shape, *args) -> Self:
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
def pad_to(self, shape, *args) -> Self:
return self._mop(Ops.PAD, tuple([(0, 0 if ns is None else ns-s) for s,ns in zip(self.shape, argfix(shape, *args), strict=True)]))
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)

View file

@ -153,19 +153,21 @@ class OnnxPBParser:
def _parse_ModelProto(self) -> dict:
"""Entry point for parsing the ONNX model."""
graph: dict|None = None
opset_imports: list[OpSetId] = []
obj: dict[str, Any] = {"opset_import": []}
for fid, wire_type in self._parse_message(self.reader.len):
match fid:
case 7: graph = self._parse_GraphProto()
case 8: opset_imports.append(self._parse_OperatorSetIdProto())
case 4: obj["domain"] = self.reader.read_string()
case 5: obj["model_version"] = self.reader.read_int64()
case 7: obj["graph"] = self._parse_GraphProto()
case 8: obj["opset_import"].append(self._parse_OperatorSetIdProto())
case _: self.reader.skip_field(wire_type)
assert graph is not None
# update opset version
versions = {opset.domain: opset.version for opset in opset_imports}
graph["node"] = [OnnxNode(n.op, OpSetId(n.opset_id.domain, versions.get(n.opset_id.domain, 1)), n.inputs, n.outputs, n.opts)
for n in graph["node"]]
return graph
opset_imports = {Domain.from_onnx(x.get('domain')):x.get('version', 1) for x in obj["opset_import"]}
for n in obj["graph"]["node"]:
n_ = n["parsed_node"]
n["parsed_node"] = OnnxNode(n_.op, OpSetId(n_.opset_id.domain, opset_imports.get(n_.opset_id.domain, 1)), n_.inputs, n_.outputs, n_.opts)
return obj
def _parse_GraphProto(self) -> dict:
obj: dict[str, Any] = {"node": [], "initializer": [], "input": [], "output": []}
@ -179,23 +181,26 @@ class OnnxPBParser:
case _: self.reader.skip_field(wire_type)
return obj
def _parse_NodeProto(self) -> OnnxNode:
inputs: list[str] = []
outputs: list[str] = []
attributes: list[tuple[str, Any]] = []
domain: str|None = None
op_type = ""
def _parse_NodeProto(self) -> dict:
obj: dict[str, Any] = {"input": [], "output": [], "attribute": [], "domain": None}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: inputs.append(self.reader.read_string())
case 2: outputs.append(self.reader.read_string())
case 4: op_type = self.reader.read_string()
case 5: attributes.append(self._parse_AttributeProto())
case 7: domain = self.reader.read_string()
case 1: obj["input"].append(self.reader.read_string())
case 2: obj["output"].append(self.reader.read_string())
case 3: obj["name"] = self.reader.read_string()
case 4: obj["op_type"] = self.reader.read_string()
case 5: obj["attribute"].append(self._parse_AttributeProto())
case 6: obj["doc_string"] = self.reader.read_string()
case 7: obj["domain"] = self.reader.read_string()
case _: self.reader.skip_field(wire_type)
return OnnxNode(op_type, OpSetId(Domain.from_onnx(domain), 1), tuple(inputs), tuple(outputs), dict(attributes))
def _parse_TensorProto(self) -> tuple[str, Tensor]:
# parse node
attributes = {attr_dict["name"]: attr_dict[AttributeType(attr_dict["type"]).to_field_name()] for attr_dict in obj["attribute"]}
opset_id = OpSetId(Domain.from_onnx(obj.get('domain')), 1) # default version, to be updated later in _parse_ModelProto
obj["parsed_node"] = OnnxNode(obj["op_type"], opset_id, tuple(obj["input"]), tuple(obj["output"]), attributes)
return obj
def _parse_TensorProto(self) -> dict:
obj: dict[str, Any] = {"dims": []}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
@ -215,16 +220,18 @@ class OnnxPBParser:
# load external data
if self.load_external_data and obj.get("data_location", 0) == 1:
if "external_data" not in obj: raise ValueError("no external_data")
ext = dict(obj["external_data"])
if "location" not in ext: raise ValueError("no location in external_data")
offset = int(ext.get("offset", "0"))
length = int(ext["length"]) if "length" in ext else None
location, length, offset = None, None, 0
for kv in obj["external_data"]:
if kv["key"] == "location": location = kv["value"]
elif kv["key"] == "offset": offset = int(kv["value"])
elif kv["key"] == "length": length = int(kv["value"])
if location is None: raise ValueError("no location in external_data")
if self.file_path is None:
if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"):
self.file_path = pathlib.Path(self.tensor.device[5:])
else: raise ValueError("onnx external_data needs the origin file path, try passing onnx file path to onnx_load")
ext_path = self.file_path.parent.joinpath(ext["location"])
ext_path = self.file_path.parent.joinpath(location)
if not ext_path.exists(): raise FileNotFoundError(f"external location not exists: {ext_path}")
ext_tensor = Tensor(ext_path)
@ -234,20 +241,23 @@ class OnnxPBParser:
# parse tensor
to_dtype = dtype_fallback(true_dtype := OnnxDataType(obj['data_type']).to_dtype(), "buffer parse")
shape = tuple(obj['dims'])
data_fields = [f for f in ('float_data','int32_data','int64_data','double_data','uint64_data','raw_data') if f in obj]
data = obj[get_single_element(data_fields)]
name = obj.get("name", "")
if not isinstance(data, Tensor): return name, Tensor(data, dtype=to_dtype).reshape(shape)
assert data.dtype == dtypes.uint8, data
present_fields = [field for field in ['float_data', 'int32_data', 'int64_data', 'double_data', 'uint64_data', 'raw_data'] if field in obj]
assert len(present_fields) == 1, f"only 1 data field is allowed from {obj=}"
data = obj[present_fields[0]]
if not isinstance(data, Tensor):
obj["parsed_tensor"] = Tensor(data, dtype=to_dtype).reshape(shape)
return obj
assert isinstance(data, Tensor) and data.dtype == dtypes.uint8, data
data = data.bitcast(true_dtype).reshape(shape)
data = data.to(Device.DEFAULT) if true_dtype is to_dtype else data.to("cpu").cast(to_dtype).to(Device.DEFAULT)
# const folding
if shape == ():
if data.dtype == dtypes.float16 and sys.version_info < (3, 12): data = data.cast(dtypes.float32)
data = Tensor(data.item(), dtype=to_dtype).reshape(shape)
return name, data
obj["parsed_tensor"] = data
return obj
def _parse_AttributeProto(self) -> tuple[str, Any]:
def _parse_AttributeProto(self) -> dict:
obj: dict[str, Any] = {"floats": [], "ints": [], "strings": []}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
@ -255,7 +265,7 @@ class OnnxPBParser:
case 2: obj["f"] = self.reader.read_float()
case 3: obj["i"] = self.reader.read_int64()
case 4: obj["s"] = self.reader.read_bytes().data().tobytes().decode("utf8")
case 5: obj["t"] = self._parse_TensorProto()[1]
case 5: obj["t"] = self._parse_TensorProto()['parsed_tensor']
case 6: obj["g"] = OnnxRunner._from_subgraph(self._parse_GraphProto())
case 7: obj["floats"].append(self.reader.read_float())
case 8: obj["ints"].append(self.reader.read_int64())
@ -263,22 +273,26 @@ class OnnxPBParser:
case 20: obj["type"] = self.reader.read_int64()
case _: self.reader.skip_field(wire_type)
obj["floats"], obj["ints"], obj["strings"] = tuple(obj["floats"]), tuple(obj["ints"]), tuple(obj["strings"])
return obj["name"], obj[AttributeType(obj["type"]).to_field_name()]
return obj
def _parse_ValueInfoProto(self) -> tuple[str, OnnxValue|None]:
name, type_obj = "", None
def _parse_ValueInfoProto(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: name = self.reader.read_string()
case 2: type_obj = self._parse_TypeProto()
case 1: obj["name"] = self.reader.read_string()
case 2: obj["type"] = self._parse_TypeProto()
case _: self.reader.skip_field(wire_type)
if type_obj is None: return name, None
# parse type
if "type" not in obj: return {**obj, "parsed_type": None}
type_obj = obj["type"]
if is_optional := "optional_type" in type_obj: type_obj = type_obj["optional_type"]["elem_type"]
if is_sequence := "sequence_type" in type_obj: type_obj = type_obj["sequence_type"]["elem_type"]
assert "tensor_type" in type_obj, type_obj
shape_dims = type_obj['tensor_type'].get('shape', {}).get('dim', [])
return name, OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims),
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
obj['parsed_type'] = OnnxValue(tuple(d.get('dim_param') or d.get('dim_value') for d in shape_dims),
OnnxDataType(type_obj['tensor_type']['elem_type']).to_dtype(), is_optional, is_sequence)
return obj
def _parse_TypeProto(self) -> dict:
obj: dict[str, Any] = {}
@ -324,24 +338,23 @@ class OnnxPBParser:
case _: self.reader.skip_field(wire_type)
return obj
def _parse_StringStringEntryProto(self) -> tuple[str, str]:
key, value = "", ""
def _parse_StringStringEntryProto(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: key = self.reader.read_string()
case 2: value = self.reader.read_string()
case 1: obj["key"] = self.reader.read_string()
case 2: obj["value"] = self.reader.read_string()
case _: self.reader.skip_field(wire_type)
return key, value
return obj
def _parse_OperatorSetIdProto(self) -> OpSetId:
domain: str|None = None
version = 1
def _parse_OperatorSetIdProto(self) -> dict:
obj: dict[str, Any] = {}
for fid, wire_type in self._parse_message(self._decode_end_pos()):
match fid:
case 1: domain = self.reader.read_string()
case 2: version = self.reader.read_int64()
case 1: obj["domain"] = self.reader.read_string()
case 2: obj["version"] = self.reader.read_int64()
case _: self.reader.skip_field(wire_type)
return OpSetId(Domain.from_onnx(domain), version)
return obj
# ***** python const *****
required_input_python_consts: dict[str, tuple[int, ...]] = {
@ -367,15 +380,16 @@ class OnnxRunner:
model_path: The ONNX model, provided as a file path (a string or Path object) or a Tensor.
"""
def __init__(self, model_path: Tensor | str | pathlib.Path):
self._init_from_graph(OnnxPBParser(model_path, load_external_data=True).parse())
model = OnnxPBParser(model_path, load_external_data=True).parse()
self._init_from_graph(model["graph"])
def _init_from_graph(self, graph: dict, is_subgraph: bool = False):
self.is_training = any(n.opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"])
self.graph_name = graph["name"] if is_subgraph else ""
self.graph_values: dict[str, Any] = {"": None, **dict(graph["initializer"])}
self.graph_inputs = {name: typ for name, typ in graph["input"] if name not in self.graph_values}
self.graph_outputs = tuple(name for name, _ in graph["output"])
self.graph_nodes = tuple(graph["node"])
self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}}
self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values}
self.graph_outputs = tuple(o["name"] for o in graph["output"])
self.graph_nodes = tuple(n["parsed_node"] for n in graph["node"])
# track names from initializers and Constant nodes for fast path optimizations
self.const_names: set[str] = set(self.graph_values.keys()) | {o for n in self.graph_nodes if n.op == "Constant" for o in n.outputs}

View file

@ -142,6 +142,7 @@ class Renderer:
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
global_prod_max: tuple[int, ...]|None = None
shared_max: int = 32768
tensor_cores: list[TensorCore] = []
pre_matcher: PatternMatcher|None = None

View file

@ -468,6 +468,7 @@ class AMDHIPRenderer(CStyleLanguage):
shared_max = 65536
# NOTE: this is only really needed on gfx12, even though gfx11 reports the same limitation
global_max = (2147483647, 65535, 65535)
global_prod_max = (0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)
@staticmethod
def get_tensor_cores(arch):

View file

@ -216,6 +216,7 @@ class AMDLLVMRenderer(LLVMRenderer):
has_local = True
shared_max = AMDHIPRenderer.shared_max
global_max = AMDHIPRenderer.global_max
global_prod_max = AMDHIPRenderer.global_prod_max
abi = "amdgpu_kernel"
code_for_op = {**LLVMRenderer.code_for_op, **{op: lambda: None for op in llvm_intrinsics}}
string_rewrite = PatternMatcher([

View file

@ -1015,11 +1015,11 @@ class AMDDevice(HCQCompiled):
gart = self.iface.alloc(0x100, uncached=True, cpu_access=True)
if queue_type == kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL:
aql_desc = hsa.amd_queue_t(queue_properties=hsa.AMD_QUEUE_PROPERTIES_IS_PTR64 | hsa.AMD_QUEUE_PROPERTIES_ENABLE_PROFILING,
self.aql_gart = gart
self.aql_desc = hsa.amd_queue_t(queue_properties=hsa.AMD_QUEUE_PROPERTIES_IS_PTR64 | hsa.AMD_QUEUE_PROPERTIES_ENABLE_PROFILING,
read_dispatch_id_field_base_byte_offset=getattr(hsa.amd_queue_t, 'read_dispatch_id').offset,
max_cu_id=(self.cu_cnt * self.xccs) - 1, max_wave_id=self.waves_per_cu - 1)
gart.cpu_view().view(fmt='B')[:ctypes.sizeof(aql_desc)] = bytes(aql_desc)
self.aql_desc = hsa.amd_queue_t.from_address(gart.cpu_view().addr)
self.aql_gart.cpu_view().view(fmt='B')[:ctypes.sizeof(self.aql_desc)] = bytes(self.aql_desc)
cwsr_buffer_size = round_up((ctx_save_restore_size + debug_memory_size) * self.xccs, mmap.PAGESIZE)
cwsr_buffer = self.iface.alloc(cwsr_buffer_size) if ctx_save_restore_size else None
@ -1067,6 +1067,7 @@ class AMDDevice(HCQCompiled):
int.from_bytes(rsrc1_t(BASE_ADDRESS_HI=hi32(self.scratch.va_addr), SWIZZLE_ENABLE=1), 'little'),
lo32(size_per_xcc), int.from_bytes(bytes(rsrc3_t(**rsrc)), 'little')]
self.aql_desc.compute_tmpring_size = self.tmpring_size
self.aql_gart.cpu_view()[:ctypes.sizeof(self.aql_desc)] = bytes(self.aql_desc)
def invalidate_caches(self):
self.hw_compute_queue_t().memory_barrier().signal(self.timeline_signal, self.next_timeline()).submit(self)

View file

@ -7,7 +7,7 @@ from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, H
from tinygrad.runtime.support.hcq import MMIOInterface, FileIOInterface, MOCKGPU, hcq_filter_visible_devices, hcq_profile
from tinygrad.uop.ops import sint
from tinygrad.device import Compiled, BufferSpec, CompilerSet
from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, prod, OSX, to_mv, hi32, lo32, NV_CC, NV_PTX, NV_NAK, NV_NVCC, PROFILE
from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, prod, OSX, hi32, lo32, NV_CC, NV_PTX, NV_NAK, NV_NVCC, PROFILE
from tinygrad.helpers import ContextVar, VIZ, ProfileEvent
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import CUDARenderer
@ -45,7 +45,7 @@ def nv_iowr(fd:FileIOInterface, nr, args, cmd=None):
class QMD:
fields: dict[str, dict[str, tuple[int, int]]] = {}
def __init__(self, dev:NVDevice, addr:int|None=None, **kwargs):
def __init__(self, dev:NVDevice, view:MMIOInterface|None=None, **kwargs):
self.ver, self.sz = (5, 0x60) if dev.iface.compute_class >= nv_gpu.BLACKWELL_COMPUTE_A else (3, 0x40)
# Init fields from module
@ -53,7 +53,7 @@ class QMD:
QMD.fields[pref] = {**{name[len(pref)+1:]: dt for name,dt in nv_gpu.__dict__.items() if name.startswith(pref) and isinstance(dt, tuple)},
**{name[len(pref)+1:]+f"_{i}": dt(i) for name,dt in nv_gpu.__dict__.items() for i in range(8) if name.startswith(pref) and callable(dt)}}
self.mv, self.pref = (memoryview(bytearray(self.sz * 4)) if addr is None else to_mv(addr, self.sz * 4)), pref
self.mv, self.pref = (memoryview(bytearray(self.sz * 4)) if view is None else view), pref
if kwargs: self.write(**kwargs)
def _rw_bits(self, hi:int, lo:int, value:int|None=None):
@ -140,7 +140,7 @@ class NVComputeQueue(NVCommandQueue):
qmd_buf.cpu_view().view(size=prg.qmd.mv.nbytes, fmt='B')[:] = prg.qmd.mv
assert qmd_buf.va_addr < (1 << 40), f"large qmd addr {qmd_buf.va_addr:x}"
qmd = QMD(dev=prg.dev, addr=qmd_buf.cpu_view().addr) # Save qmd for later update
qmd = QMD(dev=prg.dev, view=qmd_buf.cpu_view()) # Save qmd for later update
self.bind_sints_to_mem(*global_size, mem=qmd_buf.cpu_view(), fmt='I', offset=qmd.field_offset('cta_raster_width' if qmd.ver<4 else 'grid_width'))
self.bind_sints_to_mem(*(local_size[:2]), mem=qmd_buf.cpu_view(), fmt='H', offset=qmd.field_offset('cta_thread_dimension0'))
@ -825,7 +825,7 @@ class NVDevice(HCQCompiled[NVSignal]):
if params.bytesAvailable == 0: return None
start, end = self.pma_rptr, self.pma_rptr + params.bytesAvailable
pma_data = self.pma_buf.cpu_view()[start:min(end, self.pma_buf.size)] + self.pma_buf.cpu_view()[:max(0, end - self.pma_buf.size)]
pma_data = bytes(self.pma_buf.cpu_view()[start:min(end, self.pma_buf.size)]) + bytes(self.pma_buf.cpu_view()[:max(0, end - self.pma_buf.size)])
self.pma_rptr = end % self.pma_buf.size
self.iface.rm_control(self.profiler, nv_gpu.NVB0CC_CTRL_CMD_PMA_STREAM_UPDATE_GET_PUT,

View file

@ -112,7 +112,7 @@ class PythonProgram:
elif uop is Ops.VECTORIZE: values[i] = src_values
elif uop is Ops.BITCAST: values[i] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]]
elif uop is Ops.CAST:
values[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in src_values[0]]
values[i] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]]
elif uop is Ops.LOAD:
if dtype.count > 1:
values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \

View file

@ -113,8 +113,8 @@ class AddrSpace(enum.Enum): PHYS = enum.auto(); SYS = enum.auto(); PEER = enum.a
class VirtMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; aspace:AddrSpace; uncached:bool=False; snooped:bool=False # noqa: E702
class PageTableTraverseContext:
def __init__(self, dev, pt, vaddr, create_pts=False, free_pts=False, boot=False):
self.dev, self.vaddr, self.create_pts, self.free_pts, self.boot = dev, vaddr - dev.mm.va_base, create_pts, free_pts, boot
def __init__(self, dev, pt, vaddr, create_pts=False, free_pts=False, inspect=False, boot=False):
self.dev, self.vaddr, self.create_pts, self.free_pts, self.inspect, self.boot = dev, vaddr - dev.mm.va_base, create_pts, free_pts, inspect, boot
self.pt_stack:list[tuple[Any, int, int]] = [(pt, self._pt_pte_idx(pt, self.vaddr), self._pt_pte_size(pt))]
def _pt_pte_cnt(self, lv): return self.dev.mm.pte_cnt[lv]
@ -151,13 +151,17 @@ class PageTableTraverseContext:
def next(self, size:int, paddr:int|None=None, off:int=0):
while size > 0:
pt, pte_idx, pte_covers = self.pt_stack[-1]
# create_pts goes down until the page covers the request.
# free_pts goes down to the table, it assumses all entries are valid on the range (and validates that)
# inspect just visits any valid ranges and yields them.
if self.create_pts:
assert paddr is not None, "paddr must be provided when allocating new page tables"
while pte_covers > size or not pt.supports_huge_page(paddr+off) or self.vaddr&(pte_covers-1) != 0: pt, pte_idx, pte_covers = self.level_down()
else:
while not pt.is_page(pte_idx): pt, pte_idx, pte_covers = self.level_down()
while not pt.is_page(pte_idx) and (self.free_pts or pt.valid(pte_idx)): pt, pte_idx, pte_covers = self.level_down()
entries = min(size // pte_covers, self._pt_pte_cnt(pt.lv) - pte_idx)
entries = max(min(size // pte_covers, self._pt_pte_cnt(pt.lv) - pte_idx), 1 if self.inspect else 0)
assert entries > 0, f"Invalid entries {size=:#x}, {pte_covers=:#x}"
yield off, pt, pte_idx, entries, pte_covers
@ -197,11 +201,14 @@ class MemoryManager:
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, boot=boot, inspect=True)
for _, pt, pte_idx, pte_cnt, _ in ctx.next(size):
for pte_off in range(pte_cnt): assert not pt.valid(pte_idx + pte_off), f"PTE already mapped: {pt.entry(pte_idx + pte_off):#x}"
ctx = PageTableTraverseContext(self.dev, self.root_page_table, vaddr, create_pts=True, boot=boot)
for paddr, psize in paddrs:
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize, paddr=paddr):
for pte_off in range(pte_cnt):
assert not pt.valid(pte_idx + pte_off), f"PTE already mapped: {pt.entry(pte_idx + pte_off):#x}"
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, aspace=aspace, snooped=snooped,
frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
@ -240,8 +247,8 @@ class MemoryManager:
# Move to a smaller size and try again.
nxt_range += 1
if nxt_range == len(self.palloc_ranges):
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
raise MemoryError(f"Failed to allocate memory. (total allocation size={size:#x}, current try={self.palloc_ranges[nxt_range-1]})")
for paddr, _ in paddrs: self.pfree(paddr)
raise MemoryError(f"Failed to allocate memory (OOM). Request size={size:#x} ({self.palloc_ranges[nxt_range-1]})")
continue
rem_size -= self.palloc_ranges[nxt_range][0]
@ -251,7 +258,7 @@ class MemoryManager:
assert self.va_allocator is not None, "must be set"
self.unmap_range(vm.va_addr, vm.size)
self.va_allocator.free(vm.va_addr)
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
for paddr, _ in vm.paddrs: self.pfree(paddr)
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False, ptable=False) -> int:
assert self.dev.is_booting == boot, "During booting, only boot memory can be allocated"

View file

@ -2,8 +2,8 @@ from __future__ import annotations
import ctypes, time, array, struct, itertools, dataclasses
from typing import cast, Any
from tinygrad.runtime.autogen import nv, nv_570 as nv_gpu, pci
from tinygrad.helpers import to_mv, lo32, hi32, DEBUG, round_up, round_down, mv_address, fetch, wait_cond, ceildiv
from tinygrad.runtime.support.system import System
from tinygrad.helpers import lo32, hi32, DEBUG, round_up, round_down, fetch, wait_cond, ceildiv
from tinygrad.runtime.support.system import System, MMIOInterface
from tinygrad.runtime.support.elf import elf_loader
@dataclasses.dataclass(frozen=True)
@ -16,14 +16,17 @@ class NV_IP:
def fini_hw(self): pass # Finalize hw for this IP
class NVRpcQueue:
def __init__(self, gsp:NV_GSP, va:int, completion_q_va:int|None=None):
self.tx = nv.msgqTxHeader.from_address(va)
wait_cond(lambda: self.tx.entryOff, value=0x1000, msg="RPC queue not initialized")
def __init__(self, gsp:NV_GSP, view:MMIOInterface, completion_q_view:MMIOInterface|None=None):
self.tx_view = view.view(fmt='I')
wait_cond(lambda: self.tx_view[getattr(nv.msgqTxHeader, 'entryOff').offset // 4], value=0x1000, msg="RPC queue not initialized")
self.tx = nv.msgqTxHeader.from_buffer_copy(bytes(view[:ctypes.sizeof(nv.msgqTxHeader)]))
if completion_q_va is not None: self.rx = nv.msgqRxHeader.from_address(completion_q_va + nv.msgqTxHeader.from_address(completion_q_va).rxHdrOff)
if completion_q_view is not None:
comp_tx = nv.msgqTxHeader.from_buffer_copy(bytes(completion_q_view[:ctypes.sizeof(nv.msgqTxHeader)]))
self.rx_view = completion_q_view.view(comp_tx.rxHdrOff, fmt='I')
self.gsp, self.va, self.queue_va, self.seq = gsp, va, va + self.tx.entryOff, 0
self.queue_mv = to_mv(self.queue_va, self.tx.msgSize * self.tx.msgCount)
self.gsp, self.view, self.seq = gsp, view, 0
self.queue_mv = view.view(self.tx.entryOff, self.tx.msgSize * self.tx.msgCount)
def _checksum(self, data:bytes):
if (pad_len:=(-len(data)) % 8): data += b'\x00' * pad_len
@ -40,10 +43,11 @@ class NVRpcQueue:
phdr.checkSum = self._checksum(bytes(phdr) + msg)
msg = (bytes(phdr) + msg).ljust(phdr.elemCount * self.tx.msgSize, b'\x00')
off, first = self.tx.writePtr * self.tx.msgSize, min(len(msg), len(self.queue_mv) - self.tx.writePtr * self.tx.msgSize)
wp = self.tx_view[getattr(nv.msgqTxHeader, 'writePtr').offset // 4]
off, first = wp * self.tx.msgSize, min(len(msg), len(self.queue_mv) - wp * self.tx.msgSize)
self.queue_mv[off:off+first] = msg[:first]
if first < len(msg): self.queue_mv[:len(msg)-first] = msg[first:]
self.tx.writePtr = (self.tx.writePtr + phdr.elemCount) % self.tx.msgCount
self.tx_view[getattr(nv.msgqTxHeader, 'writePtr').offset // 4] = (wp + phdr.elemCount) % self.tx.msgCount
System.memory_barrier()
self.seq += 1
@ -56,20 +60,20 @@ class NVRpcQueue:
def read_resp(self):
System.memory_barrier()
while self.rx.readPtr != self.tx.writePtr:
off = self.rx.readPtr * self.tx.msgSize
hdr = nv.rpc_message_header_v.from_address(self.queue_va + off + 0x30)
msg = self.queue_mv[off + 0x50 : off + 0x50 + hdr.length]
while self.rx_view[0] != self.tx_view[getattr(nv.msgqTxHeader, 'writePtr').offset // 4]:
off = self.rx_view[0] * self.tx.msgSize
hdr = nv.rpc_message_header_v.from_buffer_copy(bytes(self.queue_mv[off + 0x30 : off + 0x30 + ctypes.sizeof(nv.rpc_message_header_v)]))
msg = bytes(self.queue_mv[off + 0x50 : off + 0x50 + hdr.length])
# Handling special functions
if hdr.function == nv.NV_VGPU_MSG_EVENT_GSP_RUN_CPU_SEQUENCER: self.gsp.run_cpu_seq(msg)
elif hdr.function == nv.NV_VGPU_MSG_EVENT_OS_ERROR_LOG:
print(f"nv {self.gsp.nvdev.devfmt}: GSP LOG: {msg[12:].tobytes().rstrip(bytes([0])).decode('utf-8')}")
print(f"nv {self.gsp.nvdev.devfmt}: GSP LOG: {msg[12:].rstrip(bytes([0])).decode('utf-8')}")
self.gsp.nvdev.is_err_state |= hdr.function in {nv.NV_VGPU_MSG_EVENT_OS_ERROR_LOG, nv.NV_VGPU_MSG_EVENT_MMU_FAULT_QUEUED}
# Update the read pointer
self.rx.readPtr = (self.rx.readPtr + round_up(hdr.length, self.tx.msgSize) // self.tx.msgSize) % self.tx.msgCount
self.rx_view[0] = (self.rx_view[0] + round_up(hdr.length, self.tx.msgSize) // self.tx.msgSize) % self.tx.msgCount
System.memory_barrier()
if DEBUG >= 3:
@ -79,7 +83,7 @@ class NVRpcQueue:
if hdr.rpc_result != 0: raise RuntimeError(f"RPC call {hdr.function} failed with result {hdr.rpc_result}")
yield hdr.function, msg
def wait_resp(self, cmd:int, timeout:int=10000) -> memoryview:
def wait_resp(self, cmd:int, timeout:int=10000) -> bytes:
start_time = int(time.perf_counter() * 1000)
while (int(time.perf_counter() * 1000) - start_time) < timeout:
if (msg:=next((message for func, message in self.read_resp() if func == cmd), None)) is not None: return msg
@ -289,7 +293,7 @@ class NV_FLCN_COT(NV_IP):
self.nvdev.include("src/nvidia/arch/nvalloc/common/inc/fsp/fsp_mctp_format.h")
self.nvdev.include("src/nvidia/arch/nvalloc/common/inc/fsp/fsp_emem_channels.h")
self.fmc_boot_args, self.fmc_boot_args_sysmem = self.nvdev._alloc_boot_struct(nv.GSP_FMC_BOOT_PARAMS())
self.fmc_boot_args_view, self.fmc_boot_args_sysmem = self.nvdev._alloc_boot_struct(nv.GSP_FMC_BOOT_PARAMS())
self.init_fmc_image()
def init_fmc_image(self):
@ -302,9 +306,10 @@ class NV_FLCN_COT(NV_IP):
def init_hw(self):
self.falcon = 0x00110000
self.fmc_boot_args.bootGspRmParams = nv.GSP_ACR_BOOT_GSP_RM_PARAMS(gspRmDescOffset=self.nvdev.gsp.wpr_meta_sysmem,
boot_args = nv.GSP_ACR_BOOT_GSP_RM_PARAMS(gspRmDescOffset=self.nvdev.gsp.wpr_meta_sysmem,
gspRmDescSize=ctypes.sizeof(nv.GspFwWprMeta), target=nv.GSP_DMA_TARGET_COHERENT_SYSTEM, bIsGspRmBoot=True)
self.fmc_boot_args.gspRmParams = nv.GSP_RM_PARAMS(bootArgsOffset=self.nvdev.gsp.libos_args_sysmem[0], target=nv.GSP_DMA_TARGET_COHERENT_SYSTEM)
rm_args = nv.GSP_RM_PARAMS(bootArgsOffset=self.nvdev.gsp.libos_args_sysmem[0], target=nv.GSP_DMA_TARGET_COHERENT_SYSTEM)
self.fmc_boot_args_view[:ctypes.sizeof(nv.GSP_FMC_BOOT_PARAMS)] = bytes(nv.GSP_FMC_BOOT_PARAMS(bootGspRmParams=boot_args, gspRmParams=rm_args))
cot_payload = nv.NVDM_PAYLOAD_COT(version=0x2, size=ctypes.sizeof(nv.NVDM_PAYLOAD_COT), frtsVidmemOffset=0x1c00000, frtsVidmemSize=0x100000,
gspBootArgsSysmemOffset=self.fmc_boot_args_sysmem, gspFmcSysmemOffset=self.fmc_booter_sysmem[0])
@ -365,25 +370,24 @@ class NV_GSP(NV_IP):
_, self.rm_args_sysmem = self.nvdev._alloc_boot_struct(nv.GSP_ARGUMENTS_CACHED(bDmemStack=True, messageQueueInitArguments=queue_args))
# Build command queue header
self.cmd_q_va, self.stat_q_va = queues_view.addr + pt_size, queues_view.addr + pt_size + queue_size
# self.cmd_q_va, self.stat_q_va = queues_view.addr + pt_size, queues_view.addr + pt_size + queue_size
self.cmd_q_view, self.stat_q_view = queues_view.view(pt_size), queues_view.view(pt_size + queue_size)
cmd_q_tx = nv.msgqTxHeader(version=0, size=queue_size, entryOff=0x1000, msgSize=0x1000, msgCount=(queue_size - 0x1000) // 0x1000,
writePtr=0, flags=1, rxHdrOff=ctypes.sizeof(nv.msgqTxHeader))
to_mv(self.cmd_q_va, ctypes.sizeof(nv.msgqTxHeader))[:] = bytes(cmd_q_tx)
self.cmd_q_view[:ctypes.sizeof(nv.msgqTxHeader)] = bytes(nv.msgqTxHeader(version=0, size=queue_size, entryOff=0x1000, msgSize=0x1000,
msgCount=(queue_size - 0x1000) // 0x1000, writePtr=0, flags=1, rxHdrOff=ctypes.sizeof(nv.msgqTxHeader)))
self.cmd_q = NVRpcQueue(self, self.cmd_q_va, None)
self.cmd_q = NVRpcQueue(self, self.cmd_q_view, None)
def init_libos_args(self):
_, logbuf_sysmem = self.nvdev._alloc_sysmem((2 << 20), contiguous=True)
libos_args_view, self.libos_args_sysmem = self.nvdev._alloc_sysmem(0x1000, contiguous=True)
libos_structs = (nv.LibosMemoryRegionInitArgument * 6).from_address(libos_args_view.addr)
for i, name in enumerate(["INIT", "INTR", "RM", "MNOC", "KRNL"]):
libos_structs[i] = nv.LibosMemoryRegionInitArgument(kind=nv.LIBOS_MEMORY_REGION_CONTIGUOUS, loc=nv.LIBOS_MEMORY_REGION_LOC_SYSMEM, size=0x10000,
libos_structs = [nv.LibosMemoryRegionInitArgument(kind=nv.LIBOS_MEMORY_REGION_CONTIGUOUS, loc=nv.LIBOS_MEMORY_REGION_LOC_SYSMEM, size=0x10000,
id8=int.from_bytes(bytes(f"LOG{name}", 'utf-8'), 'big'), pa=logbuf_sysmem[0] + 0x10000 * i)
libos_structs[5] = nv.LibosMemoryRegionInitArgument(kind=nv.LIBOS_MEMORY_REGION_CONTIGUOUS, loc=nv.LIBOS_MEMORY_REGION_LOC_SYSMEM, size=0x1000,
id8=int.from_bytes(bytes("RMARGS", 'utf-8'), 'big'), pa=self.rm_args_sysmem)
for i, name in enumerate(["INIT", "INTR", "RM", "MNOC", "KRNL"])]
libos_structs.append(nv.LibosMemoryRegionInitArgument(kind=nv.LIBOS_MEMORY_REGION_CONTIGUOUS, loc=nv.LIBOS_MEMORY_REGION_LOC_SYSMEM, size=0x1000,
id8=int.from_bytes(bytes("RMARGS", 'utf-8'), 'big'), pa=self.rm_args_sysmem))
libos_args_view[:sum(ctypes.sizeof(s) for s in libos_structs)] = b''.join(bytes(s) for s in libos_structs)
def init_gsp_image(self):
fw = fetch("https://github.com/NVIDIA/linux-firmware/raw/refs/heads/nvidia-staging/nvidia/ga102/gsp/gsp-570.144.bin", subdir="fw").read_bytes()
@ -488,8 +492,8 @@ class NV_GSP(NV_IP):
self.rpc_rm_alloc(hParent=ch_gpfifo, hClass=self.dma_class, params=None)
def init_hw(self):
self.stat_q = NVRpcQueue(self, self.stat_q_va, self.cmd_q_va)
self.cmd_q.rx = nv.msgqRxHeader.from_address(self.stat_q.va + self.stat_q.tx.rxHdrOff)
self.stat_q = NVRpcQueue(self, self.stat_q_view, self.cmd_q_view)
self.cmd_q.rx_view = self.stat_q_view.view(self.stat_q.tx.rxHdrOff, fmt='I')
self.stat_q.wait_resp(nv.NV_VGPU_MSG_EVENT_GSP_INIT_DONE)
@ -599,9 +603,9 @@ class NV_GSP(NV_IP):
header = nv.PACKED_REGISTRY_TABLE(size=hdr_size + len(entries_bytes) + len(data_bytes), numEntries=len(table))
self.cmd_q.send_rpc(nv.NV_VGPU_MSG_FUNCTION_SET_REGISTRY, bytes(header) + entries_bytes + data_bytes)
def run_cpu_seq(self, seq_buf:memoryview):
hdr = nv.rpc_run_cpu_sequencer_v17_00.from_address(mv_address(seq_buf))
cmd_iter = iter(seq_buf[ctypes.sizeof(nv.rpc_run_cpu_sequencer_v17_00):].cast('I')[:hdr.cmdIndex])
def run_cpu_seq(self, seq_buf:bytes):
hdr = nv.rpc_run_cpu_sequencer_v17_00.from_buffer_copy(seq_buf[:(hdr_sz:=ctypes.sizeof(nv.rpc_run_cpu_sequencer_v17_00))])
cmd_iter = iter(memoryview(seq_buf[hdr_sz:]).cast('I')[:hdr.cmdIndex])
for op in cmd_iter:
if op == 0x0: self.nvdev.wreg(next(cmd_iter), next(cmd_iter)) # reg write

View file

@ -151,10 +151,10 @@ class NVDev:
if data is not None: view[:size] = data
return view, paddrs
def _alloc_boot_struct(self, struct:ctypes.Structure) -> tuple[ctypes.Structure, int]:
def _alloc_boot_struct(self, struct:ctypes.Structure) -> tuple[MMIOInterface, int]:
view, paddrs = self._alloc_sysmem(sz:=ctypes.sizeof(type(struct)), contiguous=True)
view[:sz] = bytes(struct)
return type(struct).from_address(view.addr), paddrs[0]
return view, paddrs[0]
def _download(self, file:str) -> str:
url = f"https://raw.githubusercontent.com/NVIDIA/open-gpu-kernel-modules/8ec351aeb96a93a4bb69ccc12a542bf8a8df2b6f/{file}"

View file

@ -284,7 +284,7 @@ class PCIIfaceBase:
# *** Remote PCI Devices
class RemoteCmd(enum.IntEnum):
PROBE, MAP_BAR, MAP_SYSMEM_FD, CFG_READ, CFG_WRITE, RESET, MMIO_READ, MMIO_WRITE, MAP_SYSMEM, SYSMEM_READ, SYSMEM_WRITE, RESIZE_BAR = range(12)
PROBE,MAP_BAR,MAP_SYSMEM_FD,CFG_READ,CFG_WRITE,RESET,MMIO_READ,MMIO_WRITE,MAP_SYSMEM,SYSMEM_READ,SYSMEM_WRITE,RESIZE_BAR,PING = range(13)
class RemoteMMIOInterface(MMIOInterface):
def __init__(self, dev:RemotePCIDevice, residx:int, nbytes:int, fmt='B', off=0, rd_cmd=RemoteCmd.MMIO_READ, wr_cmd=RemoteCmd.MMIO_WRITE):
@ -314,7 +314,9 @@ class RemotePCIDevice(PCIDevice):
host, port = host_port[0], int(host_port[1]) if len(host_port) > 1 else 6667
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.settimeout(getenv("REMOTE_TIMEOUT", 3))
sock.connect((host, port))
sock.settimeout(None)
return sock
@staticmethod
@ -355,7 +357,7 @@ class RemotePCIDevice(PCIDevice):
self.sock.sendall(struct.pack('<BIIQQQ', cmd, self.dev_id, idx, offset, len(data), 0) + data)
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]:
paddrs_len, handle, _, _ = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM, size)
paddrs_len, handle, _, _ = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM, size, int(contiguous))
paddrs = list(struct.unpack(f'<{paddrs_len // 8}Q', self._recvall(self.sock, paddrs_len)))
return RemoteMMIOInterface(self, handle, size, fmt='B', rd_cmd=RemoteCmd.SYSMEM_READ, wr_cmd=RemoteCmd.SYSMEM_WRITE), paddrs
@ -396,7 +398,7 @@ class APLRemotePCIDevice(RemotePCIDevice):
super().__init__(devpref, "usb4", sock=sock)
def alloc_sysmem(self, size:int, vaddr:int=0, contiguous:bool=False) -> tuple[MMIOInterface, list[int]]:
mapped_size, _, _, fd = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM_FD, size, has_fd=True)
mapped_size, _, _, fd = self._rpc(self.sock, self.dev_id, RemoteCmd.MAP_SYSMEM_FD, size, int(contiguous), has_fd=True)
memview = MMIOInterface(FileIOInterface(fd=fd).mmap(0, mapped_size, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, 0), mapped_size, fmt='B')
# paddrs are returned as (paddr, size) pairs until a (paddr=0, size=0) terminator in the beginning of the mapping.

View file

@ -1,6 +1,6 @@
import functools, itertools
from tinygrad.helpers import all_int, prod, DEBUG, RING, ALL2ALL, getenv
from tinygrad.uop.ops import Ops, UOp
from tinygrad.uop.ops import UOp, Invalid
# *** allreduce implementation ***
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
@ -56,7 +56,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
def create_allreduce_function(buf:UOp, red:UOp, output:UOp|None=None) -> UOp|None:
# BUFFER without unique have unique added later
if output is None: output = UOp(Ops.BUFFER, red.dtype, (UOp(Ops.NOOP), red.src[1]), red.size).reshape(red.shape)
if output is None: output = UOp.unique_const(red.dtype, Invalid, red.device, red.shape).contiguous()
to = red.param_like(0)
src = buf.param_like(1)
red = src.allreduce(red.arg, red.src[1])

View file

@ -151,11 +151,14 @@ multi_pm = PatternMatcher([
else multi),
# rewrite into calls explicitly for MULTI
(UPat(Ops.CALL, name="call"), rewrite_into_call),
(UPat((Ops.CALL, Ops.AFTER, Ops.STORE), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
(UPat((Ops.CALL, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
# we just remove the MULTI from non-value-producing CALLs (custom kernels, etc.) — TUPLE body CALLs are handled by rewrite_into_call
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)
if root.src[0].op is not Ops.TUPLE else None),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
# remove MULTI from STORE
(UPat(Ops.STORE, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True),
lambda root,multi: UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg)),
])+replace_allreduce

View file

@ -411,12 +411,6 @@ def flatten_bufferize(x:UOp):
return ret
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
def resolve_anonymous_buffer(ctx:itertools.count, b:UOp, c:UOp) -> UOp|None:
dab = b.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx)),)+b.src[1:])
nc_src = tuple(dab if x is b else x for x in c.src)
if nc_src == c.src: return None
return dab.after(c.replace(src=nc_src))
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
@ -432,8 +426,10 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
# remove double AFTER
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.AFTER, name="y"))), lambda x,y: x.after(*y.src[1:])),
# resolve anonymous buffers
(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, src=(UPat(Ops.NOOP),), name="b", allow_any_len=True), UPat(Ops.CALL, name="c"))), resolve_anonymous_buffer),
# remove invalid writes
(UPat(Ops.STORE, src=(UPat(), UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONST, arg=Invalid),))), allow_any_len=True), lambda: UOp(Ops.NOOP)),
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.NOOP, src=()))), lambda x: x),
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.END, src=(UPat(Ops.NOOP, src=()),), allow_any_len=True))), lambda x: x),
])
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([

View file

@ -20,7 +20,8 @@ from tinygrad.engine.allocations import transform_to_call
# TODO: this should be the only usage of Device
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
return tuple(Device.canonicalize(d) for d in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
if not isinstance(device, (tuple, list)): return Device.canonicalize(device)
return canonical[0] if len(canonical:=tuple(Device.canonicalize(d) for d in device)) == 1 else canonical
# *** all in scope Tensors are here. this gets relevant UOps ***
@ -61,7 +62,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp:
ret = UOp.new_buffer("PYTHON", prod(shape:=get_shape(x)), dtype).reshape(shape)
assert dtype.fmt is not None, f"{dtype=} has None fmt"
truncate_function = truncate[dtype]
data = struct.pack(f"{ret.size}{dtype.fmt}", *[truncate_function(dtypes.as_const(xi, dtype)) for xi in fully_flatten(x)])
data = struct.pack(f"{ret.size}{dtype.fmt}", *[truncate_function(dtype.const(xi)) for xi in fully_flatten(x)])
# fake realize
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
return ret
@ -115,7 +116,9 @@ class Tensor(OpMixin):
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
if device is None:
if isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
elif isinstance(data, UOp): device = data._device
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
_device:str|tuple[str, ...] = canonicalize_device(device)
del device, dtype
@ -138,24 +141,24 @@ class Tensor(OpMixin):
const = UOp.const(var.dtype, val, _device, ())
data = data.replace(src=(var.replace(src=const.src), const))
elif data is None:
data = Tensor(0, device=_device, dtype=_dtype or dtypes.default_float, requires_grad=requires_grad).uop
elif isinstance(data, get_args(PyConst)):
data = UOp.const(_dtype or dtypes.default_float, 0, _device)
elif isinstance(data, get_args(ConstType)):
data = (UOp.unique_const if _force_unique or requires_grad else UOp.const)(_dtype or dtypes.from_py(data), data, _device)
elif isinstance(data, InvalidType):
assert _dtype is not None
data = UOp.const(_dtype, data, _device)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if _dtype is None else _dtype)
elif isinstance(data, bytes): data = _frompy(data, _dtype or dtypes.uint8)
elif isinstance(data, (list, tuple)):
if _dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): _dtype = dtypes.bool
else: _dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
if _dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = Tensor(_frompy(data, dtypes.float32), device=_device).cast(_dtype).uop
if _dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = _frompy(data, dtypes.float32).cast(_dtype)
else: data = _frompy(data, _dtype)
elif is_numpy_ndarray(data):
import numpy as np
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
if data.shape == ():
data = Tensor(data.item(), device=_device, dtype=_dtype or _from_np_dtype(data.dtype), requires_grad=requires_grad).uop
data = UOp.const(_dtype or _from_np_dtype(data.dtype), data.item(), _device)
else:
data = _fromnp(data.astype(npdtype) if _dtype is not None and (npdtype:=_to_np_dtype(_dtype)) is not None else data)
elif isinstance(data, pathlib.Path):
@ -166,12 +169,7 @@ class Tensor(OpMixin):
if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
# data might be on a different device
if isinstance(_device, str): self.uop:UOp = data if data.device == _device else data.copy_to_device(_device)
# if device is a tuple, we should have/construct a multi-device UOp
elif isinstance(data.device, str): self.uop = Tensor(data).shard(_device).uop
else:
assert data.device == _device, f"multi-device UOp device mismatch, {data.device} != {_device}"
self.uop = data
self.uop:UOp = data if data.device == _device else data.copy_to_device(_device)
# add to all_tensors after construction succeeds
all_tensors[weakref.ref(self)] = None
@ -197,7 +195,7 @@ class Tensor(OpMixin):
lhs,rhs = self._broadcasted(x, reverse)
return lhs._apply_uop(lambda *u: u[0].alu(op, *u[1:]), rhs)
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
def const_like(self, b:ConstType) -> Tensor: return Tensor(dtypes.as_const(b, self.dtype), self.device, self.dtype, requires_grad=False)
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.dtype.const(b), self.device, self.dtype, requires_grad=False)
def requires_grad_(self, requires_grad=True) -> Tensor:
# make the UOp unique if it's a CONST to prevent gradient accumulation bugs with cached const UOps
@ -240,10 +238,10 @@ class Tensor(OpMixin):
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device).multi(self.uop.axis)
else:
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param, device=self.device)
return Tensor(param)
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
return Tensor(fret.gettuple(0), device=self.device)
return Tensor(fret.gettuple(0))
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
"""
@ -251,7 +249,7 @@ class Tensor(OpMixin):
This API is alpha and may change.
"""
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
def callify(self, *lst:Tensor) -> Tensor:
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
@ -326,7 +324,7 @@ class Tensor(OpMixin):
"""
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
"""
return Tensor(self.uop.detach(), device=self.device, requires_grad=False)
return Tensor(self.uop.detach(), requires_grad=False)
def _buffer(self) -> Buffer:
from tinygrad.engine.realize import capturing
@ -411,10 +409,8 @@ class Tensor(OpMixin):
"""
Moves the tensor to the given device.
"""
device = canonicalize_device(device)
if device == self.device: return self
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
if (device:=canonicalize_device(device)) == self.device: return self
ret = Tensor(self.uop.copy_to_device(device), requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.to(device)
return ret
@ -438,8 +434,8 @@ class Tensor(OpMixin):
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
if len(devices) == 1: return self.to(devices[0])
devices = cast(tuple[str, ...], canonicalize_device(devices))
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
uop = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
return Tensor(uop, requires_grad=self.requires_grad)
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
"""
@ -534,7 +530,7 @@ class Tensor(OpMixin):
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
# TODO: add test for multidevice tensor
device = canonicalize_device(device)
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
return Tensor(UOp.new_buffer(device, size, dtype), **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, **kwargs) -> Tensor:
"""
@ -543,7 +539,7 @@ class Tensor(OpMixin):
"""
dtype, device = self.dtype if dtype is None else dtype, self.device if device is None else device
if isinstance(device, tuple) and (axis := self.uop.axis) is not None:
return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis), device=device)
return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis))
return Tensor.empty(self.shape, dtype=dtype, device=device, **kwargs)
@staticmethod
@ -633,29 +629,24 @@ class Tensor(OpMixin):
Tensor._device_rng_counters[device] = Tensor([0, 0], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
# increment rng counter for devices
new_low = Tensor._device_rng_counters[device][0] + (num & 0xffffffff)
new_high = Tensor._device_rng_counters[device][1] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32)
Tensor._device_rng_counters[device].assign(Tensor.stack(new_low, new_high))
new_low = Tensor._device_rng_counters[device][0:1] + (num & 0xffffffff)
new_high = Tensor._device_rng_counters[device][1:2] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32)
Tensor._device_rng_counters[device].assign(new_low.cat(new_high))
low = Tensor._device_rng_counters[device][0] - (num & 0xffffffff)
high = Tensor._device_rng_counters[device][1] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32)
low = Tensor._device_rng_counters[device][0:1] - (num & 0xffffffff)
high = Tensor._device_rng_counters[device][1:2] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32)
# threefry random bits
if num > dtypes.uint32.max:
bits_list = []
for i in range(0, num, dtypes.uint32.max):
chunk_num = min(num - i, dtypes.uint32.max)
c_low = low + (i & 0xffffffff)
c_high = high + (i >> 32) + (c_low < low).cast(dtypes.uint32)
new_key = Tensor._threefry_random_bits(Tensor._device_seeds[device], c_low, c_high)
counts0 = Tensor.arange(ceildiv(chunk_num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)
counts1 = counts0 + ceildiv(chunk_num, 2)
bits_list.append(Tensor._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
bits = Tensor.cat(*bits_list)
else:
counts0 = Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False) + low
counts1 = counts0 + ceildiv(num, 2)
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
bits_list = []
for i in range(0, num, dtypes.uint32.max):
chunk_num = min(num - i, dtypes.uint32.max)
c_low = low + (i & 0xffffffff)
c_high = high + (i >> 32) + (c_low < low).cast(dtypes.uint32)
new_key = Tensor._threefry_random_bits(Tensor._device_seeds[device], c_low, c_high)
counts0 = Tensor.arange(ceildiv(chunk_num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)
counts1 = counts0 + ceildiv(chunk_num, 2)
bits_list.append(Tensor._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
bits = Tensor.cat(*bits_list)
# bitcast to uint with same number of bits
_, nmant = dtypes.finfo(dt)
@ -671,7 +662,7 @@ class Tensor(OpMixin):
# ***** creation helper functions *****
@staticmethod
def full(shape:tuple[sint, ...], fill_value:PyConst, **kwargs) -> Tensor:
def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with the given value.
@ -687,6 +678,17 @@ class Tensor(OpMixin):
"""
return Tensor(fill_value, _force_unique=True, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
def invalid(*shape, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return Tensor.full(argfix(*shape), Invalid, **kwargs)
@staticmethod
def zeros(*shape, **kwargs) -> Tensor:
"""
@ -798,8 +800,8 @@ class Tensor(OpMixin):
dtype = kwargs.pop("dtype", self.dtype)
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device])
return Tensor(stacked.multi(self.uop.axis))
def full_like(self, fill_value:PyConst, **kwargs) -> Tensor:
"""
@ -1061,14 +1063,13 @@ class Tensor(OpMixin):
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
target_uops = [x.uop for x in targets]
grads = compute_gradient(self.uop, gradient.uop, set(target_uops))
ret = []
ret:list[Tensor] = []
for x in target_uops:
if (y:=grads.get(x)) is None:
if materialize_grads: y = x.const_like(0)
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.uop}")
ret.append(y)
# create returned Tensors
return [Tensor(u, device=t.device) for t,u in zip(targets, ret)]
ret.append(Tensor(y))
return ret
def backward(self, gradient:Tensor|None=None) -> Tensor:
"""
@ -1167,11 +1168,6 @@ class Tensor(OpMixin):
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
raise NotImplementedError(f"{mode=} is not supported")
# convenience
def pad_to(self, shape, *args):
if len(new_shape := argfix(shape, *args)) != self.ndim: raise ValueError(f"dim mismatch, cannot pad {self.shape} to {new_shape}")
return self.pad(tuple([None if ns is None else (0, ns-s) for s,ns in zip(self.shape, new_shape)]))
# ***** movement high level ops *****
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
@ -1710,7 +1706,7 @@ class Tensor(OpMixin):
"""
return self._reduce(Ops.MAX, axis, keepdim)
def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self
def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
@ -2879,16 +2875,6 @@ class Tensor(OpMixin):
"""
return self.cast(dtypes.bool).ne(True)
def neg(self) -> Tensor:
"""
Negates the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
```
"""
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
def contiguous(self, *args, **kwargs) -> Tensor:
"""
Returns a contiguous tensor.
@ -2975,7 +2961,7 @@ class Tensor(OpMixin):
y_dtype = x.dtype
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
else: y = Tensor(y_dtype.const(y), x.device, y_dtype, requires_grad=False)
if match_dtype and x.dtype != y.dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
@ -3163,8 +3149,6 @@ class Tensor(OpMixin):
# ***** op wrappers *****
def __invert__(self) -> Tensor: return self.bitwise_not()
# TODO: combine with UOps __floordiv__
def __floordiv__(self, x): return self.div(x, rounding_mode="floor")
def __rfloordiv__(self, x): return self.div(x, rounding_mode="floor", reverse=True)
@ -3179,7 +3163,7 @@ class Tensor(OpMixin):
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
# unlike Tensors, UOps are immutable, so these don't go in MathTraits
# unlike Tensors, UOps are immutable, so these don't go in mixin
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x)) # type: ignore[misc]
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x)) # type: ignore[misc]
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x)) # type: ignore[misc]
@ -3205,7 +3189,7 @@ class Tensor(OpMixin):
assert frame_pos.op is Ops.BIND, "frame_pos must be a bound Variable"
srcs = (out:=Tensor.empty(*shape, device=self.device, dtype=self.dtype), self.contiguous(), state.contiguous(), *ref_frames)
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)), device=self.device)
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)))
# ***** functional nn ops *****
@ -3359,9 +3343,10 @@ class Tensor(OpMixin):
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
return reductions[reduction](self)
if reduction == "none": return self
if reduction == "sum": return self.sum()
if reduction == "mean": return self.mean()
raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
"""
@ -3408,14 +3393,12 @@ class Tensor(OpMixin):
```
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
assert reduction in get_args(ReductionStr), f"reduction must be one of {get_args(ReductionStr)}"
log_probs = self.log_softmax()
loss_mask = (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
y = Y.to(self.device).unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1)
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced))
return -unreduced.sum() / loss_mask.sum() if reduction == "mean" else -unreduced._do_reduction(reduction)
def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
"""
@ -3463,7 +3446,7 @@ class Tensor(OpMixin):
print(t.log_softmax().nll_loss(Y, reduction='none').numpy())
```
"""
weight = Tensor.ones_like(Y, requires_grad=False) if weight is None else weight[Y]
weight = Y.ones_like(requires_grad=False) if weight is None else weight[Y]
masked_weight = weight if ignore_index is None else weight * (Y != ignore_index)
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)

View file

@ -29,7 +29,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op], dt)
def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op])
# With True as the default, this matches the old symbolic behavior
def resolve(x:UOp|bool, default:bool=True):
@ -472,15 +472,16 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert len(b) > 0, "can't create const from empty tuple"
b = b[0] # doesn't have to be a VCONST if they are all the same
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
arg=dtypes.as_const(b, dtype),
arg=dtype.const(b),
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
@staticmethod
def unique_const(dtype:DType, b:ConstType, device:str|tuple[str, ...], unique=True):
def unique_const(dtype:DType, b:ConstType, device:str|tuple[str, ...], shape:tuple[sint, ...]|None=None, unique=True):
# NOTE: b is ConstType, not ConstLike, so UOps and tuples aren't allowed
assert not isinstance(b, (UOp, tuple)), "unique const only works on numbers"
ret = UOp.const(dtype, b, device)
return ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src)
ret = ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src)
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@ -727,6 +728,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return buf.view(self.size, self.dtype, 0)
if self.op is Ops.BUFFER_VIEW:
buf = self.src[0].buffer
if isinstance(buf, MultiBuffer):
mbuf = MultiBuffer.__new__(MultiBuffer)
mbuf.bufs = [b.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize) for b in buf.bufs]
return mbuf
assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW"
return buf.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize)
if self.op is Ops.MSELECT:

View file

@ -36,7 +36,7 @@ shared_spec = PatternMatcher([
(UPat(Ops.SINK, dtypes.void), lambda: True), # NOTE: for testing, we let sinks be anything
# CONST/DEFINE_VAR are everywhere
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
# ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
@ -88,7 +88,7 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE, Ops.NOOP)), UPat(Ops.DEVICE)), name="buf"),
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
# BUFFER_VIEW on BUFFER is allowed if BUFFER is
@ -239,7 +239,7 @@ program_spec = PatternMatcher([
(UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and
type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
type(x.arg) is type(x.dtype.const(x.arg))),
# if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),

View file

@ -65,8 +65,10 @@ propagate_invalid = PatternMatcher([
(UPat(GroupOp.Binary-GroupOp.Comparison, src=[invalid_pat, UPat()]), lambda i: i),
# normalize where(cond, Invalid, val) -> where(~cond, val, Invalid)
(UPat.var("cond").where(invalid_pat, UPat.var("val")), lambda cond, i, val: cond.logical_not().where(val, i) if val.arg != Invalid else i),
(UPat.var("a").where(invalid_gate, UPat.var("c")), lambda cond,i,x,a,c: cond.where(a.where(x, c), i) if c.arg != Invalid else None),
(UPat.var("a").where(UPat.var("b"), invalid_gate), lambda cond,i,x,a,b: cond.where(a.where(b, x), i) if b.arg != Invalid else None),
# lift Invalid out # TODO: this `a is cond` is asymmetric to preserve the pattern
(UPat.var("a").where(invalid_gate, UPat.var("c")), lambda cond,i,x,a,c:
(cond if a is cond else (a.logical_not()|cond)).where(a.where(x,c), i) if c.arg != Invalid else None),
(UPat.var("a").where(UPat.var("b"), invalid_gate), lambda cond,i,x,a,b: (a|cond).where(a.where(b, x), i) if b.arg != Invalid else None),
(UPat(Ops.BITCAST, src=(invalid_pat,), name="bc"), lambda bc,i: i.cast(bc.dtype)),
(UPat(Ops.BITCAST, src=(invalid_gate,), name="bc"), lambda bc,cond,x,i: cond.where(x.bitcast(bc.dtype), i.bitcast(bc.dtype))),
])

View file

@ -146,7 +146,6 @@
#insts .line {
display: flex;
flex-direction: column;
cursor: pointer;
margin-bottom: 8px;
}
#insts .left {
@ -157,10 +156,6 @@
#insts .left.highlight {
background-color: rgba(0, 199, 47, 0.2);
}
#insts .wave {
color: #7aa2f7;
min-width: 2ch;
}
#insts .pc {
color: #73daca;
}

View file

@ -300,7 +300,12 @@ function setFocus(key) {
if (eventType === EventTypes.EXEC) {
const [n, _, ...rest] = e.arg.tooltipText.split("\n");
const tableData = [["Name", colored(e.arg.label)], ["Duration", formatTime(e.width)]];
data.instSt != null ? tableData.push(["Start Cycle", formatTime(e.x)], ["Timestamp", timeAtCycle(e.x)]) : tableData.push(["Start Time", formatTime(e.x)]);
if (data.instSt != null) {
const p = d3.create("p");
p.append("span").text(timeAtCycle(e.x));
p.append("span").style("margin-left", "8px").style("color", "#f0f0f566").text(formatTime(e.x));
tableData.push(["Cycle", formatTime(e.x-data.instSt)], ["Time", p.node()]);
} else tableData.push(["Start Time", formatTime(e.x)]);
html.append(() => tabulate(tableData));
let group = html.append("div").classed("args", true);
for (const r of rest) group.append("p").text(r);
@ -335,23 +340,22 @@ function setFocus(key) {
}
// instructions list renderer
let instList = document.getElementById("insts");
if (data.pcToShape.size == 0) return d3.select(instList?.parentElement).html("");
if (data.pcMap == null) return d3.select(instList?.parentElement).html("");
if (instList == null) {
let contents = "";
for (const [k, v] of data.pcToShape) {
const pcHex = v.pc.toString(16);
contents += `<div class="line" data-k="${k}"><span class="left" id="inst-${k}"><span class="wave">${v.wave}</span>
<span class="pc">${"0x"+pcHex.padStart(Math.max(4, Math.ceil(pcHex.length/4)*4), 0)}</span><span class="label">${data.pcMap[v.pc]}</span></div>`;
for (let [pc, label] of Object.entries(data.pcMap)) {
pc = parseInt(pc);
const pcHex = pc.toString(16);
contents += `<div class="line"><span class="left" id="inst-${pc}"><span class="pc">${"0x"+pcHex.padStart(Math.max(4, Math.ceil(pcHex.length/4)*4), 0)}</span><span class="label">${label}</span></span></div>`;
}
instList = d3.create("pre").append("code").classed("hljs", true).style("margin-top", "20px").attr("id", "insts").html(contents)
.on("click", e => { const line = e.target.closest(".line"); line && setFocus(line.dataset.k); }).node();
instList = d3.create("pre").append("code").classed("hljs", true).style("margin-top", "20px").attr("id", "insts").html(contents).node();
metadata.insertBefore(instList.parentElement, html.node());
}
d3.select(instList).selectAll("span").classed("highlight", false);
const instLine = document.getElementById(`inst-${key}`); instLine?.classList.add("highlight");
const instLine = document.getElementById(`inst-${e?.arg.pc}`); instLine?.classList.add("highlight");
if (instLine != null) {
const r = rect(instLine), c = rect(instList);
if (Math.max(c.top-r.bottom, r.top-c.bottom)>=-30) instLine.scrollIntoView({ block:"center" });
if (Math.max(c.top-r.bottom, r.top-c.bottom)>=-30) instList.scrollTop = instLine.offsetTop-instList.clientHeight/2+instLine.clientHeight/2;
}
}
@ -362,7 +366,7 @@ async function renderProfiler(path, opts) {
displaySelection("#profiler");
// support non realtime x axis units
formatTime = opts.unit === "ms" ? formatMicroseconds : formatCycles;
if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null, pcToShape:new Map()}; focusedDevice = null; focusedShape = null; }
if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null}; focusedDevice = null; focusedShape = null; }
setFocus(focusedShape);
// layout once!
if (data.tracks.size !== 0) return updateProgress(Status.COMPLETE);
@ -451,10 +455,10 @@ async function renderProfiler(path, opts) {
}
// tiny device events go straight to the rewrite rule
const key = k.startsWith("TINY") ? null : `${k}-${j}`;
let info = e.info != null ? "\n"+e.info : "", trace = null
if (info.startsWith("\nPC:")) { data.pcToShape.set(key, {wave:dnum, pc:parseInt(e.info.split(":")[1]), st:e.st}); info = ""; }
let info = e.info != null ? "\n"+e.info : "", trace = null, pc = null
if (info.startsWith("\nPC:")) { pc = parseInt(e.info.split(":")[1]); info = ""; }
if (info.startsWith("\nTB:")) { trace = info; info = ""; }
const arg = { tooltipText:" N:"+shapes.length+"\n"+formatTime(e.dur)+info, label, trace, bufs:[], key, ctx:shapeRef?.ctx, step:shapeRef?.step };
const arg = { tooltipText:" N:"+shapes.length+"\n"+formatTime(e.dur)+info, label, pc, trace, bufs:[], key, ctx:shapeRef?.ctx, step:shapeRef?.step };
if (e.key != null) shapeMap.set(e.key, key);
// offset y by depth
shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label:opts.hideLabels ? null : label, fillColor });
@ -550,12 +554,11 @@ async function renderProfiler(path, opts) {
}
}
for (const m of markers) m.label = m.name.split(/(\s+)/).map(st => ({ st, color:m.color, width:ctx.measureText(st).width }));
data.pcToShape = new Map([...data.pcToShape].sort((a, b) => a[1].st - b[1].st));
if (extData.pcMap != null) data.pcMap = extData.pcMap; setFocus(focusedShape);
// secondary axis mapping
let instRange = null;
for (const [k, { shapes }] of data.tracks) if (k.startsWith("WAVE")) {
const first = shapes[0].x, last = shapes.at(-1).x;
for (const [k, { shapes }] of data.tracks) if (!k.includes("Clock") && path.includes("pkts")) {
const first = shapes[0].x, last = shapes.at(-1).x+shapes.at(-1).width;
instRange = instRange == null ? [first, last] : [Math.min(first, instRange[0]), Math.max(last, instRange[1])];
}
if (instRange != null) [data.instSt, data.instEt] = instRange;

View file

@ -103,6 +103,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
@ -338,8 +339,9 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
from tinygrad.renderer.amd.sqtt import map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC
from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA
from tinygrad.renderer.amd.sqtt import (map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC,
ALUEXEC, INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA,
WAVEEND, CDNA_WAVEEND)
ret:list[ProfileEvent] = []
row_ends:dict[str, Decimal] = {}
curr_barrier:dict[str, ProfileRangeEvent] = {}
@ -368,7 +370,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
if isinstance(p, (INST, INST_RDNA4, CDNA_INST)):
name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4, InstOpCDNA)) else f"0x{p.op:02x}"
add(name, p, info=info)
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info)
if isinstance(p, (VALUINST, IMMEDIATE, WAVEEND, CDNA_WAVEEND)): add(p.__class__.__name__, p, info=info)
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info).wave, info=info)
if isinstance(p, (VMEMEXEC, ALUEXEC)):
name = str(p.src).split('.')[1]