Merge branch 'master' into new_x86_backend

This commit is contained in:
ttomsa 2026-03-23 22:50:56 +00:00 committed by GitHub
commit cd0152efec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 725 additions and 291 deletions

View file

@ -70,7 +70,7 @@ jobs:
source venv/bin/activate
pip install $GITHUB_WORKSPACE
cp $GITHUB_WORKSPACE/examples/beautiful_mnist.py .
BS=2 STEPS=10 python beautiful_mnist.py
BS=2 STEPS=10 MAX_BUFFER_SIZE=0 python beautiful_mnist.py
- name: Test Docs Build
run: python -m mkdocs build --strict
- name: Test Docs
@ -141,7 +141,7 @@ jobs:
sudo apt update || true
sudo apt install -y --no-install-recommends ninja-build
- name: Test beautiful_mnist in torch with TINY_BACKEND
run: STEPS=20 CPU=1 TARGET_EVAL_ACC_PCT=90.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
run: STEPS=20 CPU=1 TARGET_EVAL_ACC_PCT=90.0 MAX_BUFFER_SIZE=0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
- name: Test some torch tests (expect failure)
run: python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true

View file

@ -5,7 +5,7 @@ from extra.onnx_helpers import get_example_inputs, validate
def load_onnx_model(onnx_file):
run_onnx = OnnxRunner(onnx_file)
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True)
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True)
return run_onnx_jit, run_onnx.graph_inputs
if __name__ == "__main__":

View file

@ -1,6 +1,5 @@
#!/bin/bash
export BENCHMARK=5
export EVAL_BS=0
export VIZ=${VIZ:--1}
examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh
extra/viz/cli.py --profile --device "AMD" --top 20
VIZ=${VIZ:--1} examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh
extra/viz/cli.py --profile --device "AMD" --limit 20

View file

@ -1,12 +1,11 @@
from pathlib import Path
import numpy as np
import torch
from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.helpers import trange
from tinygrad.nn import optim
from extra.datasets import fetch_mnist
from tinygrad.nn.datasets import mnist
class LinearGen:
def __init__(self):
@ -38,14 +37,14 @@ class LinearDisc:
return x
def make_batch(images):
sample = np.random.randint(0, len(images), size=(batch_size))
image_b = images[sample].reshape(-1, 28*28).astype(np.float32) / 127.5 - 1.0
return Tensor(image_b)
sample = Tensor.randint(batch_size, low=0, high=images.shape[0])
return images[sample].reshape(batch_size, 28*28).cast('float').div(127.5).sub(1.0)
def make_labels(bs, col, val=-2.0):
y = np.zeros((bs, 2), np.float32)
y[range(bs), [col] * bs] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
return Tensor(y)
y = Tensor.zeros(bs, 2)
if col == 0: y = y + Tensor([val, 0.0])
else: y = y + Tensor([0.0, val])
return y
def train_discriminator(optimizer, data_real, data_fake):
real_labels = make_labels(batch_size, 1)
@ -71,12 +70,12 @@ def train_generator(optimizer, data_fake):
if __name__ == "__main__":
# data for training and validation
images_real = np.vstack(fetch_mnist()[::2])
X_train, _, _, _ = mnist()
ds_noise = Tensor.randn(64, 128, requires_grad=False)
# parameters
epochs, batch_size, k = 300, 512, 1
sample_interval = epochs // 10
n_steps = len(images_real) // batch_size
n_steps = X_train.shape[0] // batch_size
# models and optimizer
generator = LinearGen()
discriminator = LinearDisc()
@ -84,24 +83,24 @@ if __name__ == "__main__":
output_dir = Path(".").resolve() / "outputs"
output_dir.mkdir(exist_ok=True)
# optimizers
optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator),lr=0.0002, b1=0.5)
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop
Tensor.training = True
for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps):
data_real = make_batch(images_real)
for step in range(k): # Try with k = 5 or 7.
with Tensor.train():
for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps):
data_real = make_batch(X_train)
for step in range(k): # Try with k = 5 or 7.
noise = Tensor.randn(batch_size, 128)
data_fake = generator.forward(noise).detach()
loss_d += train_discriminator(optim_d, data_real, data_fake)
noise = Tensor.randn(batch_size, 128)
data_fake = generator.forward(noise).detach()
loss_d += train_discriminator(optim_d, data_real, data_fake)
noise = Tensor.randn(batch_size, 128)
data_fake = generator.forward(noise)
loss_g += train_generator(optim_g, data_fake)
if (epoch + 1) % sample_interval == 0:
fake_images = generator.forward(ds_noise).detach().numpy()
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg")
t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}")
data_fake = generator.forward(noise)
loss_g += train_generator(optim_g, data_fake)
if (epoch + 1) % sample_interval == 0:
fake_images = generator.forward(ds_noise).detach().numpy()
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg")
t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}")
print("Training Completed!")

View file

@ -13,12 +13,20 @@ from collections import OrderedDict
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "CL"]
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
# memory-planned subbuffers can have multiple Buffer objects for the same memory region
canon, _seen = {}, {}
for ji in run.jit_cache:
for b in ji.bufs:
if b is not None: canon[id(b)] = _seen.setdefault((id(b.base._buf), b.offset, b.size, b.dtype), b)
special_names = {id(canon[k]): v for k, v in special_names.items() if k in canon}
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for ji in run.jit_cache:
fxn: ProgramSpec = ji.prg.p
functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(ji.bufs):
arg = canon[id(arg)]
key = id(arg)
if key not in bufs:
if key in special_names:

View file

@ -0,0 +1,208 @@
from tinygrad import Tensor, UOp, getenv
from tinygrad.uop.ops import AxisType, KernelInfo, Ops
from tinygrad.dtype import AddrSpace, dtypes
from tinygrad.helpers import DEBUG, GlobalCounters, Context
import math
B = getenv("B", 1)
H = getenv("H", 32)
N = getenv("N", 1024)
D = getenv("D", 64)
assert D % 16 == 0 and N % 16 == 0
BLOCK_M, BLOCK_N = 64, 64
WARP_SIZE = 32
WMMA_M, WMMA_N, WMMA_K = 16, 16, 16
WAVES_M, WAVES_N = 4, 1
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 2, 16
WMMA_ACC = WMMA_M // LANES_PER_WAVE_M
THREADS_PER_BLOCK = WARP_SIZE * WAVES_M * WAVES_N
TM = BLOCK_M // (WAVES_M * LANES_PER_WAVE_M)
TN = BLOCK_N // (WAVES_N * LANES_PER_WAVE_N)
TD = D // (WAVES_N * LANES_PER_WAVE_N)
LDS_PAD = 4 # pad LDS rows to reduce bank conflicts
WMMA_ARG = ((WMMA_M, WMMA_N, WMMA_K), 'AMD', 32)
SCALE = 1.0 / math.sqrt(D)
LOG2E = math.log2(math.e)
def warp_shfl_xor(val, offset, lane):
"""Read val from lane ^ offset using ds_bpermute."""
idx = ((lane ^ offset) * 4).cast(dtypes.int)
return UOp(Ops.CUSTOM, dtypes.float, (idx, val),
arg="__builtin_bit_cast(float, __builtin_amdgcn_ds_bpermute({0}, __builtin_bit_cast(int, {1})))")
def warp_reduce_max(val, lane):
"""Tree reduce MAX across LANES_PER_WAVE_N=16 lanes."""
for offset in [8, 4, 2, 1]:
val = UOp(Ops.MAX, dtypes.float, (val, warp_shfl_xor(val, offset, lane)))
return val
def warp_reduce_sum(val, lane):
"""Tree reduce SUM across LANES_PER_WAVE_N=16 lanes."""
for offset in [8, 4, 2, 1]:
val = val + warp_shfl_xor(val, offset, lane)
return val
def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
block_bh = UOp.range(B * H, 0, AxisType.GLOBAL)
block_m = UOp.range(N // BLOCK_M, 1, AxisType.GLOBAL)
q = q.reshape(B*H, N//BLOCK_M, BLOCK_M, D)[block_bh, block_m]
k = k.reshape(B*H, N//BLOCK_N, BLOCK_N, D)[block_bh]
v = v.reshape(B*H, N//BLOCK_N, BLOCK_N, D)[block_bh]
o = o.reshape(B*H, N//BLOCK_M, BLOCK_M, D)[block_bh, block_m]
wave_m = UOp.range(WAVES_M, 2, AxisType.LOCAL)
wave_n = UOp.range(WAVES_N, 3, AxisType.LOCAL)
lane = UOp.range(WARP_SIZE, -1, AxisType.WARP)
tid = (wave_m * WAVES_N + wave_n) * WARP_SIZE + lane
lane_m = lane // LANES_PER_WAVE_N
lane_n = lane % LANES_PER_WAVE_N
# LDS allocation: slot 0 = Q then P (shared), slot 1 = K then V
# TODO: the memory planner should be able to find this reuse
ELEMS_PER_THREAD = BLOCK_M * D // THREADS_PER_BLOCK
QP_lds = UOp.placeholder((BLOCK_M, D + LDS_PAD), dtypes.half, slot=0, addrspace=AddrSpace.LOCAL)
KV_lds = UOp.placeholder((BLOCK_N, D + LDS_PAD), dtypes.half, slot=1, addrspace=AddrSpace.LOCAL)[:, :D]
# register state
acc = UOp.placeholder((TM, TD), dtypes.float, slot=2, addrspace=AddrSpace.REG)
m_i = UOp.placeholder((TM,), dtypes.float, slot=3, addrspace=AddrSpace.REG)
l_i = UOp.placeholder((TM,), dtypes.float, slot=4, addrspace=AddrSpace.REG)
acc = acc.after(acc.store(acc.const_like(0)))
m_i = m_i.after(m_i.store(m_i.const_like(-math.inf)))
l_i = l_i.after(l_i.store(l_i.const_like(0)))
# ====== KV tile loop ======
n_tile = UOp.range(N // BLOCK_N, 100, AxisType.REDUCE)
# load Q + K into LDS (Q reloaded each iteration since P overwrites slot 0)
Q_lds = QP_lds[:, :D]
Q_store = Q_lds.after(n_tile).reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid].store(
q.reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid])
K_store = KV_lds.reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid].store(
k[n_tile].reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid])
qk_load_barrier = UOp.barrier(UOp.group(Q_store, K_store))
Q_lds = Q_lds.after(qk_load_barrier)
KV_lds_k = KV_lds.after(qk_load_barrier)
# -- S = Q @ K^T via WMMA (re-init each n_tile) --
S_reg = UOp.placeholder((TM, TN), dtypes.float, slot=6, addrspace=AddrSpace.REG)
S_reg = S_reg.after(S_reg.after(n_tile).store(S_reg.const_like(0)))
k_qk = UOp.range(D // WMMA_K, 101, AxisType.REDUCE)
tm1 = UOp.range(TM // WMMA_ACC, 200, AxisType.LOOP)
tn1 = UOp.range(TN, 201, AxisType.LOOP)
S_frag = S_reg.reshape(TM // WMMA_ACC, WMMA_ACC, TN).permute(0, 2, 1)[tm1, tn1]
q_frag = Q_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, D // WMMA_K, WMMA_K)[wave_m, tm1, lane_n, k_qk]
k_frag = KV_lds_k.reshape(WAVES_N, TN, WMMA_N, D // WMMA_K, WMMA_K)[wave_n, tn1, lane_n, k_qk]
qk = UOp(Ops.SHAPED_WMMA, dtypes.float, (q_frag, k_frag, S_frag.after(k_qk)), arg=WMMA_ARG)
qk_done = S_frag.store(qk).end(tm1, tn1).end(k_qk)
S_reg = S_reg.after(qk_done)
# -- softmax in registers with warp shuffles --
S_reg = S_reg.after(S_reg.store(S_reg * SCALE))
# per-thread local row max over TN=4 elements, then warp reduce across 16 lanes
m_ij = UOp.placeholder((TM,), dtypes.float, slot=7, addrspace=AddrSpace.REG)
m_ij = m_ij.after(m_ij.after(n_tile).store(m_ij.const_like(-math.inf)))
rm1 = UOp.range(TM, 260, AxisType.LOOP)
rm2 = UOp.range(TN, 261, AxisType.REDUCE)
m_ij = m_ij.after(m_ij[rm1].store(UOp(Ops.MAX, dtypes.float, (m_ij.after(rm1, rm2)[rm1], S_reg[rm1, rm2]))).end(rm2, rm1))
# warp reduce max (in-place)
ri_w = UOp.range(TM, 270, AxisType.LOOP)
m_ij = m_ij.after(m_ij[ri_w].store(warp_reduce_max(m_ij[ri_w], lane)).end(ri_w))
# compute P = exp(S - m_ij) in S_reg (manual ranges)
rp0a = UOp.range(TM, 275, AxisType.LOOP)
rp0b = UOp.range(TN, 276, AxisType.LOOP)
S_reg = S_reg.after(S_reg[rp0a, rp0b].store(((S_reg[rp0a, rp0b] - m_ij[rp0a]) * LOG2E).exp2()).end(rp0a, rp0b))
p_local = UOp.placeholder((TM,), dtypes.float, slot=8, addrspace=AddrSpace.REG)
p_local = p_local.after(p_local.after(n_tile).store(p_local.const_like(0)))
rp1 = UOp.range(TM, 290, AxisType.LOOP)
rp2 = UOp.range(TN, 291, AxisType.REDUCE)
p_local = p_local.after(p_local[rp1].store(p_local.after(rp1, rp2)[rp1] + S_reg[rp1, rp2]).end(rp2, rp1))
ri_ws = UOp.range(TM, 295, AxisType.LOOP)
p_sum = p_local.after(p_local[ri_ws].store(warp_reduce_sum(p_local[ri_ws], lane)).end(ri_ws))
# write P = exp(S - m_ij) to P_lds (reuses slot 0, Q no longer needed)
P_lds = QP_lds[:, :BLOCK_N]
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
rw1 = UOp.range(TM, 296, AxisType.LOOP)
rw2 = UOp.range(TN, 297, AxisType.LOOP)
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)
# -- online softmax correction --
ri4 = UOp.range(TM, 330, AxisType.LOOP)
m_new_val = UOp(Ops.MAX, dtypes.float, (m_i[ri4], m_ij[ri4]))
alpha_val = ((m_i[ri4] - m_new_val) * LOG2E).exp2()
beta_val = ((m_ij[ri4] - m_new_val) * LOG2E).exp2()
rj4 = UOp.range(TD, 331, AxisType.LOOP)
correction = UOp.group(
acc[ri4, rj4].store(alpha_val * acc[ri4, rj4]).end(rj4),
l_i[ri4].store(alpha_val * l_i[ri4] + beta_val * p_sum[ri4]),
m_i[ri4].store(m_new_val),
).end(ri4)
acc = acc.after(correction)
l_i = l_i.after(correction)
m_i = m_i.after(correction)
# load V into KV_lds (must wait for QK WMMA to finish reading K from KV_lds)
V_store = KV_lds.after(qk_done).reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid].store(
v[n_tile].reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid])
pv_barrier = UOp.barrier(UOp.group(P_store, V_store))
P_lds = P_lds.after(pv_barrier)
KV_lds_v = KV_lds.after(pv_barrier)
# -- acc += P @ V via WMMA --
k_pv = UOp.range(BLOCK_N // WMMA_K, 400, AxisType.REDUCE)
tm2 = UOp.range(TM // WMMA_ACC, 401, AxisType.LOOP)
tn2 = UOp.range(TD, 402, AxisType.LOOP)
acc_frag = acc.reshape(TM // WMMA_ACC, WMMA_ACC, TD).permute(0, 2, 1)[tm2, tn2]
p_frag = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_N // WMMA_K, WMMA_K)[wave_m, tm2, lane_n, k_pv]
v_frag = KV_lds_v.reshape(WAVES_N, TD, WMMA_N, BLOCK_N // WMMA_K, WMMA_K)[wave_n, tn2, lane_n, k_pv]
pv = UOp(Ops.SHAPED_WMMA, dtypes.float, (p_frag, v_frag, acc_frag.after(k_pv)), arg=WMMA_ARG)
# end KV tile loop
n_tile_end = acc_frag.store(pv).end(tm2, tn2).end(k_pv).barrier().end(n_tile)
acc = acc.after(n_tile_end)
l_i = l_i.after(n_tile_end)
m_i = m_i.after(n_tile_end)
# normalize: acc /= l_i
acc = acc.after(acc.store(acc * (1 / l_i).reshape(TM, 1).expand(TM, TD)))
# store output
o = o.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TD, LANES_PER_WAVE_N)
o = o.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TD)
return o[tid].store(acc).end(wave_m, wave_n, lane).end(block_m, block_bh).sink(arg=KernelInfo(opts_to_apply=()))
if __name__ == "__main__":
q = Tensor.rand(B, H, N, D).cast(dtypes.half)
k = Tensor.rand(B, H, N, D).cast(dtypes.half)
v = Tensor.rand(B, H, N, D).cast(dtypes.half)
o = Tensor.empty(B, H, N, D, dtype=dtypes.float)
with Context(DEBUG=0): Tensor.realize(q, k, v)
q_flat, k_flat, v_flat, o_flat = q.reshape(B*H, N, D), k.reshape(B*H, N, D), v.reshape(B*H, N, D), o.reshape(B*H, N, D)
NUM_RUNS = getenv("CNT", 5)
ets = []
with Context(DEBUG=getenv("KDBG", 2)):
for _ in range(NUM_RUNS):
GlobalCounters.reset()
tst = Tensor.custom_kernel(o_flat, q_flat, k_flat, v_flat, fxn=amd_flash_attention)[0].realize()
ets.append(GlobalCounters.time_sum_s)
print(f"best time: {min(ets)*1e3:.2f}ms")
if getenv("VERIFY", 1):
with Context(DEBUG=0):
ref = q.float().scaled_dot_product_attention(k.float(), v.float()).reshape(B*H, N, D).realize()
err = (ref - tst).square().mean().item()
print(f"mean squared error {err}")
if err > 1e-2:
raise RuntimeError("flash attention is wrong!")
else:
print("flash attention is correct!")

View file

@ -9,7 +9,7 @@ EXAMPLES = {
"empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
"plus":"test/test_tiny.py TestTiny.test_plus",
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"",
"sync":"test/amd/test_custom_kernel.py TestCustomKernel.test_wave_sync",
"sync":"test/amd/test_custom_kernel.py TestCustomKernel.test_lds_sync",
}
if __name__ == "__main__":

View file

@ -1,6 +1,7 @@
A command line tool for exploring the VIZ trace.
After running with VIZ=-1, use `extra/viz/cli.py` to explore the saved trace files.
1. Set VIZ to -1 to save the trace.
2. Use `extra/viz/cli.py` to inspect the trace files.
## Inspect runtime profiling

View file

@ -1,9 +1,9 @@
#!/usr/bin/env python3
import argparse, pathlib, sys, struct, json
import argparse, pathlib, sys, struct, json, itertools
from typing import Iterator
from tinygrad.viz import serve as viz
from tinygrad.uop.ops import RewriteTrace
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, Context
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen
# ** generic helpers
@ -59,16 +59,18 @@ def decode_profile(data:bytes) -> dict:
return {"dur":total_dur, "peak":global_peak, "layout":layout, "markers":markers}
if __name__ == "__main__":
Context(VIZ=0, TRACK_MATCH_STATS=0).__enter__()
parser = argparse.ArgumentParser()
g_mode = parser.add_argument_group("mode")
g_mode.add_argument("--profile", action="store_true", help="View profile trace")
g_mode.add_argument("--rewrites", action="store_true", help="View rewrites trace")
g_common = parser.add_argument_group("common options")
g_common.add_argument("--kernel", type=str, default=None, metavar="NAME", help="Select a kernel by name (optional name, default: only list names)")
g_common.add_argument("--no-color", action="store_true", default=not (sys.stdin.isatty() and sys.stdout.isatty()),
help="Disable colored output (default: true in non-interactive mode)")
g_profile = parser.add_argument_group("profile options")
g_profile.add_argument("--device", type=str, default=None, metavar="NAME", help="Select a device (optional name, default: only list names)")
g_profile.add_argument("--top", type=int, default=10, metavar="N", help="Number of top kernels to show (-1 for all, default: 10)")
g_profile.add_argument("--offset", type=int, default=0, metavar="N", help="event offset (default: 0)")
g_profile.add_argument("--limit", type=int, default=10, metavar="N", help="events to display (-1 for all, default: 10)")
g_rewrites = parser.add_argument_group("rewrites options")
g_rewrites.add_argument("--select", type=str, default=None, metavar="NAME",
help="Select an item within the chosen kernel (optional name, default: only list names)")
@ -84,14 +86,64 @@ if __name__ == "__main__":
viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))
viz.ctxs = viz.get_rewrites(viz.trace)
def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s
if args.profile:
from tabulate import tabulate
profile = decode_profile(viz.get_profile(viz.load_pickle(args.profile_path, default=[])))
profile = decode_profile(viz.get_profile(profile_data:=viz.load_pickle(args.profile_path, default=[])))
viz.load_amd_counters(viz.ctxs, profile_data)
counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"]
if s["name"].startswith("PKTS")}
if args.device is None:
print("Select a device:")
for k in (*profile["layout"], *counters):
print(f" {format_colored(k)}")
sys.exit(0)
# ** SQTT printer
if args.device is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.device), None)) is not None:
assert args.limit > 1, f"SQTT limit must be greater than 1, got {args.limit}"
sqtt_events, has_more = viz.sqtt_timeline(*sqtt_data, max_pkts=args.offset+args.limit)
sqtt_pkts = [e for e in sqtt_events if type(e).__name__ == "ProfileRangeEvent"]
pc_map = next((e.arg for e in sqtt_events if type(e).__name__ == "ProfilePointEvent" and e.key == 'pcMap'), None)
if pc_map is None:
print(f"No SQTT instruction trace data for {args.device}")
sys.exit(0)
# modern terminals support 24-bit color
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'),
(('STORE',), '#4fa3cc'), (('IMMEDIATE',), '#f3b44a'), (('BARRIER',), '#d00000'), (('LDS',), '#9fb4a6'), (('JUMP',), '#ffb703'),
(('JUMP_NO',), '#fb8500'), (('MESSAGE',), '#90dbf4'), (('WAVERDY',), '#1a2a2a'))
print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Info'}")
print("-" * 90)
# start from the first packet in trace, prepare packet indexes and map dispatches
pkt_idxs:dict[str, itertools.count] = {}
dispatch_to_pc:dict[str, int] = {}
for e in sqtt_pkts[:-args.limit]:
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
if e.name.ret is not None and e.name.ret.startswith("PC:"): dispatch_to_pc[f"{e.device}-{idx}"] = int(e.name.ret.replace("PC:", ""))
# start printing from the offset point
for e in sqtt_pkts[-args.limit:]:
op_name, info = e.name.display_name, e.name.ret or ""
color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None)
op_str = hex_colored(op_name, color) if color and not args.no_color else op_name
phase, pc = None, None
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
if info.startswith("PC:"):
dispatch_to_pc[f"{e.device}-{idx}"] = pc = int(info.replace("PC:", ""))
phase = "DISPATCH"
if info.startswith("LINK:"): phase, pc = "EXEC", dispatch_to_pc[info.replace("LINK:", "")]
if pc and phase: info = f"{phase:<8} 0x{pc:05x} {pc_map[pc]}"
print(f"{int(e.st):<12} {e.device:<20} {op_str}{' '*(22-ansilen(op_str))} {int(e.en-e.st):<4} {info}")
# note: we only print the important packets and skip the rest
if has_more: print(f"Selected packets {args.offset:,}-{args.offset + args.limit:,}. Use --offset and --limit to see others")
sys.exit(0)
# ** Profiler printer
agg, total, n = {}, 0, 0
if args.device is None: print("Select a device:")
for k,v in profile["layout"].items():
if not optional_eq({"name":k}, args.device): continue
print(f" {k}")
print(f" {format_colored(k)}")
if args.device is None: continue
for e in v.get("events", []):
et = e["dur"]*1e-6
@ -99,7 +151,7 @@ if __name__ == "__main__":
if optional_eq(e, args.kernel) and n < 10:
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
name = e["name"]+(" " * (46 - ansilen(e["name"])))
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e['fmt'].replace('\n', ' | ')+" ")
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ")
n += 1
else:
a = agg.setdefault(e["name"], [0.0, 0])
@ -108,14 +160,15 @@ if __name__ == "__main__":
total += et
if agg and total > 0:
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
sel = items if args.top == -1 else items[:args.top]
sel = items if args.limit == -1 else items[args.offset:args.offset+args.limit]
table = [[name, time_to_str(t, w=9), c, f"{(t/total*100.0):.2f}%"] for name,(t,c) in sel]
if args.top != -1 and (other:=items[len(sel):]):
if args.limit != -1 and (other:=items[len(sel):]):
other_t = total-sum(t for _, (t, _) in sel)
table.append([f"Other ({len(other)} unique)", time_to_str(other_t, w=9), sum(c for _,(_,c) in other), f"{other_t/total*100.0:.2f}%"])
print(tabulate(table, headers=["name", "total", "count", "pct"], tablefmt="github"))
sys.exit(0)
# ** Graph rewrites printer
for k in viz.ctxs:
if not optional_eq(k, args.kernel): continue
print(k["name"])

View file

@ -3,8 +3,10 @@ import functools
from tinygrad import Tensor, Device, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from tinygrad.dtype import AddrSpace
from tinygrad.runtime.autogen.amd.rdna3.ins import *
from tinygrad.runtime.autogen.amd.rdna4.ins import s_barrier_wait, s_barrier_signal
import tinygrad.runtime.autogen.amd.rdna3.ins as r3
import tinygrad.runtime.autogen.amd.rdna4.ins as r4
from tinygrad.renderer.amd.dsl import s, v
from test.amd.helpers import TARGET_TO_ARCH
@ -53,12 +55,46 @@ def custom_wave_sync(A:UOp, arch:str) -> UOp:
insts = []
for _ in range(4):
insts.append(s_sleep(4))
insts += [s_barrier()] if arch == "rdna3" else [s_barrier_signal(), s_barrier_wait()]
insts += [s_barrier()] if arch == "rdna3" else [r4.s_barrier_signal(), r4.s_barrier_wait()]
insts += [s_nop(0)]*4
insts.append(s_endpgm())
sink = UOp.sink(A.base, threads, wg, arg=KernelInfo("custom_wave_sync"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_lds_sync(A:UOp, arch:str) -> UOp:
A = A.flatten()
num_threads = A.shape[0]
threads = UOp.special(num_threads, "lidx0")
wg = UOp.special(1, "gidx0")
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
isa = r4 if arch == "rdna4" else r3
wait_kmcnt = [isa.s_wait_kmcnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt(lgkmcnt=0)]
wait_dscnt = [isa.s_wait_dscnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt(lgkmcnt=0)]
barrier = [isa.s_barrier_signal(ssrc0=-1), isa.s_barrier_wait(simm16=-1)] if arch == "rdna4" else [isa.s_barrier()]
global_store = [isa.global_store_b32(vaddr=v[6:7], saddr=s[0:1], vsrc=v[5])] if arch == "rdna4" \
else [isa.global_store_b32(addr=v[6], data=v[5], saddr=s[0:1])]
insts = [
isa.s_load_b64(s[0:1], s[0:1], soffset=NULL),
*wait_kmcnt,
isa.v_lshlrev_b32_e32(v[1], 2, v[0]),
# lds[thread_idx] = thread_idx
isa.ds_store_b32(addr=v[1], data0=v[0]),
*wait_dscnt,
*barrier,
# out[threaed_idx] = thread_idx == num_threads ? -1 : lds[thread_idx + 1]
isa.v_add_nc_u32_e32(v[2], 4, v[1]),
isa.v_cmp_gt_u32_e32(num_threads-1, v[0]),
isa.ds_load_b32(vdst=v[3], addr=v[2]),
*wait_dscnt,
isa.v_mov_b32_e32(v[4], -1),
isa.v_cndmask_b32_e32(v[5], v[4], v[3]),
isa.v_lshlrev_b32_e32(v[6], 2, v[0]),
*global_store,
isa.s_endpgm(),
]
sink = UOp.sink(A.base, lds, threads, wg, arg=KernelInfo("custom_lds_sync"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
@unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device")
class TestCustomKernel(unittest.TestCase):
def setUp(self): self.arch = TARGET_TO_ARCH[Device["AMD"].arch]
@ -83,9 +119,14 @@ class TestCustomKernel(unittest.TestCase):
ei.run({"var":i})
self.assertTrue((a.numpy() == 1+i).all())
def test_wave_sync(self):
if self.arch not in {"rdna3", "rdna4"}: self.skipTest("only rdna3 or rdna4")
Tensor.empty(1).custom_kernel(fxn=functools.partial(custom_wave_sync, arch=self.arch))[0].realize()
def test_lds_sync(self):
if self.arch not in ("rdna3", "rdna4"): self.skipTest("only rdna3/rdna4")
a = Tensor.empty(128, dtype=dtypes.int32).contiguous().realize()
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_lds_sync, arch=self.arch))[0]
a.realize()
ref = Tensor.arange(1, 129, dtype=dtypes.int32)
ref[127] = -1
self.assertListEqual(a.tolist(), ref.tolist())
if __name__ == "__main__":
unittest.main()

View file

@ -10,18 +10,11 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
from tinygrad.renderer.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST,
IMMEDIATE, IMMEDIATE_MASK, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4, PACKET_TYPES_CDNA, CDNA_WAVESTART,
InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND, CDNA_INST)
print_packets, CDNA_WAVEEND, CDNA_INST)
from test.amd.helpers import TARGET_TO_ARCH
import tinygrad
EXAMPLES_DIR = Path(tinygrad.__file__).parent.parent / "extra/sqtt/examples"
# INST ops for non-traced SIMDs (excluded from instruction count)
OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LDS_STORE_64, InstOp.OTHER_LDS_STORE_128,
InstOp.OTHER_FLAT_LOAD, InstOp.OTHER_FLAT_STORE, InstOp.OTHER_FLAT_STORE_64, InstOp.OTHER_FLAT_STORE_96,
InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR,
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_5, InstOpRDNA4.OTHER_LDS_1, InstOpRDNA4.OTHER_LDS_2}
# ═══════════════════════════════════════════════════════════════════════════════
# ROCPROF DECODER
@ -183,11 +176,18 @@ class SQTTExamplesTestBase(unittest.TestCase):
our_waves: list[tuple[int, int]] = []
for event in events:
wave_starts: dict[tuple[int, int, int], int] = {}
first_timestamp:int|None = None
for p in decode(event.blob):
if first_timestamp is None: first_timestamp = p._time
if isinstance(p, (WAVESTART, CDNA_WAVESTART, WAVESTART_RDNA4)): wave_starts[(p.wave, p.simd, p.cu)] = p._time
elif isinstance(p, (WAVEEND, CDNA_WAVEEND)) and (key := (p.wave, p.simd, p.cu)) in wave_starts:
our_waves.append((wave_starts[key], p._time))
self.assertEqual(sorted(our_waves), sorted(roc_waves), f"wave times mismatch in {name}")
for st in wave_starts.values():
self.assertGreater(st, first_timestamp, "wave start must be after the first packet")
# rocprof fails non deterministically and gives inaccurate timestamps.
#self.assertEqual(sorted(our_waves), sorted(roc_waves), f"wave times mismatch in {name}")
for st, et in our_waves:
self.assertGreater(et, st, "wave end must be after start")
def test_rocprof_inst_times_match(self):
"""Instruction times must match rocprof exactly (excluding s_endpgm)."""
@ -200,8 +200,8 @@ class SQTTExamplesTestBase(unittest.TestCase):
our_insts: list[int] = []
for event in events:
for p in decode(event.blob):
if isinstance(p, INST) and p.op not in OTHER_SIMD_OPS: our_insts.append(p._time)
elif isinstance(p, INST_RDNA4) and p.op not in OTHER_SIMD_OPS_RDNA4: our_insts.append(p._time)
# INST ops for non-traced SIMDs (excluded from instruction count)
if isinstance(p, (INST, INST_RDNA4)) and not p.op.name.startswith("OTHER_"): our_insts.append(p._time)
elif isinstance(p, VALUINST): our_insts.append(p._time)
elif isinstance(p, IMMEDIATE): our_insts.append(p._time)
elif isinstance(p, IMMEDIATE_MASK):

View file

@ -79,7 +79,7 @@ class TestSQTTMapBase(unittest.TestCase):
if (p:=kern_events.get(event.kern)) is None: continue
with self.subTest(example=name, kern=event.kern):
# skip if there's no SQTT frequency data
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)): continue
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)[0]): continue
if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue
mean = sum(frequency) / len(frequency)
variance = sum((v - mean) ** 2 for v in frequency) / len(frequency)
@ -93,7 +93,7 @@ class TestSQTTMapBase(unittest.TestCase):
elif "WAVE" in e.device:
# sopk/immediates don't get ALU/MEM EXEC
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL",
"WAVEEND"}: insts += 1
"WAVEEND", "WAVERDY"}: insts += 1
else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}")
self.assertEqual(execs, insts)
@ -101,7 +101,7 @@ class TestSQTTMapBase(unittest.TestCase):
for name, (events, kern_events, target) in self.examples.items():
for event in events:
wave_barriers = {}
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target):
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target)[0]:
if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e)
if not wave_barriers: continue
for row, events in wave_barriers.items():

View file

@ -39,6 +39,18 @@ class TestJit(unittest.TestCase):
def add(a, b): return (a+b).realize()
_simple_test(add)
def test_jitbeam_triggers_beam(self):
from unittest.mock import patch
from tinygrad.helpers import getenv as _getenv
@TinyJit
def add(a, b): return (a+b).realize()
a, b = Tensor.ones(10, 10).contiguous().realize(), Tensor.ones(10, 10).contiguous().realize()
with patch("tinygrad.codegen.opt.search.beam_search", wraps=lambda k,*a,**kw: k) as mock_beam:
add(a, b)
assert mock_beam.call_count == 0
with patch("tinygrad.engine.jit.getenv", side_effect=lambda k, d=0: 1 if k == "JITBEAM" else _getenv(k, d)): add(a, b)
assert mock_beam.call_count == 1
def test_simple_jit_reset(self):
@TinyJit
def add(a, b): return (a+b).realize()
@ -648,25 +660,6 @@ class TestJitFree(unittest.TestCase):
fxn(Tensor([2]))
self.assertEqual(x.item(), 8)
def test_replan_buffers_memory_layout(self):
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless")
ext_tensor = Tensor([1,24,23,45,1]).contiguous()
ext_tensor_2 = Tensor([2,2,2,2,2]).contiguous()
@TinyJit
def fxn(x:Tensor):
out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
return out.sum()
for i in range(5):
out = fxn(Tensor([i,1,2,3,4]))
self.assertEqual(out.item(), 11400+200*i)
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 4)
fxn.captured.replan_buffers_memory_layout()
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 2)
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
class TestJitGraphSplit(unittest.TestCase):
def compute(self, device, inp):
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"

View file

@ -858,6 +858,8 @@ def _compile_sop(inst: ir3.SOP1|ir3.SOP2|ir3.SOPC|ir3.SOPK|ir4.SOP1|ir4.SOP2|ir4
result = (hw_val >> offset) & mask
return UOp.sink(ctx.wsgpr_dyn(sdst_off, result), *ctx.inc_pc())
elif isinstance(inst, (ir3.SOP1, ir4.SOP1, irc.SOP1)):
# S_BARRIER_SIGNAL: no-op in emulator, barrier sync handled by execution loop
if isinstance(inst, ir4.SOP1) and inst.op in _BARRIER_SOP1_OPS: return UOp.sink(*ctx.inc_pc())
sdst_off = ctx.inst_field(type(inst).sdst)
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)}
@ -1960,6 +1962,8 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
_BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT)
_BARRIER_SOP1_OPS: set = set()
if hasattr(ir4.SOP1Op, 'S_BARRIER_SIGNAL'): _BARRIER_SOP1_OPS.add(ir4.SOP1Op.S_BARRIER_SIGNAL)
_BRANCH_OPS: set[int] = {op.value for op in (ir3.SOPPOp.S_BRANCH, ir3.SOPPOp.S_CBRANCH_SCC0, ir3.SOPPOp.S_CBRANCH_SCC1,
ir3.SOPPOp.S_CBRANCH_VCCZ, ir3.SOPPOp.S_CBRANCH_VCCNZ, ir3.SOPPOp.S_CBRANCH_EXECZ, ir3.SOPPOp.S_CBRANCH_EXECNZ)}
@ -2084,7 +2088,8 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
if pc not in program:
prev_len = len(_canonical_runner_cache)
runner, inst = _decode_at(pc, arch)
is_barrier = isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS
is_barrier = (isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS) or \
(isinstance(inst, (ir4.SOP1,)) and inst.op in _BARRIER_SOP1_OPS)
program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier, inst)
if DEBUG >= 3:
msg = f"[emu] PC={pc - lib}: {inst!r}"

View file

@ -1,42 +1,73 @@
import unittest
from tinygrad import dtypes
from tinygrad.device import Buffer
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.memory import memory_plan_rewrite
global_map = {}
held_bufs: set[UOp] = set()
def b(i, base=None, offset=0, pin=False, size=16):
global global_map
if i in global_map: return global_map[i]
global_map[i] = Buffer("NULL", size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset)
if pin: global_map[i].ref(1)
if base is not None:
global_map[i] = global_map[base]
return global_map[i]
global_map[i] = UOp.new_buffer("NULL", size, dtypes.int8)
if pin: held_bufs.add(global_map[i])
return global_map[i]
def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]], copies:list[tuple[Buffer, Buffer]]|None=None):
assigned = _internal_memory_planner(buffers, copies=copies)
def _make_linear(buffer_lists, copies=None):
copy_pairs = {frozenset((id(dst), id(src))) for dst, src in copies} if copies else set()
calls = []
for bufs in buffer_lists:
is_copy = len(bufs) == 2 and frozenset((id(bufs[0]), id(bufs[1]))) in copy_pairs
calls.append(UOp(Ops.CALL, dtypes.void, (UOp(Ops.COPY if is_copy else Ops.SINK), *bufs)))
return UOp(Ops.LINEAR, src=tuple(calls))
taken_parts = set()
def _get_arena(buf, linear, result):
for orig_si, new_si in zip(linear.src, result.src):
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
if orig is buf and new.op is Ops.BUFFER_VIEW: return new.src[0]
return None
def check_assign(buffer_lists, copies=None):
linear = _make_linear(buffer_lists, copies)
result = memory_plan_rewrite(linear, held_bufs)
# build mapping: original buf -> (arena, offset_bytes, nbytes) from the result
replace_map: dict[int, tuple[UOp, int, int]] = {}
for orig_si, new_si in zip(linear.src, result.src):
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
if new.op is Ops.BUFFER_VIEW and id(orig) not in replace_map:
replace_map[id(orig)] = (new.src[0], new.arg[1] * new.dtype.itemsize, new.arg[0] * new.dtype.itemsize)
# verify pinned buffers are not planned
for buf in held_bufs:
assert id(buf) not in replace_map, "pinned buffer was planned"
# compute lifetimes
first_appearance, last_appearance = {}, {}
for i,u in enumerate(buffers):
for buf in u:
if buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0: continue
if buf.base not in first_appearance: first_appearance[buf.base] = i
last_appearance[buf.base] = i
for i, bufs in enumerate(buffer_lists):
for buf in bufs:
if buf in held_bufs: continue
if id(buf) not in first_appearance: first_appearance[id(buf)] = i
last_appearance[id(buf)] = i
for i,u in enumerate(buffers):
for buf in u:
if buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0: continue
cur, base = assigned.get(buf, buf), assigned.get(buf.base, buf.base)
if buf._base is not None:
assert cur.base == base.base and cur.offset == buf.offset + base.offset, f"failed: {buf} {cur} {base} {buf.offset} {base.offset}"
else:
for part in taken_parts:
assert buf.base == part[3] or part[0] != cur.base or part[1] + part[2] <= cur.offset or part[1] >= cur.offset + buf.nbytes
if first_appearance[buf.base] == i: taken_parts.add((cur.base, cur.offset, buf.nbytes, buf.base))
if last_appearance[buf.base] == i: taken_parts.remove((cur.base, cur.offset, buf.nbytes, buf.base))
# verify non-overlapping: no two live buffers share the same arena region
taken_parts: set[tuple[int, int, int, int]] = set() # (id(arena), offset, nbytes, id(buf))
for i, bufs in enumerate(buffer_lists):
for buf in bufs:
if buf in held_bufs or id(buf) not in replace_map: continue
arena, off, nb = replace_map[id(buf)]
for part in taken_parts:
assert id(buf) == part[3] or part[0] != id(arena) or part[1] + part[2] <= off or part[1] >= off + nb, \
f"overlap at step {i}: [{off}, {off+nb}) conflicts with [{part[1]}, {part[1]+part[2]})"
if first_appearance.get(id(buf)) == i: taken_parts.add((id(arena), off, nb, id(buf)))
if last_appearance.get(id(buf)) == i: taken_parts.discard((id(arena), off, nb, id(buf)))
class TestMemoryPlanner(unittest.TestCase):
def setUp(self):
global global_map
held_bufs.clear()
global_map = {}
def test_simple_buffer(self):
@ -140,9 +171,11 @@ class TestMemoryPlanner(unittest.TestCase):
[b(1), b(2)],
[b(3), b(2)],
]
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))])
r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2))
assert r1.base != r2.base
linear = _make_linear(bs, copies=[(b(1), b(0))])
result = memory_plan_rewrite(linear)
r1_arena, r2_arena = _get_arena(b(1), linear, result), _get_arena(b(2), linear, result)
assert r1_arena is not None and r2_arena is not None
assert r1_arena is not r2_arena
def test_copy_bufs_reuse_among_copies(self):
bs = [
@ -150,9 +183,11 @@ class TestMemoryPlanner(unittest.TestCase):
[b(2), b(1)],
[b(3), b(2)],
]
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0)), (b(2), b(1))])
r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2))
assert r1.base == r2.base
linear = _make_linear(bs, copies=[(b(1), b(0)), (b(2), b(1))])
result = memory_plan_rewrite(linear)
r1_arena, r2_arena = _get_arena(b(1), linear, result), _get_arena(b(2), linear, result)
assert r1_arena is not None and r2_arena is not None
assert r1_arena is r2_arena
def test_compute_bufs_reuse_among_compute(self):
bs = [
@ -161,9 +196,11 @@ class TestMemoryPlanner(unittest.TestCase):
[b(3), b(2)],
[b(4), b(3)],
]
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))])
r2, r3 = assigned.get(b(2), b(2)), assigned.get(b(3), b(3))
assert r2.base == r3.base
linear = _make_linear(bs, copies=[(b(1), b(0))])
result = memory_plan_rewrite(linear)
r2_arena, r3_arena = _get_arena(b(2), linear, result), _get_arena(b(3), linear, result)
assert r2_arena is not None and r3_arena is not None
assert r2_arena is r3_arena
def test_copy_and_compute_no_cross_reuse(self):
bs = [
@ -171,9 +208,11 @@ class TestMemoryPlanner(unittest.TestCase):
[b(2), b(1)],
[b(3), b(2)],
]
assigned = _internal_memory_planner(bs, copies=[(b(2), b(1))])
r0, r2 = assigned.get(b(0), b(0)), assigned.get(b(2), b(2))
assert r0.base != r2.base
linear = _make_linear(bs, copies=[(b(2), b(1))])
result = memory_plan_rewrite(linear)
r0_arena, r2_arena = _get_arena(b(0), linear, result), _get_arena(b(2), linear, result)
assert r0_arena is not None and r2_arena is not None
assert r0_arena is not r2_arena
def test_multiple_copy_bufs_with_offsets(self):
bs = [

View file

@ -74,7 +74,10 @@ class TestRealWorld(unittest.TestCase):
def test(t, t2):
for l in model: t = l(t, t2)
return t.realize()
helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.0002, 37)
# TODO: support _offset on CL to get mem down to 0.0002
exp_mem = 0.00037 if Device.DEFAULT == "CL" else 0.0002
helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, exp_mem, 37)
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
def test_llama(self):

View file

@ -816,5 +816,58 @@ class TestUOpTags(unittest.TestCase):
g = graph_rewrite(g, pm_plus_1)
assert g.ssimplify() == 6
class TestUOpGetItem(unittest.TestCase):
def _placeholder(self, shape, dtype=dtypes.half):
return UOp.placeholder(shape, dtype, slot=0, addrspace=AddrSpace.LOCAL)
# full slices (no shrink)
def test_full_slice(self):
p = self._placeholder((64, 64))
self.assertEqual(p[:, :].shape, (64, 64))
def test_full_slice_explicit(self):
p = self._placeholder((64, 64))
self.assertEqual(p[0:64, 0:64].shape, (64, 64))
# partial slices (shrink)
def test_shrink_cols(self):
p = self._placeholder((64, 80))
self.assertEqual(p[:, :64].shape, (64, 64))
def test_shrink_rows(self):
p = self._placeholder((80, 64))
self.assertEqual(p[:64, :].shape, (64, 64))
def test_shrink_both(self):
p = self._placeholder((80, 80))
self.assertEqual(p[:64, :64].shape, (64, 64))
def test_shrink_start(self):
p = self._placeholder((64, 64))
self.assertEqual(p[8:, :].shape, (56, 64))
def test_shrink_start_and_end(self):
p = self._placeholder((64, 64))
self.assertEqual(p[8:56, 4:60].shape, (48, 56))
# mixed slice and index
def test_index_and_slice(self):
p = self._placeholder((64, 80))
r = UOp.range(64, 100)
result = p[r, :64]
self.assertEqual(result.shape, (64,))
def test_slice_and_index(self):
p = self._placeholder((80, 64))
r = UOp.range(64, 100)
result = p[:64, r]
self.assertEqual(result.shape, (64,))
def test_shrink_then_index(self):
p = self._placeholder((64, 80))
s = p[:, :64]
r = UOp.range(64, 100)
result = s[r]
self.assertEqual(result.shape, (64,))
# integer index (no slice)
def test_int_index(self):
p = self._placeholder((64, 64))
result = p[0]
self.assertEqual(result.shape, (64,))
if __name__ == '__main__':
unittest.main(verbosity=2)

View file

@ -286,12 +286,12 @@ pm_render = PatternMatcher([
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
# Where after gated load becomes alt value
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
UPat.var("a")), lambda c,idx,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),),
allow_any_len=True, name="l").or_casted()), lambda c,idx,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c").logical_not()).or_casted(),),
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***

View file

@ -1,16 +1,26 @@
from typing import TypeVar, Generic, Callable, cast, Any
import functools, collections
from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops, buffers
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.engine.memory import memory_plan_rewrite, _collect_bufs
from tinygrad.engine.schedule import linear_to_schedule
from tinygrad.nn.state import get_parameters
from tinygrad.schedule.rangeify import mop_cleanup
from dataclasses import dataclass, replace
from dataclasses import dataclass
def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]:
kept, onetime = [], []
for si in linear.src:
si_bufs = {b for src in si.src[1:] for b in _collect_bufs(src)}
if not si_bufs.isdisjoint(needed):
kept.append(si)
needed |= si_bufs
else: onetime.append(si)
return linear.replace(src=tuple(kept)), linear.replace(src=tuple(onetime))
class GraphException(Exception): pass
class JitError(Exception): pass
@ -180,7 +190,7 @@ class CapturedJit(Generic[ReturnType]):
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
def __reduce__(self):
# TODO: free_intermediates here? replan_buffers_memory_layout here?
# TODO: free_intermediates here?
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info)
def __post_init__(self):
@ -202,18 +212,12 @@ class CapturedJit(Generic[ReturnType]):
def free_intermediates(self):
depends: set[Buffer|None] = set([None])
update_depends(depends, self.jit_cache)
for b in depends:
if b is not None:
if b.is_allocated(): b.deallocate()
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
self.__post_init__() # reset the graph state
def replan_buffers_memory_layout(self):
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
self.jit_cache = [replace(item, bufs=[asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
for old, new in asgn.items():
if old.is_allocated(): new.ensure_allocated().copyin(old.as_memoryview())
arenas = {b._base for b in depends if b is not None and b._base is not None}
to_free = {b for b in depends if b is not None} | {b for ei in self.jit_cache for b in ei.bufs if b is not None and b._base in arenas}
for b in to_free:
if hasattr(b, '_buf'): b.deallocate()
for a in arenas:
if a.allocated_views == 0 and a.is_allocated(): a.deallocate()
self.__post_init__()
# jit exec
@ -272,13 +276,12 @@ def _prepare_jit_inputs(args, kwargs):
return input_buffers, var_vals, names, expected_input_info
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False):
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False):
assert fxn or captured, "need either a function or a CapturedJit"
self.fxn = fxn
self.captured: CapturedJit|None = captured
self.cnt: int = 2 if self.fxn is None else 0
self.prune = prune
self.optimize = optimize
def add_linear(self, linear:UOp, var_vals:dict[str, int]): self._linears.append(linear)
@ -312,20 +315,32 @@ class TinyJit(Generic[ReturnType]):
assert self.fxn is not None
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
self._linears: list[UOp] = []
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(*params)
finally: capturing.clear()
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(*params)
finally: capturing.clear()
if not len(self._linears): raise JitError("didn't JIT anything!")
_check_no_non_tensor_return(ret)
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs")
# combine all captured linears into one and convert to ExecItems
jit_cache = [ei.lower() for ei in linear_to_schedule(UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears]))))]
# combine all captured linears into one, memory plan, and convert to ExecItems
big_linear = UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears])))
del self._linears
if self.prune:
big_linear, onetime_linear = prune_linear(big_linear, {k for k,v in buffers.items() if isinstance(v, Buffer) and v in set(input_buffers)})
if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels")
for ei in (si.lower() for si in linear_to_schedule(onetime_linear)):
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
ei.run(var_vals, jit=True)
del onetime_linear
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
jit_cache = [ei.lower() for ei in linear_to_schedule(memory_plan_rewrite(big_linear, held_bufs))]
del big_linear
# track inputs that are views of buffers
# TODO: eventually expected_buffers should live in ExecItem
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
@ -335,25 +350,6 @@ class TinyJit(Generic[ReturnType]):
input_buffers.append(b)
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
# prune independent kernels (optional)
if self.prune:
depends = set(input_buffers)
update_depends(depends, jit_cache)
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
# sync before re-executing onetime kernels
for dev in set(Device[b.device] for ei in onetime for b in ei.bufs if b is not None): dev.synchronize()
# run the onetime kernels here
for ei in onetime:
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
ei.run(var_vals, jit=True)
jit_cache = pruned
# memory planning (optional)
copies = [(cast(Buffer,ji.bufs[0]),cast(Buffer,ji.bufs[1])) for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec))]
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], copies, debug_prefix="JIT ")
jit_cache = [replace(item, bufs=[assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
input_replace = get_input_replace(jit_cache, input_buffers)
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
@ -361,7 +357,6 @@ class TinyJit(Generic[ReturnType]):
for ei in jit_cache: ei.run(var_vals)
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
if self.optimize: self.captured.replan_buffers_memory_layout()
elif self.cnt >= 2:
# jit exec
assert self.captured is not None

View file

@ -1,77 +1,65 @@
from typing import cast
from collections import defaultdict
from tinygrad.engine.realize import ExecItem
from tinygrad.device import Device, Buffer
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
from tinygrad.uop.ops import Ops
from tinygrad.device import Device
from tinygrad.helpers import NO_MEMORY_PLANNER, DEBUG, round_up
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes
from tinygrad.runtime.support.memory import TLSFAllocator
def _collect_bufs(u:UOp) -> list[UOp]:
if u.op is Ops.BUFFER: return [u]
if u.op in {Ops.MSELECT, Ops.MSTACK}: return [b for s in u.src for b in _collect_bufs(s)]
return []
def _can_plan(b:UOp, held_bufs:set[UOp]) -> bool:
if b in held_bufs: return False
devs = (b.device,) if isinstance(b.device, str) else b.device
return all(not d.startswith(("DISK", "TINYFS")) and hasattr(Device[d].allocator, "_offset") for d in devs)
LaneKey = tuple[str, int]
# **************** memory planning ****************
def memory_plan_rewrite(linear:UOp, held_bufs:set[UOp]|None=None) -> UOp:
if NO_MEMORY_PLANNER: return linear
if held_bufs is None: held_bufs = set()
def _internal_memory_planner(buffers:list[list[Buffer]], copies:list[tuple[Buffer, Buffer]]|None=None,
ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
if NO_MEMORY_PLANNER: return {}
first_appearance, last_appearance, buf_to_opt = {}, {}, set()
for i,u in enumerate(buffers):
for buf in u:
if not ignore_checks and (buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0): continue
if buf.base not in first_appearance: first_appearance[buf.base] = i
last_appearance[buf.base] = i
buf_to_opt.add(buf)
# compute lifetimes for all plannable internal buffers
first_appearance:dict[UOp, int] = {}
last_appearance:dict[UOp, int] = {}
copy_bufs: set[UOp] = set()
for i, si in enumerate(linear.src):
si_bufs = [b for src in si.src[1:] for b in _collect_bufs(src) if _can_plan(b, held_bufs)]
for b in si_bufs:
if b not in first_appearance: first_appearance[b] = i
last_appearance[b] = i
if si.src[0].op is Ops.COPY: copy_bufs.update(si_bufs)
if not first_appearance: return linear
# Separate copy and compute buffers into different lanes and defer cross-queue frees to avoid introducing dependencies (copy->compute->copy)
copy_dsts, copy_srcs = ({dst.base for dst,_ in copies}, {src.base for _,src in copies}) if copies else (set(), set())
def _key(buf) -> LaneKey: return (buf.device, 1 if buf in copy_dsts or buf in copy_srcs else 0)
buf_hold = {buf: last_appearance[buf] - first_appearance[buf] + 1 for buf in first_appearance if buf in copy_dsts or buf in copy_srcs}
# separate copy and compute buffers into different lanes to avoid introducing dependencies (copy->compute->copy)
def _key(b:UOp): return (b.device, 1 if b in copy_bufs else 0)
buf_hold = {b: last_appearance[b] - first_appearance[b] + 1 for b in first_appearance if b in copy_bufs}
# Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed.
buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \
[((last_appearance[buf] + 1 + buf_hold.get(buf, 0), False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
total_memory = sum(round_up(buf.nbytes, BLK:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%)
# suballocation: build sorted open/close events, then alloc/free in order
block_size = 256
nbytes = {b: round_up(b.arg * b.dtype.itemsize, block_size) for b in first_appearance}
events = sorted([(first_appearance[b], True, b) for b in first_appearance] +
[(last_appearance[b] + 1 + buf_hold.get(b, 0), False, b) for b in first_appearance], key=lambda x: (x[0], x[1]))
total_memory = sum(nbytes.values()) * 2
# Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator.
# Also track buffer replacements for buffers that do not support suballocation.
buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {}
reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list)
global_planner:dict[LaneKey, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=BLK, lv2_cnt=32)))
for (_, is_open_ev), buf in buffer_requests:
# Check if suballocation is possible for the given buffer and device.
if hasattr(Device[buf.device].allocator, "_offset"):
if is_open_ev: buffer_replace[buf] = (None, global_planner[_key(buf)][1].alloc(round_up(buf.nbytes, BLK)))
else: global_planner[_key(buf)][1].free(cast(int, buffer_replace[buf][1]))
global_planner[_key(buf)] = (max(global_planner[_key(buf)][0], buffer_replace[buf][1] + buf.nbytes), global_planner[_key(buf)][1])
else:
key = (_key(buf), buf.dtype, buf.options, buf.nbytes)
if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None)
else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0]))
offsets:dict[UOp, int] = {}
peaks:dict[LaneKey, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=block_size, lv2_cnt=32)))
for _, is_open, buf in events:
if is_open: offsets[buf] = peaks[_key(buf)][1].alloc(nbytes[buf])
else: peaks[_key(buf)][1].free(offsets[buf])
peaks[_key(buf)] = (max(peaks[_key(buf)][0], offsets[buf] + buf.arg * buf.dtype.itemsize), peaks[_key(buf)][1])
arena_sizes = {key: round_up(peak, block_size) for key, (peak, _) in peaks.items()}
# Allocate global buffers based on the memory planner.
global_buffers = {key: Buffer(key[0], round_up(sz, BLK), dtypes.int8) for key, (sz, _) in global_planner.items()}
buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[_key(buf)], off) for buf,(base,off) in buffer_replace.items()}
# build replace_map: each buffer becomes a BUFFER_VIEW into a shared per-device-lane arena
arenas = {key: UOp.new_buffer(key[0], sz, dtypes.int8) for key, sz in arena_sizes.items()}
replace_map:dict[UOp, UOp] = {}
for buf_uop, offset in offsets.items():
assert offset % buf_uop.dtype.itemsize == 0, f"offset {offset} not aligned to {buf_uop.dtype.itemsize}"
replace_map[buf_uop] = UOp(Ops.BUFFER_VIEW, buf_uop.dtype, (arenas[_key(buf_uop)],), (buf_uop.arg, offset // buf_uop.dtype.itemsize))
# Assign buffers. First, assign full buffers (not sub-buffers).
assigned:dict[Buffer, Buffer] = {}
for buf, (base, off) in buffer_resolve.items():
if buf != base:
assigned[buf] = base if off is None else Buffer(buf.device, buf.size, buf.dtype, base=base, offset=off)
if DEBUG >= 1 and (omem:=sum(nbytes.values()) / 1e6) != (nmem:=sum(arena_sizes.values()) / 1e6):
print(f"memory reduced from {omem:.2f} MB -> {nmem:.2f} MB, {len(first_appearance)} -> {len(arenas)} bufs")
# Now assign sub-buffers.
for buf in buf_to_opt:
if buf._base is not None:
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=(pbuf:=assigned.get(buf.base, buf.base)).base, offset=pbuf.offset+buf.offset)
if DEBUG >= 1:
ak, av = dedup(x for x in assigned.keys() if x._base is None),dedup(x for x in assigned.values() if x._base is None)+list(global_buffers.values())
omem, nmem = sum([x.nbytes for x in ak])/1e6, sum([x.nbytes for x in av])/1e6
if omem != nmem: print(f"{debug_prefix}memory reduced from {omem:.2f} MB -> {nmem:.2f} MB,", f"{len(ak)} -> {len(av)} bufs")
return assigned
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
copies=[(cast(Buffer,si.bufs[0]),cast(Buffer,si.bufs[1])) for si in schedule if si.ast.op is Ops.COPY])
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]
return linear.substitute(replace_map, name="memory plan", walk=True)

View file

@ -127,7 +127,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
si_lowerer = PatternMatcher([
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
(UPat(Ops.COPY), lambda ctx: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))),
(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="cf"), lambda ctx,cf: EncDec(cf, ctx[0].nbytes, ctx[0].device)),

View file

@ -83,7 +83,7 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]:
schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata))
return schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.memory import memory_plan_rewrite
from tinygrad.engine.realize import capturing
from tinygrad.schedule.rangeify import get_kernel_graph
from tinygrad.helpers import CAPTURING
@ -163,7 +163,9 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], di
capturing[0].add_linear(linear, var_vals)
return [], var_vals
held_bufs = ({b for b in linear_call.src[1:] if b.op is Ops.BUFFER} if linear_call.op is Ops.CALL else set())
linear = memory_plan_rewrite(linear, held_bufs)
# convert LINEAR to ExecItems
schedule: list[ExecItem] = linear_to_schedule(linear)
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
return schedule, var_vals

View file

@ -677,7 +677,7 @@ class ElementwiseMixin(DTypeMixin):
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
```
"""
return self.ne(0).where((self < 0).where(self.const_like(-1), self.const_like(1)), self.const_like(0)) + self * 0
return self.ne(0).where((self < 0).where(self.const_like(-1), self.const_like(1)), self.const_like(0))
def abs(self) -> Self:
"""

View file

@ -105,14 +105,25 @@ class InstOpRDNA4(Enum):
"""SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3."""
SALU = 0x0
SMEM = 0x1
SMEM_WR = 0x2
JUMP = 0x3
JUMP_NO = 0x4
CALL = 0x5
SALU_NO_EXEC = 0x7
MESSAGE = 0x9
VALU_1 = 0xa
VALU_TRANS = 0xb
VALU_B1 = 0xc
VALU_B2 = 0xd
VALU_B4 = 0xe
VALU_B16 = 0xf
VINTERP = 0x12
BARRIER_WAIT = 0x13
FLAT_RD_2 = 0x1c
FLAT_WR_3 = 0x1d
FLAT_WR_4 = 0x1e
FLAT_WR_5 = 0x1f
FLAT_WR_6 = 0x20
VMEM_RD_1 = 0x21
VMEM_RD_2 = 0x22
VMEM_WR_1 = 0x23
@ -127,18 +138,45 @@ class InstOpRDNA4(Enum):
LDS_WR_3 = 0x2c
LDS_WR_4 = 0x2d
LDS_WR_5 = 0x2e
BUF_RD_1 = 0x2f
BUF_RD_2 = 0x30
BUF_WR_1 = 0x31
BUF_WR_2 = 0x32
BUF_WR_3 = 0x33
BUF_WR_4 = 0x34
BUF_WR_5 = 0x35
BUF_WR_6 = 0x36
OTHER_LDS_1 = 0x50
OTHER_LDS_2 = 0x51
OTHER_LDS_3 = 0x52
OTHER_LDS_4 = 0x53
OTHER_LDS_5 = 0x54
OTHER_FLAT_2 = 0x55
OTHER_FLAT_3 = 0x56
OTHER_FLAT_4 = 0x57
OTHER_FLAT_5 = 0x58
OTHER_FLAT_6 = 0x59
LDS_DIR_LOAD = 0x6e
LDS_PARAM_LOAD = 0x6f
SALU_WR_EXEC = 0x72
VALU1_WR_EXEC = 0x73
VALU_B2_WR_EXEC = 0x74
OTHER_LDS_6 = 0x77
OTHER_LDS_10 = 0x78
BARRIER_SIGNAL = 0x7a
DYN_VGPR = 0x87
BARRIER_JOIN = 0x8a
WMMA_8 = 0x8c
WMMA_16 = 0x8d
WMMA_32 = 0x8e
WMMA_64 = 0x8f
VALU_DPFP = 0x92
SALU_FLOAT3 = 0x98
VALU_SCL_TRANS = 0x99
SALU_2 = 0x9b
SALU_5 = 0x9c
OTHER_VMEM = 0xbd
OTHER_VMEM_5 = 0xc1
OTHER_VMEM = 0xbc # 0xbc-0xdd: vmem_other_simd
for _i in range(34): InstOpRDNA4._value2member_map_[0xbc + _i] = InstOpRDNA4.OTHER_VMEM
class InstOpCDNA(Enum):
SMEM_RD = 0
@ -650,8 +688,8 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
yield (p, InstructionInfo(pc, wave, inst))
elif isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)):
inst = pc_map[pc:=wave_pc[p.wave]]
# s_delay_alu and s_wait_alu instructions are skipped
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU"}:
# s_delay_alu, s_wait_alu and s_barrier_wait instructions are skipped
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU", "S_BARRIER_WAIT"}:
wave_pc[p.wave] += inst.size()
inst = pc_map[pc:=wave_pc[p.wave]]
# assert branch always has a JUMP packet

View file

@ -487,8 +487,8 @@ class AMDHIPRenderer(CStyleLanguage):
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]},"
f" {fp8_index(x.src[0].dtype)}, {fp8_index(x.src[0].dtype)}, 0, 0, 0, 0)" if x.arg[1][2] == 128 else None),
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}, 0, 0, 0)"),
(UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",),
lambda ctx,x,y: f"f32_to_fp8({ctx[x.src[0]]}, {fp8_index(x.dtype)})"),
(UPat(Ops.CAST, dtypes.fp8s, (UPat(dtype=dtypes.float),), name="x",),
lambda ctx,x: f"f32_to_fp8({ctx[x.src[0]]}, {fp8_index(x.dtype)})"),
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",),
lambda ctx,x,y: f"__builtin_amdgcn_cvt_f32_{('fp8', 'bf8')[fp8_index(y.dtype)]}((unsigned int){ctx[x.src[0]]}, 0)"),
]) + base_rewrite

View file

@ -224,17 +224,17 @@ class AMDLLVMRenderer(LLVMRenderer):
(UPat(tuple(llvm_intrinsics), name="x"),
lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
(UPat(Ops.BARRIER), lambda ctx: barrier),
(UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",), lambda ctx,x,y:
(UPat(Ops.CAST, dtypes.fp8s, (UPat(dtype=dtypes.float),), name="x",), lambda ctx,x:
f" {ctx[x]} = call i8 @f32_to_fp8({ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i1 {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"),
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",), lambda ctx,x,y:
f" {ctx[x.src[0]]}_i32 = zext i8 {ctx[x.src[0]]} to i32\n"
f" {ctx[x]} = call float @llvm.amdgcn.cvt.f32.{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}(i32 {ctx[x.src[0]]}_i32, i32 0)"),
]) + base_rewrite
extra_matcher = LLVMRenderer.extra_matcher + create_non_native_float_pats(dtypes.fp8s) + PatternMatcher([
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
(UPat(Ops.CAST, dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
lambda y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
(UPat(Ops.CAST, dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
lambda y: UOp(Ops.VECTORIZE, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
# amd llvm intrinsics llvm.log2/llvm.exp2 don't support double
(UPat(Ops.LOG2, dtype=dtypes.double, src=(UPat.var("d"),)), xlog2),
(UPat(Ops.EXP2, dtype=dtypes.double, src=(UPat.var("d"),)), xexp2),

View file

@ -17,8 +17,8 @@ def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0
# alu ops, aop[<dtype>][<op>]
u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior",
Ops.AND: "iand", Ops.XOR: "ixor", Ops.WHERE: "bcsel", Ops.MAX: "umax"}
s_aop = {**u_aop, Ops.CMPLT: "ilt", Ops.IDIV: "idiv", Ops.MOD: "irem", Ops.MAX: "imax"}
Ops.AND: "iand", Ops.XOR: "ixor", Ops.WHERE: "bcsel", Ops.MAX: "umax", Ops.SHL: "ishl", Ops.SHR: "ushr"}
s_aop = {**u_aop, Ops.CMPLT: "ilt", Ops.IDIV: "idiv", Ops.MOD: "irem", Ops.MAX: "imax", Ops.SHR: "ishr"}
f_aop = { Ops.ADD: "fadd", Ops.MUL: "fmul", Ops.CMPLT: "flt", Ops.CMPNE: "fneu", Ops.CMPEQ: "feq", Ops.FDIV: "fdiv", Ops.RECIPROCAL: "frcp",
Ops.MAX: "fmax", Ops.TRUNC: "ftrunc", Ops.SIN: "fsin", Ops.EXP2: "fexp2", Ops.LOG2: "flog2"}
aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dtypes.sints}, **{x:f_aop for x in dtypes.floats}}
@ -130,6 +130,8 @@ class NIRRenderer(Renderer):
lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: x.replace(src=x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
# NIR requires shift amount to be 32 bit: https://docs.mesa3d.org/nir/alu.html#nir-alu-op-ishl
(UPat((Ops.SHL, Ops.SHR), name="x"), lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint))) if x.src[1].dtype.bitsize != 32 else None),
# OpConvertFToU is undefined if Result Type is not wide enough, cast through int32
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
@ -144,8 +146,8 @@ class NIRRenderer(Renderer):
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val")), allow_any_len=True, name="x"),
lambda ctx,x,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val")), allow_any_len=True),
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),

View file

@ -317,7 +317,7 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
# copy on CONST is CONST
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
# hack if a noop turned to a const
(UPat(Ops.NOOP, src=(UPat.cvar("c"),), name="noop"), lambda c,noop: c),
(UPat(Ops.NOOP, src=(UPat.cvar("c"),)), lambda c: c),
# mstack on CONST is CONST
(UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True),
lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None),

View file

@ -5,7 +5,7 @@ from contextlib import ContextDecorator
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid, InvalidType
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
@ -134,13 +134,7 @@ class Tensor(OpMixin):
if isinstance(data, UOp):
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}"
# if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
# TODO: remove this and stay in weakint
if data.dtype==dtypes.weakint: data = _index_to_concrete_int(data)
if data.op is Ops.BIND:
var, val = data.unbind()
# give the bound constant a device
const = UOp.const(var.dtype, val, _device, ())
data = data.replace(src=(var.replace(src=const.src), const))
if data.dtype == dtypes.weakint: data = Tensor.from_uop(data, device=_device).uop
elif data is None:
data = UOp.const(_dtype or dtypes.default_float, 0, _device)
elif isinstance(data, get_args(ConstType)):
@ -503,7 +497,13 @@ class Tensor(OpMixin):
@staticmethod
def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)
# TODO: remove this and stay in weakint
if y.dtype == dtypes.weakint: y = _index_to_concrete_int(y)
if y.op is Ops.BIND:
var, val = y.unbind()
_device = canonicalize_device(kwargs.get("device"))
const = UOp.const(var.dtype, val, _device, ())
return Tensor(y.replace(src=(var.replace(src=const.src), const)), **kwargs, requires_grad=False)
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
@ -1822,7 +1822,7 @@ class Tensor(OpMixin):
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
denominator = prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
return numerator.div(Tensor.from_uop(denominator, device=numerator.device) if isinstance(denominator, UOp) else denominator).cast(output_dtype)
return numerator.div(denominator).cast(output_dtype)
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
"""
@ -1848,7 +1848,8 @@ class Tensor(OpMixin):
"""
squares = (self - self.mean(axis=axis, keepdim=True)).square()
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
denominator = (Tensor.from_uop(n, device=self.device) if isinstance(n, UOp) else Tensor(n, device=self.device)) - correction
denominator = Tensor(n, device=self.device) - correction
# TODO: infer device and remove relu
return squares.sum(axis=axis, keepdim=keepdim).div(denominator.relu())
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
@ -2953,10 +2954,8 @@ class Tensor(OpMixin):
if not isinstance(y, Tensor):
# make y a Tensor
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
if y is Invalid or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
else: y = Tensor(y_dtype.const(y), x.device, y_dtype, requires_grad=False)
y_dtype = x.dtype if dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, (int, InvalidType))) else None
y = Tensor(y, x.device, y_dtype, requires_grad=False)
if match_dtype and x.dtype != y.dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
@ -2993,7 +2992,7 @@ class Tensor(OpMixin):
a, b = self._broadcasted(x, reverse)
return a + (-b)
def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
def div(self, x:Tensor|ConstType|UOp, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
"""
Divides `self` by `x`.
Equivalent to `self / x`.

View file

@ -431,13 +431,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if len(idx) < len(self.shape): idx += tuple([slice(None)]*(len(self.shape)-len(idx)))
assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx))
# apply SHRINK for slices that aren't the full range
bounds = tuple((s.start or 0, s.stop if s.stop is not None else self.shape[i]) if isinstance(s, slice) else (0, self.shape[i])
for i, s in enumerate(idx))
src = self if all(b == (0, self.shape[i]) for i, b in enumerate(bounds)) else self.shrink(bounds)
perm = src.permute(tuple([i for i in range(src.ndim) if i not in slice_idx] + slice_idx))
return perm.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
else:
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
return UOp.const(self.dtype.base, b, device=self._device, shape=self._shape)
def broadcast(self, count:int):
assert self.dtype.vcount == 1
if count == 1: return self

View file

@ -69,9 +69,9 @@ shared_spec = PatternMatcher([
# ***** UOp spec in the Tensor graph *****
movement_ops = PatternMatcher([
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint))), lambda mv,x: True),
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint))), lambda mv,x: True),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat(dtype=dtypes.weakint))), lambda: True),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint))), lambda: True),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
# inputs to movement ops
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.weakint), lambda: True),
@ -213,7 +213,7 @@ kernel_spec = PatternMatcher([
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),
# bufferize can be on anything
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True), lambda: True),
# reduce must be on ranges
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.weakint, dtypes.int) for y in x.src[1:])),
@ -286,7 +286,7 @@ full_spec = PatternMatcher([
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint), (UPat(), UPat()), arg=None), lambda: True),
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
(UPat((Ops.MSELECT, Ops.MSTACK)), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),

View file

@ -28,19 +28,19 @@ z3_renderer = PatternMatcher([
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
(UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx:
create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
(UPat((Ops.LOAD, Ops.INDEX), dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
(UPat((Ops.LOAD, Ops.INDEX), dtypes.bool), lambda ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
# constants
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x,ctx: (z3.Int("Invalid", ctx=ctx[0]), None)),
(UPat(Ops.CONST, arg=Invalid), lambda ctx: (z3.Int("Invalid", ctx=ctx[0]), None)),
(UPat(Ops.CONST, dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)),
(UPat(Ops.CONST, dtypes.bool, name="x"), lambda x,ctx: (z3.BoolVal(x.arg, ctx=ctx[0]), None)),
# casts from floats create new variables
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx:
create_bounded(f"cast{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
# A comparison between floats introduces a new bool variable
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0]), None)),
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats)), lambda ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0]), None)),
# casts from bool/int to int/bool
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,),src=(UPat.var("x", dtypes.bool),), name="c"), lambda x,c,ctx: (z3.If(ctx[1][x], 1, 0), None)),
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat.var("x", dtypes.ints+(dtypes.weakint,)),), name="c"), lambda x,c,ctx: (ctx[1][x], None)),
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,),src=(UPat.var("x", dtypes.bool),)), lambda x,ctx: (z3.If(ctx[1][x], 1, 0), None)),
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat.var("x", dtypes.ints+(dtypes.weakint,)),)), lambda x,ctx: (ctx[1][x], None)),
(UPat(Ops.CAST, dtypes.bool, name="x"), lambda x,ctx: (ctx[1][x.src[0]]!=0, None)),
(UPat(GroupOp.ALU, name="x"), lambda x,ctx: (z3_alu[x.op](*(ctx[1][s] for s in x.src)), None)),
])

View file

@ -315,6 +315,7 @@ function setFocus(key) {
p.append("span").style("margin-left", "8px").style("color", "#f0f0f566").text(formatTime(e.x));
tableData.push(["Cycle", formatTime(e.x-data.instSt)], ["Time", p.node()]);
} else tableData.push(["Start Time", formatTime(e.x)]);
if (data.link != null) tableData.push(["Delay", `${formatTime(Math.abs(selectShape(data.link[0]).e.x - selectShape(data.link[1]).e.x))} Cycles`]);
html.append(() => tabulate(tableData));
let group = html.append("div").classed("args", true);
for (const r of rest) group.append("p").text(r);

View file

@ -337,7 +337,7 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
def sqtt_timeline(data:bytes, lib:bytes, target:str, max_pkts=getenv("MAX_SQTT_PKTS",50_000)) -> tuple[list[ProfileEvent], bool]:
from tinygrad.renderer.amd.sqtt import (map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC,
ALUEXEC, INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA,
WAVEEND, CDNA_WAVEEND, WAVERDY)
@ -365,8 +365,11 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None:
exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}")
if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): e.name = TracingKey(op or name, ret=f"LINK:{exec_pending[name].pop(0)}")
has_more = False
for p, info in map_insts(data, lib, target):
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
if len(ret) > max_pkts:
has_more = True
break
if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker:
pair = (p._time, p.delta)
if prev_pair is None: prev_pair = pair
@ -393,7 +396,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
else:
add(name.replace("_ALT", ""), p, op=name)
pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()}
return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret
return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret, has_more
# ** SQTT OCC only unpacks wave start, end time and SIMD location
@ -619,7 +622,7 @@ def get_render(query:str) -> dict:
if fmt.startswith("prg-pkts"):
ret = {}
with soft_err(lambda err:ret.update(err)):
if (events:=get_profile(sqtt_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
if (events:=get_profile(sqtt_timeline(*data)[0], sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
else: ret = {"src":"No SQTT trace on this SE."}
return ret
if fmt == "prg-sqtt":