mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into new_x86_backend
This commit is contained in:
commit
cd0152efec
38 changed files with 725 additions and 291 deletions
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
|
|
@ -70,7 +70,7 @@ jobs:
|
|||
source venv/bin/activate
|
||||
pip install $GITHUB_WORKSPACE
|
||||
cp $GITHUB_WORKSPACE/examples/beautiful_mnist.py .
|
||||
BS=2 STEPS=10 python beautiful_mnist.py
|
||||
BS=2 STEPS=10 MAX_BUFFER_SIZE=0 python beautiful_mnist.py
|
||||
- name: Test Docs Build
|
||||
run: python -m mkdocs build --strict
|
||||
- name: Test Docs
|
||||
|
|
@ -141,7 +141,7 @@ jobs:
|
|||
sudo apt update || true
|
||||
sudo apt install -y --no-install-recommends ninja-build
|
||||
- name: Test beautiful_mnist in torch with TINY_BACKEND
|
||||
run: STEPS=20 CPU=1 TARGET_EVAL_ACC_PCT=90.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
run: STEPS=20 CPU=1 TARGET_EVAL_ACC_PCT=90.0 MAX_BUFFER_SIZE=0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
- name: Test some torch tests (expect failure)
|
||||
run: python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from extra.onnx_helpers import get_example_inputs, validate
|
|||
|
||||
def load_onnx_model(onnx_file):
|
||||
run_onnx = OnnxRunner(onnx_file)
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True)
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True)
|
||||
return run_onnx_jit, run_onnx.graph_inputs
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#!/bin/bash
|
||||
export BENCHMARK=5
|
||||
export EVAL_BS=0
|
||||
export VIZ=${VIZ:--1}
|
||||
examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh
|
||||
extra/viz/cli.py --profile --device "AMD" --top 20
|
||||
VIZ=${VIZ:--1} examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh
|
||||
extra/viz/cli.py --profile --device "AMD" --limit 20
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.utils import make_grid, save_image
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import trange
|
||||
from tinygrad.nn import optim
|
||||
from extra.datasets import fetch_mnist
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
class LinearGen:
|
||||
def __init__(self):
|
||||
|
|
@ -38,14 +37,14 @@ class LinearDisc:
|
|||
return x
|
||||
|
||||
def make_batch(images):
|
||||
sample = np.random.randint(0, len(images), size=(batch_size))
|
||||
image_b = images[sample].reshape(-1, 28*28).astype(np.float32) / 127.5 - 1.0
|
||||
return Tensor(image_b)
|
||||
sample = Tensor.randint(batch_size, low=0, high=images.shape[0])
|
||||
return images[sample].reshape(batch_size, 28*28).cast('float').div(127.5).sub(1.0)
|
||||
|
||||
def make_labels(bs, col, val=-2.0):
|
||||
y = np.zeros((bs, 2), np.float32)
|
||||
y[range(bs), [col] * bs] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
|
||||
return Tensor(y)
|
||||
y = Tensor.zeros(bs, 2)
|
||||
if col == 0: y = y + Tensor([val, 0.0])
|
||||
else: y = y + Tensor([0.0, val])
|
||||
return y
|
||||
|
||||
def train_discriminator(optimizer, data_real, data_fake):
|
||||
real_labels = make_labels(batch_size, 1)
|
||||
|
|
@ -71,12 +70,12 @@ def train_generator(optimizer, data_fake):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# data for training and validation
|
||||
images_real = np.vstack(fetch_mnist()[::2])
|
||||
X_train, _, _, _ = mnist()
|
||||
ds_noise = Tensor.randn(64, 128, requires_grad=False)
|
||||
# parameters
|
||||
epochs, batch_size, k = 300, 512, 1
|
||||
sample_interval = epochs // 10
|
||||
n_steps = len(images_real) // batch_size
|
||||
n_steps = X_train.shape[0] // batch_size
|
||||
# models and optimizer
|
||||
generator = LinearGen()
|
||||
discriminator = LinearDisc()
|
||||
|
|
@ -84,24 +83,24 @@ if __name__ == "__main__":
|
|||
output_dir = Path(".").resolve() / "outputs"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
# optimizers
|
||||
optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
||||
optim_d = optim.Adam(get_parameters(discriminator),lr=0.0002, b1=0.5)
|
||||
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
||||
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
|
||||
# training loop
|
||||
Tensor.training = True
|
||||
for epoch in (t := trange(epochs)):
|
||||
loss_g, loss_d = 0.0, 0.0
|
||||
for _ in range(n_steps):
|
||||
data_real = make_batch(images_real)
|
||||
for step in range(k): # Try with k = 5 or 7.
|
||||
with Tensor.train():
|
||||
for epoch in (t := trange(epochs)):
|
||||
loss_g, loss_d = 0.0, 0.0
|
||||
for _ in range(n_steps):
|
||||
data_real = make_batch(X_train)
|
||||
for step in range(k): # Try with k = 5 or 7.
|
||||
noise = Tensor.randn(batch_size, 128)
|
||||
data_fake = generator.forward(noise).detach()
|
||||
loss_d += train_discriminator(optim_d, data_real, data_fake)
|
||||
noise = Tensor.randn(batch_size, 128)
|
||||
data_fake = generator.forward(noise).detach()
|
||||
loss_d += train_discriminator(optim_d, data_real, data_fake)
|
||||
noise = Tensor.randn(batch_size, 128)
|
||||
data_fake = generator.forward(noise)
|
||||
loss_g += train_generator(optim_g, data_fake)
|
||||
if (epoch + 1) % sample_interval == 0:
|
||||
fake_images = generator.forward(ds_noise).detach().numpy()
|
||||
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
|
||||
save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg")
|
||||
t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}")
|
||||
data_fake = generator.forward(noise)
|
||||
loss_g += train_generator(optim_g, data_fake)
|
||||
if (epoch + 1) % sample_interval == 0:
|
||||
fake_images = generator.forward(ds_noise).detach().numpy()
|
||||
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
|
||||
save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg")
|
||||
t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}")
|
||||
print("Training Completed!")
|
||||
|
|
|
|||
|
|
@ -13,12 +13,20 @@ from collections import OrderedDict
|
|||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "CL"]
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
# memory-planned subbuffers can have multiple Buffer objects for the same memory region
|
||||
canon, _seen = {}, {}
|
||||
for ji in run.jit_cache:
|
||||
for b in ji.bufs:
|
||||
if b is not None: canon[id(b)] = _seen.setdefault((id(b.base._buf), b.offset, b.size, b.dtype), b)
|
||||
special_names = {id(canon[k]): v for k, v in special_names.items() if k in canon}
|
||||
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for ji in run.jit_cache:
|
||||
fxn: ProgramSpec = ji.prg.p
|
||||
functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(ji.bufs):
|
||||
arg = canon[id(arg)]
|
||||
key = id(arg)
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
|
|
|
|||
208
extra/gemm/amd_flash_attention.py
Normal file
208
extra/gemm/amd_flash_attention.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
from tinygrad import Tensor, UOp, getenv
|
||||
from tinygrad.uop.ops import AxisType, KernelInfo, Ops
|
||||
from tinygrad.dtype import AddrSpace, dtypes
|
||||
from tinygrad.helpers import DEBUG, GlobalCounters, Context
|
||||
import math
|
||||
|
||||
B = getenv("B", 1)
|
||||
H = getenv("H", 32)
|
||||
N = getenv("N", 1024)
|
||||
D = getenv("D", 64)
|
||||
assert D % 16 == 0 and N % 16 == 0
|
||||
|
||||
BLOCK_M, BLOCK_N = 64, 64
|
||||
WARP_SIZE = 32
|
||||
WMMA_M, WMMA_N, WMMA_K = 16, 16, 16
|
||||
WAVES_M, WAVES_N = 4, 1
|
||||
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 2, 16
|
||||
WMMA_ACC = WMMA_M // LANES_PER_WAVE_M
|
||||
THREADS_PER_BLOCK = WARP_SIZE * WAVES_M * WAVES_N
|
||||
|
||||
TM = BLOCK_M // (WAVES_M * LANES_PER_WAVE_M)
|
||||
TN = BLOCK_N // (WAVES_N * LANES_PER_WAVE_N)
|
||||
TD = D // (WAVES_N * LANES_PER_WAVE_N)
|
||||
LDS_PAD = 4 # pad LDS rows to reduce bank conflicts
|
||||
|
||||
WMMA_ARG = ((WMMA_M, WMMA_N, WMMA_K), 'AMD', 32)
|
||||
SCALE = 1.0 / math.sqrt(D)
|
||||
LOG2E = math.log2(math.e)
|
||||
|
||||
def warp_shfl_xor(val, offset, lane):
|
||||
"""Read val from lane ^ offset using ds_bpermute."""
|
||||
idx = ((lane ^ offset) * 4).cast(dtypes.int)
|
||||
return UOp(Ops.CUSTOM, dtypes.float, (idx, val),
|
||||
arg="__builtin_bit_cast(float, __builtin_amdgcn_ds_bpermute({0}, __builtin_bit_cast(int, {1})))")
|
||||
|
||||
def warp_reduce_max(val, lane):
|
||||
"""Tree reduce MAX across LANES_PER_WAVE_N=16 lanes."""
|
||||
for offset in [8, 4, 2, 1]:
|
||||
val = UOp(Ops.MAX, dtypes.float, (val, warp_shfl_xor(val, offset, lane)))
|
||||
return val
|
||||
|
||||
def warp_reduce_sum(val, lane):
|
||||
"""Tree reduce SUM across LANES_PER_WAVE_N=16 lanes."""
|
||||
for offset in [8, 4, 2, 1]:
|
||||
val = val + warp_shfl_xor(val, offset, lane)
|
||||
return val
|
||||
|
||||
def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
|
||||
block_bh = UOp.range(B * H, 0, AxisType.GLOBAL)
|
||||
block_m = UOp.range(N // BLOCK_M, 1, AxisType.GLOBAL)
|
||||
|
||||
q = q.reshape(B*H, N//BLOCK_M, BLOCK_M, D)[block_bh, block_m]
|
||||
k = k.reshape(B*H, N//BLOCK_N, BLOCK_N, D)[block_bh]
|
||||
v = v.reshape(B*H, N//BLOCK_N, BLOCK_N, D)[block_bh]
|
||||
o = o.reshape(B*H, N//BLOCK_M, BLOCK_M, D)[block_bh, block_m]
|
||||
|
||||
wave_m = UOp.range(WAVES_M, 2, AxisType.LOCAL)
|
||||
wave_n = UOp.range(WAVES_N, 3, AxisType.LOCAL)
|
||||
lane = UOp.range(WARP_SIZE, -1, AxisType.WARP)
|
||||
tid = (wave_m * WAVES_N + wave_n) * WARP_SIZE + lane
|
||||
lane_m = lane // LANES_PER_WAVE_N
|
||||
lane_n = lane % LANES_PER_WAVE_N
|
||||
|
||||
# LDS allocation: slot 0 = Q then P (shared), slot 1 = K then V
|
||||
# TODO: the memory planner should be able to find this reuse
|
||||
ELEMS_PER_THREAD = BLOCK_M * D // THREADS_PER_BLOCK
|
||||
QP_lds = UOp.placeholder((BLOCK_M, D + LDS_PAD), dtypes.half, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
KV_lds = UOp.placeholder((BLOCK_N, D + LDS_PAD), dtypes.half, slot=1, addrspace=AddrSpace.LOCAL)[:, :D]
|
||||
|
||||
# register state
|
||||
acc = UOp.placeholder((TM, TD), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
m_i = UOp.placeholder((TM,), dtypes.float, slot=3, addrspace=AddrSpace.REG)
|
||||
l_i = UOp.placeholder((TM,), dtypes.float, slot=4, addrspace=AddrSpace.REG)
|
||||
acc = acc.after(acc.store(acc.const_like(0)))
|
||||
m_i = m_i.after(m_i.store(m_i.const_like(-math.inf)))
|
||||
l_i = l_i.after(l_i.store(l_i.const_like(0)))
|
||||
|
||||
# ====== KV tile loop ======
|
||||
n_tile = UOp.range(N // BLOCK_N, 100, AxisType.REDUCE)
|
||||
|
||||
# load Q + K into LDS (Q reloaded each iteration since P overwrites slot 0)
|
||||
Q_lds = QP_lds[:, :D]
|
||||
Q_store = Q_lds.after(n_tile).reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid].store(
|
||||
q.reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid])
|
||||
K_store = KV_lds.reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid].store(
|
||||
k[n_tile].reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid])
|
||||
qk_load_barrier = UOp.barrier(UOp.group(Q_store, K_store))
|
||||
Q_lds = Q_lds.after(qk_load_barrier)
|
||||
KV_lds_k = KV_lds.after(qk_load_barrier)
|
||||
|
||||
# -- S = Q @ K^T via WMMA (re-init each n_tile) --
|
||||
S_reg = UOp.placeholder((TM, TN), dtypes.float, slot=6, addrspace=AddrSpace.REG)
|
||||
S_reg = S_reg.after(S_reg.after(n_tile).store(S_reg.const_like(0)))
|
||||
k_qk = UOp.range(D // WMMA_K, 101, AxisType.REDUCE)
|
||||
tm1 = UOp.range(TM // WMMA_ACC, 200, AxisType.LOOP)
|
||||
tn1 = UOp.range(TN, 201, AxisType.LOOP)
|
||||
S_frag = S_reg.reshape(TM // WMMA_ACC, WMMA_ACC, TN).permute(0, 2, 1)[tm1, tn1]
|
||||
q_frag = Q_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, D // WMMA_K, WMMA_K)[wave_m, tm1, lane_n, k_qk]
|
||||
k_frag = KV_lds_k.reshape(WAVES_N, TN, WMMA_N, D // WMMA_K, WMMA_K)[wave_n, tn1, lane_n, k_qk]
|
||||
qk = UOp(Ops.SHAPED_WMMA, dtypes.float, (q_frag, k_frag, S_frag.after(k_qk)), arg=WMMA_ARG)
|
||||
qk_done = S_frag.store(qk).end(tm1, tn1).end(k_qk)
|
||||
S_reg = S_reg.after(qk_done)
|
||||
|
||||
# -- softmax in registers with warp shuffles --
|
||||
S_reg = S_reg.after(S_reg.store(S_reg * SCALE))
|
||||
|
||||
# per-thread local row max over TN=4 elements, then warp reduce across 16 lanes
|
||||
m_ij = UOp.placeholder((TM,), dtypes.float, slot=7, addrspace=AddrSpace.REG)
|
||||
m_ij = m_ij.after(m_ij.after(n_tile).store(m_ij.const_like(-math.inf)))
|
||||
rm1 = UOp.range(TM, 260, AxisType.LOOP)
|
||||
rm2 = UOp.range(TN, 261, AxisType.REDUCE)
|
||||
m_ij = m_ij.after(m_ij[rm1].store(UOp(Ops.MAX, dtypes.float, (m_ij.after(rm1, rm2)[rm1], S_reg[rm1, rm2]))).end(rm2, rm1))
|
||||
# warp reduce max (in-place)
|
||||
ri_w = UOp.range(TM, 270, AxisType.LOOP)
|
||||
m_ij = m_ij.after(m_ij[ri_w].store(warp_reduce_max(m_ij[ri_w], lane)).end(ri_w))
|
||||
|
||||
# compute P = exp(S - m_ij) in S_reg (manual ranges)
|
||||
rp0a = UOp.range(TM, 275, AxisType.LOOP)
|
||||
rp0b = UOp.range(TN, 276, AxisType.LOOP)
|
||||
S_reg = S_reg.after(S_reg[rp0a, rp0b].store(((S_reg[rp0a, rp0b] - m_ij[rp0a]) * LOG2E).exp2()).end(rp0a, rp0b))
|
||||
|
||||
p_local = UOp.placeholder((TM,), dtypes.float, slot=8, addrspace=AddrSpace.REG)
|
||||
p_local = p_local.after(p_local.after(n_tile).store(p_local.const_like(0)))
|
||||
rp1 = UOp.range(TM, 290, AxisType.LOOP)
|
||||
rp2 = UOp.range(TN, 291, AxisType.REDUCE)
|
||||
p_local = p_local.after(p_local[rp1].store(p_local.after(rp1, rp2)[rp1] + S_reg[rp1, rp2]).end(rp2, rp1))
|
||||
ri_ws = UOp.range(TM, 295, AxisType.LOOP)
|
||||
p_sum = p_local.after(p_local[ri_ws].store(warp_reduce_sum(p_local[ri_ws], lane)).end(ri_ws))
|
||||
|
||||
# write P = exp(S - m_ij) to P_lds (reuses slot 0, Q no longer needed)
|
||||
P_lds = QP_lds[:, :BLOCK_N]
|
||||
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
|
||||
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
|
||||
rw1 = UOp.range(TM, 296, AxisType.LOOP)
|
||||
rw2 = UOp.range(TN, 297, AxisType.LOOP)
|
||||
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)
|
||||
|
||||
# -- online softmax correction --
|
||||
ri4 = UOp.range(TM, 330, AxisType.LOOP)
|
||||
m_new_val = UOp(Ops.MAX, dtypes.float, (m_i[ri4], m_ij[ri4]))
|
||||
alpha_val = ((m_i[ri4] - m_new_val) * LOG2E).exp2()
|
||||
beta_val = ((m_ij[ri4] - m_new_val) * LOG2E).exp2()
|
||||
rj4 = UOp.range(TD, 331, AxisType.LOOP)
|
||||
correction = UOp.group(
|
||||
acc[ri4, rj4].store(alpha_val * acc[ri4, rj4]).end(rj4),
|
||||
l_i[ri4].store(alpha_val * l_i[ri4] + beta_val * p_sum[ri4]),
|
||||
m_i[ri4].store(m_new_val),
|
||||
).end(ri4)
|
||||
acc = acc.after(correction)
|
||||
l_i = l_i.after(correction)
|
||||
m_i = m_i.after(correction)
|
||||
|
||||
# load V into KV_lds (must wait for QK WMMA to finish reading K from KV_lds)
|
||||
V_store = KV_lds.after(qk_done).reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid].store(
|
||||
v[n_tile].reshape(THREADS_PER_BLOCK, ELEMS_PER_THREAD)[tid])
|
||||
pv_barrier = UOp.barrier(UOp.group(P_store, V_store))
|
||||
P_lds = P_lds.after(pv_barrier)
|
||||
KV_lds_v = KV_lds.after(pv_barrier)
|
||||
|
||||
# -- acc += P @ V via WMMA --
|
||||
k_pv = UOp.range(BLOCK_N // WMMA_K, 400, AxisType.REDUCE)
|
||||
tm2 = UOp.range(TM // WMMA_ACC, 401, AxisType.LOOP)
|
||||
tn2 = UOp.range(TD, 402, AxisType.LOOP)
|
||||
acc_frag = acc.reshape(TM // WMMA_ACC, WMMA_ACC, TD).permute(0, 2, 1)[tm2, tn2]
|
||||
p_frag = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_N // WMMA_K, WMMA_K)[wave_m, tm2, lane_n, k_pv]
|
||||
v_frag = KV_lds_v.reshape(WAVES_N, TD, WMMA_N, BLOCK_N // WMMA_K, WMMA_K)[wave_n, tn2, lane_n, k_pv]
|
||||
pv = UOp(Ops.SHAPED_WMMA, dtypes.float, (p_frag, v_frag, acc_frag.after(k_pv)), arg=WMMA_ARG)
|
||||
|
||||
# end KV tile loop
|
||||
n_tile_end = acc_frag.store(pv).end(tm2, tn2).end(k_pv).barrier().end(n_tile)
|
||||
acc = acc.after(n_tile_end)
|
||||
l_i = l_i.after(n_tile_end)
|
||||
m_i = m_i.after(n_tile_end)
|
||||
|
||||
# normalize: acc /= l_i
|
||||
acc = acc.after(acc.store(acc * (1 / l_i).reshape(TM, 1).expand(TM, TD)))
|
||||
|
||||
# store output
|
||||
o = o.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TD, LANES_PER_WAVE_N)
|
||||
o = o.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TD)
|
||||
return o[tid].store(acc).end(wave_m, wave_n, lane).end(block_m, block_bh).sink(arg=KernelInfo(opts_to_apply=()))
|
||||
|
||||
if __name__ == "__main__":
|
||||
q = Tensor.rand(B, H, N, D).cast(dtypes.half)
|
||||
k = Tensor.rand(B, H, N, D).cast(dtypes.half)
|
||||
v = Tensor.rand(B, H, N, D).cast(dtypes.half)
|
||||
o = Tensor.empty(B, H, N, D, dtype=dtypes.float)
|
||||
with Context(DEBUG=0): Tensor.realize(q, k, v)
|
||||
|
||||
q_flat, k_flat, v_flat, o_flat = q.reshape(B*H, N, D), k.reshape(B*H, N, D), v.reshape(B*H, N, D), o.reshape(B*H, N, D)
|
||||
NUM_RUNS = getenv("CNT", 5)
|
||||
ets = []
|
||||
with Context(DEBUG=getenv("KDBG", 2)):
|
||||
for _ in range(NUM_RUNS):
|
||||
GlobalCounters.reset()
|
||||
tst = Tensor.custom_kernel(o_flat, q_flat, k_flat, v_flat, fxn=amd_flash_attention)[0].realize()
|
||||
ets.append(GlobalCounters.time_sum_s)
|
||||
print(f"best time: {min(ets)*1e3:.2f}ms")
|
||||
|
||||
if getenv("VERIFY", 1):
|
||||
with Context(DEBUG=0):
|
||||
ref = q.float().scaled_dot_product_attention(k.float(), v.float()).reshape(B*H, N, D).realize()
|
||||
err = (ref - tst).square().mean().item()
|
||||
print(f"mean squared error {err}")
|
||||
if err > 1e-2:
|
||||
raise RuntimeError("flash attention is wrong!")
|
||||
else:
|
||||
print("flash attention is correct!")
|
||||
|
|
@ -9,7 +9,7 @@ EXAMPLES = {
|
|||
"empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
|
||||
"plus":"test/test_tiny.py TestTiny.test_plus",
|
||||
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"",
|
||||
"sync":"test/amd/test_custom_kernel.py TestCustomKernel.test_wave_sync",
|
||||
"sync":"test/amd/test_custom_kernel.py TestCustomKernel.test_lds_sync",
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,6 +1,7 @@
|
|||
A command line tool for exploring the VIZ trace.
|
||||
|
||||
After running with VIZ=-1, use `extra/viz/cli.py` to explore the saved trace files.
|
||||
1. Set VIZ to -1 to save the trace.
|
||||
2. Use `extra/viz/cli.py` to inspect the trace files.
|
||||
|
||||
## Inspect runtime profiling
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse, pathlib, sys, struct, json
|
||||
import argparse, pathlib, sys, struct, json, itertools
|
||||
from typing import Iterator
|
||||
from tinygrad.viz import serve as viz
|
||||
from tinygrad.uop.ops import RewriteTrace
|
||||
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, Context
|
||||
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen
|
||||
|
||||
# ** generic helpers
|
||||
|
||||
|
|
@ -59,16 +59,18 @@ def decode_profile(data:bytes) -> dict:
|
|||
return {"dur":total_dur, "peak":global_peak, "layout":layout, "markers":markers}
|
||||
|
||||
if __name__ == "__main__":
|
||||
Context(VIZ=0, TRACK_MATCH_STATS=0).__enter__()
|
||||
parser = argparse.ArgumentParser()
|
||||
g_mode = parser.add_argument_group("mode")
|
||||
g_mode.add_argument("--profile", action="store_true", help="View profile trace")
|
||||
g_mode.add_argument("--rewrites", action="store_true", help="View rewrites trace")
|
||||
g_common = parser.add_argument_group("common options")
|
||||
g_common.add_argument("--kernel", type=str, default=None, metavar="NAME", help="Select a kernel by name (optional name, default: only list names)")
|
||||
g_common.add_argument("--no-color", action="store_true", default=not (sys.stdin.isatty() and sys.stdout.isatty()),
|
||||
help="Disable colored output (default: true in non-interactive mode)")
|
||||
g_profile = parser.add_argument_group("profile options")
|
||||
g_profile.add_argument("--device", type=str, default=None, metavar="NAME", help="Select a device (optional name, default: only list names)")
|
||||
g_profile.add_argument("--top", type=int, default=10, metavar="N", help="Number of top kernels to show (-1 for all, default: 10)")
|
||||
g_profile.add_argument("--offset", type=int, default=0, metavar="N", help="event offset (default: 0)")
|
||||
g_profile.add_argument("--limit", type=int, default=10, metavar="N", help="events to display (-1 for all, default: 10)")
|
||||
g_rewrites = parser.add_argument_group("rewrites options")
|
||||
g_rewrites.add_argument("--select", type=str, default=None, metavar="NAME",
|
||||
help="Select an item within the chosen kernel (optional name, default: only list names)")
|
||||
|
|
@ -84,14 +86,64 @@ if __name__ == "__main__":
|
|||
viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))
|
||||
viz.ctxs = viz.get_rewrites(viz.trace)
|
||||
|
||||
def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s
|
||||
|
||||
if args.profile:
|
||||
from tabulate import tabulate
|
||||
profile = decode_profile(viz.get_profile(viz.load_pickle(args.profile_path, default=[])))
|
||||
profile = decode_profile(viz.get_profile(profile_data:=viz.load_pickle(args.profile_path, default=[])))
|
||||
viz.load_amd_counters(viz.ctxs, profile_data)
|
||||
counters = {f'{c["name"]} SQTT {s["name"]}': s["data"] for c in viz.ctxs if c["name"].startswith("Exec") for s in c["steps"]
|
||||
if s["name"].startswith("PKTS")}
|
||||
if args.device is None:
|
||||
print("Select a device:")
|
||||
for k in (*profile["layout"], *counters):
|
||||
print(f" {format_colored(k)}")
|
||||
sys.exit(0)
|
||||
|
||||
# ** SQTT printer
|
||||
if args.device is not None and (sqtt_data:=next((v for k,v in counters.items() if ansistrip(k) == args.device), None)) is not None:
|
||||
assert args.limit > 1, f"SQTT limit must be greater than 1, got {args.limit}"
|
||||
sqtt_events, has_more = viz.sqtt_timeline(*sqtt_data, max_pkts=args.offset+args.limit)
|
||||
sqtt_pkts = [e for e in sqtt_events if type(e).__name__ == "ProfileRangeEvent"]
|
||||
pc_map = next((e.arg for e in sqtt_events if type(e).__name__ == "ProfilePointEvent" and e.key == 'pcMap'), None)
|
||||
if pc_map is None:
|
||||
print(f"No SQTT instruction trace data for {args.device}")
|
||||
sys.exit(0)
|
||||
# modern terminals support 24-bit color
|
||||
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
|
||||
WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'),
|
||||
(('STORE',), '#4fa3cc'), (('IMMEDIATE',), '#f3b44a'), (('BARRIER',), '#d00000'), (('LDS',), '#9fb4a6'), (('JUMP',), '#ffb703'),
|
||||
(('JUMP_NO',), '#fb8500'), (('MESSAGE',), '#90dbf4'), (('WAVERDY',), '#1a2a2a'))
|
||||
print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Info'}")
|
||||
print("-" * 90)
|
||||
# start from the first packet in trace, prepare packet indexes and map dispatches
|
||||
pkt_idxs:dict[str, itertools.count] = {}
|
||||
dispatch_to_pc:dict[str, int] = {}
|
||||
for e in sqtt_pkts[:-args.limit]:
|
||||
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
|
||||
if e.name.ret is not None and e.name.ret.startswith("PC:"): dispatch_to_pc[f"{e.device}-{idx}"] = int(e.name.ret.replace("PC:", ""))
|
||||
# start printing from the offset point
|
||||
for e in sqtt_pkts[-args.limit:]:
|
||||
op_name, info = e.name.display_name, e.name.ret or ""
|
||||
color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None)
|
||||
op_str = hex_colored(op_name, color) if color and not args.no_color else op_name
|
||||
phase, pc = None, None
|
||||
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
|
||||
if info.startswith("PC:"):
|
||||
dispatch_to_pc[f"{e.device}-{idx}"] = pc = int(info.replace("PC:", ""))
|
||||
phase = "DISPATCH"
|
||||
if info.startswith("LINK:"): phase, pc = "EXEC", dispatch_to_pc[info.replace("LINK:", "")]
|
||||
if pc and phase: info = f"{phase:<8} 0x{pc:05x} {pc_map[pc]}"
|
||||
print(f"{int(e.st):<12} {e.device:<20} {op_str}{' '*(22-ansilen(op_str))} {int(e.en-e.st):<4} {info}")
|
||||
# note: we only print the important packets and skip the rest
|
||||
if has_more: print(f"Selected packets {args.offset:,}-{args.offset + args.limit:,}. Use --offset and --limit to see others")
|
||||
sys.exit(0)
|
||||
|
||||
# ** Profiler printer
|
||||
agg, total, n = {}, 0, 0
|
||||
if args.device is None: print("Select a device:")
|
||||
for k,v in profile["layout"].items():
|
||||
if not optional_eq({"name":k}, args.device): continue
|
||||
print(f" {k}")
|
||||
print(f" {format_colored(k)}")
|
||||
if args.device is None: continue
|
||||
for e in v.get("events", []):
|
||||
et = e["dur"]*1e-6
|
||||
|
|
@ -99,7 +151,7 @@ if __name__ == "__main__":
|
|||
if optional_eq(e, args.kernel) and n < 10:
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
name = e["name"]+(" " * (46 - ansilen(e["name"])))
|
||||
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e['fmt'].replace('\n', ' | ')+" ")
|
||||
print(f"{name} {ptm}/{(et or 0)*1e3:9.2f}ms "+e.get('fmt', '').replace('\n', ' | ')+" ")
|
||||
n += 1
|
||||
else:
|
||||
a = agg.setdefault(e["name"], [0.0, 0])
|
||||
|
|
@ -108,14 +160,15 @@ if __name__ == "__main__":
|
|||
total += et
|
||||
if agg and total > 0:
|
||||
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
|
||||
sel = items if args.top == -1 else items[:args.top]
|
||||
sel = items if args.limit == -1 else items[args.offset:args.offset+args.limit]
|
||||
table = [[name, time_to_str(t, w=9), c, f"{(t/total*100.0):.2f}%"] for name,(t,c) in sel]
|
||||
if args.top != -1 and (other:=items[len(sel):]):
|
||||
if args.limit != -1 and (other:=items[len(sel):]):
|
||||
other_t = total-sum(t for _, (t, _) in sel)
|
||||
table.append([f"Other ({len(other)} unique)", time_to_str(other_t, w=9), sum(c for _,(_,c) in other), f"{other_t/total*100.0:.2f}%"])
|
||||
print(tabulate(table, headers=["name", "total", "count", "pct"], tablefmt="github"))
|
||||
sys.exit(0)
|
||||
|
||||
# ** Graph rewrites printer
|
||||
for k in viz.ctxs:
|
||||
if not optional_eq(k, args.kernel): continue
|
||||
print(k["name"])
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@ import functools
|
|||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
from tinygrad.runtime.autogen.amd.rdna4.ins import s_barrier_wait, s_barrier_signal
|
||||
import tinygrad.runtime.autogen.amd.rdna3.ins as r3
|
||||
import tinygrad.runtime.autogen.amd.rdna4.ins as r4
|
||||
from tinygrad.renderer.amd.dsl import s, v
|
||||
from test.amd.helpers import TARGET_TO_ARCH
|
||||
|
||||
|
|
@ -53,12 +55,46 @@ def custom_wave_sync(A:UOp, arch:str) -> UOp:
|
|||
insts = []
|
||||
for _ in range(4):
|
||||
insts.append(s_sleep(4))
|
||||
insts += [s_barrier()] if arch == "rdna3" else [s_barrier_signal(), s_barrier_wait()]
|
||||
insts += [s_barrier()] if arch == "rdna3" else [r4.s_barrier_signal(), r4.s_barrier_wait()]
|
||||
insts += [s_nop(0)]*4
|
||||
insts.append(s_endpgm())
|
||||
sink = UOp.sink(A.base, threads, wg, arg=KernelInfo("custom_wave_sync"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
||||
def custom_lds_sync(A:UOp, arch:str) -> UOp:
|
||||
A = A.flatten()
|
||||
num_threads = A.shape[0]
|
||||
threads = UOp.special(num_threads, "lidx0")
|
||||
wg = UOp.special(1, "gidx0")
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
|
||||
isa = r4 if arch == "rdna4" else r3
|
||||
wait_kmcnt = [isa.s_wait_kmcnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt(lgkmcnt=0)]
|
||||
wait_dscnt = [isa.s_wait_dscnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt(lgkmcnt=0)]
|
||||
barrier = [isa.s_barrier_signal(ssrc0=-1), isa.s_barrier_wait(simm16=-1)] if arch == "rdna4" else [isa.s_barrier()]
|
||||
global_store = [isa.global_store_b32(vaddr=v[6:7], saddr=s[0:1], vsrc=v[5])] if arch == "rdna4" \
|
||||
else [isa.global_store_b32(addr=v[6], data=v[5], saddr=s[0:1])]
|
||||
insts = [
|
||||
isa.s_load_b64(s[0:1], s[0:1], soffset=NULL),
|
||||
*wait_kmcnt,
|
||||
isa.v_lshlrev_b32_e32(v[1], 2, v[0]),
|
||||
# lds[thread_idx] = thread_idx
|
||||
isa.ds_store_b32(addr=v[1], data0=v[0]),
|
||||
*wait_dscnt,
|
||||
*barrier,
|
||||
# out[threaed_idx] = thread_idx == num_threads ? -1 : lds[thread_idx + 1]
|
||||
isa.v_add_nc_u32_e32(v[2], 4, v[1]),
|
||||
isa.v_cmp_gt_u32_e32(num_threads-1, v[0]),
|
||||
isa.ds_load_b32(vdst=v[3], addr=v[2]),
|
||||
*wait_dscnt,
|
||||
isa.v_mov_b32_e32(v[4], -1),
|
||||
isa.v_cndmask_b32_e32(v[5], v[4], v[3]),
|
||||
isa.v_lshlrev_b32_e32(v[6], 2, v[0]),
|
||||
*global_store,
|
||||
isa.s_endpgm(),
|
||||
]
|
||||
sink = UOp.sink(A.base, lds, threads, wg, arg=KernelInfo("custom_lds_sync"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device")
|
||||
class TestCustomKernel(unittest.TestCase):
|
||||
def setUp(self): self.arch = TARGET_TO_ARCH[Device["AMD"].arch]
|
||||
|
|
@ -83,9 +119,14 @@ class TestCustomKernel(unittest.TestCase):
|
|||
ei.run({"var":i})
|
||||
self.assertTrue((a.numpy() == 1+i).all())
|
||||
|
||||
def test_wave_sync(self):
|
||||
if self.arch not in {"rdna3", "rdna4"}: self.skipTest("only rdna3 or rdna4")
|
||||
Tensor.empty(1).custom_kernel(fxn=functools.partial(custom_wave_sync, arch=self.arch))[0].realize()
|
||||
def test_lds_sync(self):
|
||||
if self.arch not in ("rdna3", "rdna4"): self.skipTest("only rdna3/rdna4")
|
||||
a = Tensor.empty(128, dtype=dtypes.int32).contiguous().realize()
|
||||
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_lds_sync, arch=self.arch))[0]
|
||||
a.realize()
|
||||
ref = Tensor.arange(1, 129, dtype=dtypes.int32)
|
||||
ref[127] = -1
|
||||
self.assertListEqual(a.tolist(), ref.tolist())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -10,18 +10,11 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP
|
|||
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
|
||||
from tinygrad.renderer.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST,
|
||||
IMMEDIATE, IMMEDIATE_MASK, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4, PACKET_TYPES_CDNA, CDNA_WAVESTART,
|
||||
InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND, CDNA_INST)
|
||||
print_packets, CDNA_WAVEEND, CDNA_INST)
|
||||
from test.amd.helpers import TARGET_TO_ARCH
|
||||
|
||||
import tinygrad
|
||||
EXAMPLES_DIR = Path(tinygrad.__file__).parent.parent / "extra/sqtt/examples"
|
||||
# INST ops for non-traced SIMDs (excluded from instruction count)
|
||||
OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LDS_STORE_64, InstOp.OTHER_LDS_STORE_128,
|
||||
InstOp.OTHER_FLAT_LOAD, InstOp.OTHER_FLAT_STORE, InstOp.OTHER_FLAT_STORE_64, InstOp.OTHER_FLAT_STORE_96,
|
||||
InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR,
|
||||
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
|
||||
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
|
||||
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_5, InstOpRDNA4.OTHER_LDS_1, InstOpRDNA4.OTHER_LDS_2}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ROCPROF DECODER
|
||||
|
|
@ -183,11 +176,18 @@ class SQTTExamplesTestBase(unittest.TestCase):
|
|||
our_waves: list[tuple[int, int]] = []
|
||||
for event in events:
|
||||
wave_starts: dict[tuple[int, int, int], int] = {}
|
||||
first_timestamp:int|None = None
|
||||
for p in decode(event.blob):
|
||||
if first_timestamp is None: first_timestamp = p._time
|
||||
if isinstance(p, (WAVESTART, CDNA_WAVESTART, WAVESTART_RDNA4)): wave_starts[(p.wave, p.simd, p.cu)] = p._time
|
||||
elif isinstance(p, (WAVEEND, CDNA_WAVEEND)) and (key := (p.wave, p.simd, p.cu)) in wave_starts:
|
||||
our_waves.append((wave_starts[key], p._time))
|
||||
self.assertEqual(sorted(our_waves), sorted(roc_waves), f"wave times mismatch in {name}")
|
||||
for st in wave_starts.values():
|
||||
self.assertGreater(st, first_timestamp, "wave start must be after the first packet")
|
||||
# rocprof fails non deterministically and gives inaccurate timestamps.
|
||||
#self.assertEqual(sorted(our_waves), sorted(roc_waves), f"wave times mismatch in {name}")
|
||||
for st, et in our_waves:
|
||||
self.assertGreater(et, st, "wave end must be after start")
|
||||
|
||||
def test_rocprof_inst_times_match(self):
|
||||
"""Instruction times must match rocprof exactly (excluding s_endpgm)."""
|
||||
|
|
@ -200,8 +200,8 @@ class SQTTExamplesTestBase(unittest.TestCase):
|
|||
our_insts: list[int] = []
|
||||
for event in events:
|
||||
for p in decode(event.blob):
|
||||
if isinstance(p, INST) and p.op not in OTHER_SIMD_OPS: our_insts.append(p._time)
|
||||
elif isinstance(p, INST_RDNA4) and p.op not in OTHER_SIMD_OPS_RDNA4: our_insts.append(p._time)
|
||||
# INST ops for non-traced SIMDs (excluded from instruction count)
|
||||
if isinstance(p, (INST, INST_RDNA4)) and not p.op.name.startswith("OTHER_"): our_insts.append(p._time)
|
||||
elif isinstance(p, VALUINST): our_insts.append(p._time)
|
||||
elif isinstance(p, IMMEDIATE): our_insts.append(p._time)
|
||||
elif isinstance(p, IMMEDIATE_MASK):
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class TestSQTTMapBase(unittest.TestCase):
|
|||
if (p:=kern_events.get(event.kern)) is None: continue
|
||||
with self.subTest(example=name, kern=event.kern):
|
||||
# skip if there's no SQTT frequency data
|
||||
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)): continue
|
||||
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)[0]): continue
|
||||
if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue
|
||||
mean = sum(frequency) / len(frequency)
|
||||
variance = sum((v - mean) ** 2 for v in frequency) / len(frequency)
|
||||
|
|
@ -93,7 +93,7 @@ class TestSQTTMapBase(unittest.TestCase):
|
|||
elif "WAVE" in e.device:
|
||||
# sopk/immediates don't get ALU/MEM EXEC
|
||||
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL",
|
||||
"WAVEEND"}: insts += 1
|
||||
"WAVEEND", "WAVERDY"}: insts += 1
|
||||
else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}")
|
||||
self.assertEqual(execs, insts)
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class TestSQTTMapBase(unittest.TestCase):
|
|||
for name, (events, kern_events, target) in self.examples.items():
|
||||
for event in events:
|
||||
wave_barriers = {}
|
||||
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target):
|
||||
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target)[0]:
|
||||
if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e)
|
||||
if not wave_barriers: continue
|
||||
for row, events in wave_barriers.items():
|
||||
|
|
|
|||
|
|
@ -39,6 +39,18 @@ class TestJit(unittest.TestCase):
|
|||
def add(a, b): return (a+b).realize()
|
||||
_simple_test(add)
|
||||
|
||||
def test_jitbeam_triggers_beam(self):
|
||||
from unittest.mock import patch
|
||||
from tinygrad.helpers import getenv as _getenv
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
a, b = Tensor.ones(10, 10).contiguous().realize(), Tensor.ones(10, 10).contiguous().realize()
|
||||
with patch("tinygrad.codegen.opt.search.beam_search", wraps=lambda k,*a,**kw: k) as mock_beam:
|
||||
add(a, b)
|
||||
assert mock_beam.call_count == 0
|
||||
with patch("tinygrad.engine.jit.getenv", side_effect=lambda k, d=0: 1 if k == "JITBEAM" else _getenv(k, d)): add(a, b)
|
||||
assert mock_beam.call_count == 1
|
||||
|
||||
def test_simple_jit_reset(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
|
|
@ -648,25 +660,6 @@ class TestJitFree(unittest.TestCase):
|
|||
fxn(Tensor([2]))
|
||||
self.assertEqual(x.item(), 8)
|
||||
|
||||
def test_replan_buffers_memory_layout(self):
|
||||
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless")
|
||||
|
||||
ext_tensor = Tensor([1,24,23,45,1]).contiguous()
|
||||
ext_tensor_2 = Tensor([2,2,2,2,2]).contiguous()
|
||||
@TinyJit
|
||||
def fxn(x:Tensor):
|
||||
out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
|
||||
return out.sum()
|
||||
for i in range(5):
|
||||
out = fxn(Tensor([i,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 11400+200*i)
|
||||
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 4)
|
||||
fxn.captured.replan_buffers_memory_layout()
|
||||
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 2)
|
||||
|
||||
out = fxn(Tensor([11,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 13600)
|
||||
|
||||
class TestJitGraphSplit(unittest.TestCase):
|
||||
def compute(self, device, inp):
|
||||
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
|
||||
|
|
|
|||
|
|
@ -858,6 +858,8 @@ def _compile_sop(inst: ir3.SOP1|ir3.SOP2|ir3.SOPC|ir3.SOPK|ir4.SOP1|ir4.SOP2|ir4
|
|||
result = (hw_val >> offset) & mask
|
||||
return UOp.sink(ctx.wsgpr_dyn(sdst_off, result), *ctx.inc_pc())
|
||||
elif isinstance(inst, (ir3.SOP1, ir4.SOP1, irc.SOP1)):
|
||||
# S_BARRIER_SIGNAL: no-op in emulator, barrier sync handled by execution loop
|
||||
if isinstance(inst, ir4.SOP1) and inst.op in _BARRIER_SOP1_OPS: return UOp.sink(*ctx.inc_pc())
|
||||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)}
|
||||
|
|
@ -1960,6 +1962,8 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
|||
|
||||
_BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
|
||||
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT)
|
||||
_BARRIER_SOP1_OPS: set = set()
|
||||
if hasattr(ir4.SOP1Op, 'S_BARRIER_SIGNAL'): _BARRIER_SOP1_OPS.add(ir4.SOP1Op.S_BARRIER_SIGNAL)
|
||||
_BRANCH_OPS: set[int] = {op.value for op in (ir3.SOPPOp.S_BRANCH, ir3.SOPPOp.S_CBRANCH_SCC0, ir3.SOPPOp.S_CBRANCH_SCC1,
|
||||
ir3.SOPPOp.S_CBRANCH_VCCZ, ir3.SOPPOp.S_CBRANCH_VCCNZ, ir3.SOPPOp.S_CBRANCH_EXECZ, ir3.SOPPOp.S_CBRANCH_EXECNZ)}
|
||||
|
||||
|
|
@ -2084,7 +2088,8 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
|||
if pc not in program:
|
||||
prev_len = len(_canonical_runner_cache)
|
||||
runner, inst = _decode_at(pc, arch)
|
||||
is_barrier = isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS
|
||||
is_barrier = (isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS) or \
|
||||
(isinstance(inst, (ir4.SOP1,)) and inst.op in _BARRIER_SOP1_OPS)
|
||||
program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier, inst)
|
||||
if DEBUG >= 3:
|
||||
msg = f"[emu] PC={pc - lib}: {inst!r}"
|
||||
|
|
|
|||
|
|
@ -1,42 +1,73 @@
|
|||
import unittest
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.engine.memory import memory_plan_rewrite
|
||||
|
||||
global_map = {}
|
||||
held_bufs: set[UOp] = set()
|
||||
def b(i, base=None, offset=0, pin=False, size=16):
|
||||
global global_map
|
||||
if i in global_map: return global_map[i]
|
||||
global_map[i] = Buffer("NULL", size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset)
|
||||
if pin: global_map[i].ref(1)
|
||||
if base is not None:
|
||||
global_map[i] = global_map[base]
|
||||
return global_map[i]
|
||||
global_map[i] = UOp.new_buffer("NULL", size, dtypes.int8)
|
||||
if pin: held_bufs.add(global_map[i])
|
||||
return global_map[i]
|
||||
|
||||
def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]], copies:list[tuple[Buffer, Buffer]]|None=None):
|
||||
assigned = _internal_memory_planner(buffers, copies=copies)
|
||||
def _make_linear(buffer_lists, copies=None):
|
||||
copy_pairs = {frozenset((id(dst), id(src))) for dst, src in copies} if copies else set()
|
||||
calls = []
|
||||
for bufs in buffer_lists:
|
||||
is_copy = len(bufs) == 2 and frozenset((id(bufs[0]), id(bufs[1]))) in copy_pairs
|
||||
calls.append(UOp(Ops.CALL, dtypes.void, (UOp(Ops.COPY if is_copy else Ops.SINK), *bufs)))
|
||||
return UOp(Ops.LINEAR, src=tuple(calls))
|
||||
|
||||
taken_parts = set()
|
||||
def _get_arena(buf, linear, result):
|
||||
for orig_si, new_si in zip(linear.src, result.src):
|
||||
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
|
||||
if orig is buf and new.op is Ops.BUFFER_VIEW: return new.src[0]
|
||||
return None
|
||||
|
||||
def check_assign(buffer_lists, copies=None):
|
||||
linear = _make_linear(buffer_lists, copies)
|
||||
result = memory_plan_rewrite(linear, held_bufs)
|
||||
|
||||
# build mapping: original buf -> (arena, offset_bytes, nbytes) from the result
|
||||
replace_map: dict[int, tuple[UOp, int, int]] = {}
|
||||
for orig_si, new_si in zip(linear.src, result.src):
|
||||
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
|
||||
if new.op is Ops.BUFFER_VIEW and id(orig) not in replace_map:
|
||||
replace_map[id(orig)] = (new.src[0], new.arg[1] * new.dtype.itemsize, new.arg[0] * new.dtype.itemsize)
|
||||
|
||||
# verify pinned buffers are not planned
|
||||
for buf in held_bufs:
|
||||
assert id(buf) not in replace_map, "pinned buffer was planned"
|
||||
|
||||
# compute lifetimes
|
||||
first_appearance, last_appearance = {}, {}
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u:
|
||||
if buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0: continue
|
||||
if buf.base not in first_appearance: first_appearance[buf.base] = i
|
||||
last_appearance[buf.base] = i
|
||||
for i, bufs in enumerate(buffer_lists):
|
||||
for buf in bufs:
|
||||
if buf in held_bufs: continue
|
||||
if id(buf) not in first_appearance: first_appearance[id(buf)] = i
|
||||
last_appearance[id(buf)] = i
|
||||
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u:
|
||||
if buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0: continue
|
||||
cur, base = assigned.get(buf, buf), assigned.get(buf.base, buf.base)
|
||||
if buf._base is not None:
|
||||
assert cur.base == base.base and cur.offset == buf.offset + base.offset, f"failed: {buf} {cur} {base} {buf.offset} {base.offset}"
|
||||
else:
|
||||
for part in taken_parts:
|
||||
assert buf.base == part[3] or part[0] != cur.base or part[1] + part[2] <= cur.offset or part[1] >= cur.offset + buf.nbytes
|
||||
if first_appearance[buf.base] == i: taken_parts.add((cur.base, cur.offset, buf.nbytes, buf.base))
|
||||
if last_appearance[buf.base] == i: taken_parts.remove((cur.base, cur.offset, buf.nbytes, buf.base))
|
||||
# verify non-overlapping: no two live buffers share the same arena region
|
||||
taken_parts: set[tuple[int, int, int, int]] = set() # (id(arena), offset, nbytes, id(buf))
|
||||
for i, bufs in enumerate(buffer_lists):
|
||||
for buf in bufs:
|
||||
if buf in held_bufs or id(buf) not in replace_map: continue
|
||||
arena, off, nb = replace_map[id(buf)]
|
||||
for part in taken_parts:
|
||||
assert id(buf) == part[3] or part[0] != id(arena) or part[1] + part[2] <= off or part[1] >= off + nb, \
|
||||
f"overlap at step {i}: [{off}, {off+nb}) conflicts with [{part[1]}, {part[1]+part[2]})"
|
||||
if first_appearance.get(id(buf)) == i: taken_parts.add((id(arena), off, nb, id(buf)))
|
||||
if last_appearance.get(id(buf)) == i: taken_parts.discard((id(arena), off, nb, id(buf)))
|
||||
|
||||
class TestMemoryPlanner(unittest.TestCase):
|
||||
def setUp(self):
|
||||
global global_map
|
||||
held_bufs.clear()
|
||||
global_map = {}
|
||||
|
||||
def test_simple_buffer(self):
|
||||
|
|
@ -140,9 +171,11 @@ class TestMemoryPlanner(unittest.TestCase):
|
|||
[b(1), b(2)],
|
||||
[b(3), b(2)],
|
||||
]
|
||||
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))])
|
||||
r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2))
|
||||
assert r1.base != r2.base
|
||||
linear = _make_linear(bs, copies=[(b(1), b(0))])
|
||||
result = memory_plan_rewrite(linear)
|
||||
r1_arena, r2_arena = _get_arena(b(1), linear, result), _get_arena(b(2), linear, result)
|
||||
assert r1_arena is not None and r2_arena is not None
|
||||
assert r1_arena is not r2_arena
|
||||
|
||||
def test_copy_bufs_reuse_among_copies(self):
|
||||
bs = [
|
||||
|
|
@ -150,9 +183,11 @@ class TestMemoryPlanner(unittest.TestCase):
|
|||
[b(2), b(1)],
|
||||
[b(3), b(2)],
|
||||
]
|
||||
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0)), (b(2), b(1))])
|
||||
r1, r2 = assigned.get(b(1), b(1)), assigned.get(b(2), b(2))
|
||||
assert r1.base == r2.base
|
||||
linear = _make_linear(bs, copies=[(b(1), b(0)), (b(2), b(1))])
|
||||
result = memory_plan_rewrite(linear)
|
||||
r1_arena, r2_arena = _get_arena(b(1), linear, result), _get_arena(b(2), linear, result)
|
||||
assert r1_arena is not None and r2_arena is not None
|
||||
assert r1_arena is r2_arena
|
||||
|
||||
def test_compute_bufs_reuse_among_compute(self):
|
||||
bs = [
|
||||
|
|
@ -161,9 +196,11 @@ class TestMemoryPlanner(unittest.TestCase):
|
|||
[b(3), b(2)],
|
||||
[b(4), b(3)],
|
||||
]
|
||||
assigned = _internal_memory_planner(bs, copies=[(b(1), b(0))])
|
||||
r2, r3 = assigned.get(b(2), b(2)), assigned.get(b(3), b(3))
|
||||
assert r2.base == r3.base
|
||||
linear = _make_linear(bs, copies=[(b(1), b(0))])
|
||||
result = memory_plan_rewrite(linear)
|
||||
r2_arena, r3_arena = _get_arena(b(2), linear, result), _get_arena(b(3), linear, result)
|
||||
assert r2_arena is not None and r3_arena is not None
|
||||
assert r2_arena is r3_arena
|
||||
|
||||
def test_copy_and_compute_no_cross_reuse(self):
|
||||
bs = [
|
||||
|
|
@ -171,9 +208,11 @@ class TestMemoryPlanner(unittest.TestCase):
|
|||
[b(2), b(1)],
|
||||
[b(3), b(2)],
|
||||
]
|
||||
assigned = _internal_memory_planner(bs, copies=[(b(2), b(1))])
|
||||
r0, r2 = assigned.get(b(0), b(0)), assigned.get(b(2), b(2))
|
||||
assert r0.base != r2.base
|
||||
linear = _make_linear(bs, copies=[(b(2), b(1))])
|
||||
result = memory_plan_rewrite(linear)
|
||||
r0_arena, r2_arena = _get_arena(b(0), linear, result), _get_arena(b(2), linear, result)
|
||||
assert r0_arena is not None and r2_arena is not None
|
||||
assert r0_arena is not r2_arena
|
||||
|
||||
def test_multiple_copy_bufs_with_offsets(self):
|
||||
bs = [
|
||||
|
|
|
|||
|
|
@ -74,7 +74,10 @@ class TestRealWorld(unittest.TestCase):
|
|||
def test(t, t2):
|
||||
for l in model: t = l(t, t2)
|
||||
return t.realize()
|
||||
helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.0002, 37)
|
||||
|
||||
# TODO: support _offset on CL to get mem down to 0.0002
|
||||
exp_mem = 0.00037 if Device.DEFAULT == "CL" else 0.0002
|
||||
helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, exp_mem, 37)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
def test_llama(self):
|
||||
|
|
|
|||
|
|
@ -816,5 +816,58 @@ class TestUOpTags(unittest.TestCase):
|
|||
g = graph_rewrite(g, pm_plus_1)
|
||||
assert g.ssimplify() == 6
|
||||
|
||||
class TestUOpGetItem(unittest.TestCase):
|
||||
def _placeholder(self, shape, dtype=dtypes.half):
|
||||
return UOp.placeholder(shape, dtype, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
# full slices (no shrink)
|
||||
def test_full_slice(self):
|
||||
p = self._placeholder((64, 64))
|
||||
self.assertEqual(p[:, :].shape, (64, 64))
|
||||
def test_full_slice_explicit(self):
|
||||
p = self._placeholder((64, 64))
|
||||
self.assertEqual(p[0:64, 0:64].shape, (64, 64))
|
||||
|
||||
# partial slices (shrink)
|
||||
def test_shrink_cols(self):
|
||||
p = self._placeholder((64, 80))
|
||||
self.assertEqual(p[:, :64].shape, (64, 64))
|
||||
def test_shrink_rows(self):
|
||||
p = self._placeholder((80, 64))
|
||||
self.assertEqual(p[:64, :].shape, (64, 64))
|
||||
def test_shrink_both(self):
|
||||
p = self._placeholder((80, 80))
|
||||
self.assertEqual(p[:64, :64].shape, (64, 64))
|
||||
def test_shrink_start(self):
|
||||
p = self._placeholder((64, 64))
|
||||
self.assertEqual(p[8:, :].shape, (56, 64))
|
||||
def test_shrink_start_and_end(self):
|
||||
p = self._placeholder((64, 64))
|
||||
self.assertEqual(p[8:56, 4:60].shape, (48, 56))
|
||||
|
||||
# mixed slice and index
|
||||
def test_index_and_slice(self):
|
||||
p = self._placeholder((64, 80))
|
||||
r = UOp.range(64, 100)
|
||||
result = p[r, :64]
|
||||
self.assertEqual(result.shape, (64,))
|
||||
def test_slice_and_index(self):
|
||||
p = self._placeholder((80, 64))
|
||||
r = UOp.range(64, 100)
|
||||
result = p[:64, r]
|
||||
self.assertEqual(result.shape, (64,))
|
||||
def test_shrink_then_index(self):
|
||||
p = self._placeholder((64, 80))
|
||||
s = p[:, :64]
|
||||
r = UOp.range(64, 100)
|
||||
result = s[r]
|
||||
self.assertEqual(result.shape, (64,))
|
||||
|
||||
# integer index (no slice)
|
||||
def test_int_index(self):
|
||||
p = self._placeholder((64, 64))
|
||||
result = p[0]
|
||||
self.assertEqual(result.shape, (64,))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
|||
|
|
@ -286,12 +286,12 @@ pm_render = PatternMatcher([
|
|||
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
|
||||
# Where after gated load becomes alt value
|
||||
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
|
||||
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
|
||||
UPat.var("a")), lambda c,idx,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
|
||||
l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),),
|
||||
allow_any_len=True, name="l").or_casted()), lambda c,idx,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
|
||||
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
|
||||
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
|
||||
l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c").logical_not()).or_casted(),),
|
||||
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
|
||||
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
|
|
|||
|
|
@ -1,16 +1,26 @@
|
|||
from typing import TypeVar, Generic, Callable, cast, Any
|
||||
import functools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap
|
||||
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
|
||||
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops, buffers
|
||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.engine.memory import memory_plan_rewrite, _collect_bufs
|
||||
from tinygrad.engine.schedule import linear_to_schedule
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.schedule.rangeify import mop_cleanup
|
||||
from dataclasses import dataclass, replace
|
||||
from dataclasses import dataclass
|
||||
|
||||
def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]:
|
||||
kept, onetime = [], []
|
||||
for si in linear.src:
|
||||
si_bufs = {b for src in si.src[1:] for b in _collect_bufs(src)}
|
||||
if not si_bufs.isdisjoint(needed):
|
||||
kept.append(si)
|
||||
needed |= si_bufs
|
||||
else: onetime.append(si)
|
||||
return linear.replace(src=tuple(kept)), linear.replace(src=tuple(onetime))
|
||||
|
||||
class GraphException(Exception): pass
|
||||
class JitError(Exception): pass
|
||||
|
|
@ -180,7 +190,7 @@ class CapturedJit(Generic[ReturnType]):
|
|||
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
|
||||
|
||||
def __reduce__(self):
|
||||
# TODO: free_intermediates here? replan_buffers_memory_layout here?
|
||||
# TODO: free_intermediates here?
|
||||
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info)
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
@ -202,18 +212,12 @@ class CapturedJit(Generic[ReturnType]):
|
|||
def free_intermediates(self):
|
||||
depends: set[Buffer|None] = set([None])
|
||||
update_depends(depends, self.jit_cache)
|
||||
for b in depends:
|
||||
if b is not None:
|
||||
if b.is_allocated(): b.deallocate()
|
||||
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
|
||||
self.__post_init__() # reset the graph state
|
||||
|
||||
def replan_buffers_memory_layout(self):
|
||||
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
|
||||
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
|
||||
self.jit_cache = [replace(item, bufs=[asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
|
||||
for old, new in asgn.items():
|
||||
if old.is_allocated(): new.ensure_allocated().copyin(old.as_memoryview())
|
||||
arenas = {b._base for b in depends if b is not None and b._base is not None}
|
||||
to_free = {b for b in depends if b is not None} | {b for ei in self.jit_cache for b in ei.bufs if b is not None and b._base in arenas}
|
||||
for b in to_free:
|
||||
if hasattr(b, '_buf'): b.deallocate()
|
||||
for a in arenas:
|
||||
if a.allocated_views == 0 and a.is_allocated(): a.deallocate()
|
||||
self.__post_init__()
|
||||
|
||||
# jit exec
|
||||
|
|
@ -272,13 +276,12 @@ def _prepare_jit_inputs(args, kwargs):
|
|||
return input_buffers, var_vals, names, expected_input_info
|
||||
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False):
|
||||
assert fxn or captured, "need either a function or a CapturedJit"
|
||||
self.fxn = fxn
|
||||
self.captured: CapturedJit|None = captured
|
||||
self.cnt: int = 2 if self.fxn is None else 0
|
||||
self.prune = prune
|
||||
self.optimize = optimize
|
||||
|
||||
def add_linear(self, linear:UOp, var_vals:dict[str, int]): self._linears.append(linear)
|
||||
|
||||
|
|
@ -312,20 +315,32 @@ class TinyJit(Generic[ReturnType]):
|
|||
assert self.fxn is not None
|
||||
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
||||
self._linears: list[UOp] = []
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
capturing.append(self)
|
||||
try:
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(*params)
|
||||
finally: capturing.clear()
|
||||
capturing.append(self)
|
||||
try:
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(*params)
|
||||
finally: capturing.clear()
|
||||
if not len(self._linears): raise JitError("didn't JIT anything!")
|
||||
_check_no_non_tensor_return(ret)
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs")
|
||||
|
||||
# combine all captured linears into one and convert to ExecItems
|
||||
jit_cache = [ei.lower() for ei in linear_to_schedule(UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears]))))]
|
||||
# combine all captured linears into one, memory plan, and convert to ExecItems
|
||||
big_linear = UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears])))
|
||||
del self._linears
|
||||
|
||||
if self.prune:
|
||||
big_linear, onetime_linear = prune_linear(big_linear, {k for k,v in buffers.items() if isinstance(v, Buffer) and v in set(input_buffers)})
|
||||
if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels")
|
||||
for ei in (si.lower() for si in linear_to_schedule(onetime_linear)):
|
||||
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
||||
ei.run(var_vals, jit=True)
|
||||
del onetime_linear
|
||||
|
||||
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
jit_cache = [ei.lower() for ei in linear_to_schedule(memory_plan_rewrite(big_linear, held_bufs))]
|
||||
del big_linear
|
||||
|
||||
# track inputs that are views of buffers
|
||||
# TODO: eventually expected_buffers should live in ExecItem
|
||||
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
|
||||
|
|
@ -335,25 +350,6 @@ class TinyJit(Generic[ReturnType]):
|
|||
input_buffers.append(b)
|
||||
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
||||
|
||||
# prune independent kernels (optional)
|
||||
if self.prune:
|
||||
depends = set(input_buffers)
|
||||
update_depends(depends, jit_cache)
|
||||
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
|
||||
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
||||
# sync before re-executing onetime kernels
|
||||
for dev in set(Device[b.device] for ei in onetime for b in ei.bufs if b is not None): dev.synchronize()
|
||||
# run the onetime kernels here
|
||||
for ei in onetime:
|
||||
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
||||
ei.run(var_vals, jit=True)
|
||||
jit_cache = pruned
|
||||
|
||||
# memory planning (optional)
|
||||
copies = [(cast(Buffer,ji.bufs[0]),cast(Buffer,ji.bufs[1])) for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec))]
|
||||
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], copies, debug_prefix="JIT ")
|
||||
jit_cache = [replace(item, bufs=[assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
|
||||
|
||||
input_replace = get_input_replace(jit_cache, input_buffers)
|
||||
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
||||
|
||||
|
|
@ -361,7 +357,6 @@ class TinyJit(Generic[ReturnType]):
|
|||
for ei in jit_cache: ei.run(var_vals)
|
||||
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
|
||||
if self.optimize: self.captured.replan_buffers_memory_layout()
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.captured is not None
|
||||
|
|
|
|||
|
|
@ -1,77 +1,65 @@
|
|||
from typing import cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER, DEBUG, round_up
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.runtime.support.memory import TLSFAllocator
|
||||
|
||||
def _collect_bufs(u:UOp) -> list[UOp]:
|
||||
if u.op is Ops.BUFFER: return [u]
|
||||
if u.op in {Ops.MSELECT, Ops.MSTACK}: return [b for s in u.src for b in _collect_bufs(s)]
|
||||
return []
|
||||
|
||||
def _can_plan(b:UOp, held_bufs:set[UOp]) -> bool:
|
||||
if b in held_bufs: return False
|
||||
devs = (b.device,) if isinstance(b.device, str) else b.device
|
||||
return all(not d.startswith(("DISK", "TINYFS")) and hasattr(Device[d].allocator, "_offset") for d in devs)
|
||||
|
||||
LaneKey = tuple[str, int]
|
||||
|
||||
# **************** memory planning ****************
|
||||
def memory_plan_rewrite(linear:UOp, held_bufs:set[UOp]|None=None) -> UOp:
|
||||
if NO_MEMORY_PLANNER: return linear
|
||||
if held_bufs is None: held_bufs = set()
|
||||
|
||||
def _internal_memory_planner(buffers:list[list[Buffer]], copies:list[tuple[Buffer, Buffer]]|None=None,
|
||||
ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
|
||||
if NO_MEMORY_PLANNER: return {}
|
||||
first_appearance, last_appearance, buf_to_opt = {}, {}, set()
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u:
|
||||
if not ignore_checks and (buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0): continue
|
||||
if buf.base not in first_appearance: first_appearance[buf.base] = i
|
||||
last_appearance[buf.base] = i
|
||||
buf_to_opt.add(buf)
|
||||
# compute lifetimes for all plannable internal buffers
|
||||
first_appearance:dict[UOp, int] = {}
|
||||
last_appearance:dict[UOp, int] = {}
|
||||
copy_bufs: set[UOp] = set()
|
||||
for i, si in enumerate(linear.src):
|
||||
si_bufs = [b for src in si.src[1:] for b in _collect_bufs(src) if _can_plan(b, held_bufs)]
|
||||
for b in si_bufs:
|
||||
if b not in first_appearance: first_appearance[b] = i
|
||||
last_appearance[b] = i
|
||||
if si.src[0].op is Ops.COPY: copy_bufs.update(si_bufs)
|
||||
if not first_appearance: return linear
|
||||
|
||||
# Separate copy and compute buffers into different lanes and defer cross-queue frees to avoid introducing dependencies (copy->compute->copy)
|
||||
copy_dsts, copy_srcs = ({dst.base for dst,_ in copies}, {src.base for _,src in copies}) if copies else (set(), set())
|
||||
def _key(buf) -> LaneKey: return (buf.device, 1 if buf in copy_dsts or buf in copy_srcs else 0)
|
||||
buf_hold = {buf: last_appearance[buf] - first_appearance[buf] + 1 for buf in first_appearance if buf in copy_dsts or buf in copy_srcs}
|
||||
# separate copy and compute buffers into different lanes to avoid introducing dependencies (copy->compute->copy)
|
||||
def _key(b:UOp): return (b.device, 1 if b in copy_bufs else 0)
|
||||
buf_hold = {b: last_appearance[b] - first_appearance[b] + 1 for b in first_appearance if b in copy_bufs}
|
||||
|
||||
# Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed.
|
||||
buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \
|
||||
[((last_appearance[buf] + 1 + buf_hold.get(buf, 0), False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
|
||||
total_memory = sum(round_up(buf.nbytes, BLK:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%)
|
||||
# suballocation: build sorted open/close events, then alloc/free in order
|
||||
block_size = 256
|
||||
nbytes = {b: round_up(b.arg * b.dtype.itemsize, block_size) for b in first_appearance}
|
||||
events = sorted([(first_appearance[b], True, b) for b in first_appearance] +
|
||||
[(last_appearance[b] + 1 + buf_hold.get(b, 0), False, b) for b in first_appearance], key=lambda x: (x[0], x[1]))
|
||||
total_memory = sum(nbytes.values()) * 2
|
||||
|
||||
# Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator.
|
||||
# Also track buffer replacements for buffers that do not support suballocation.
|
||||
buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {}
|
||||
reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list)
|
||||
global_planner:dict[LaneKey, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=BLK, lv2_cnt=32)))
|
||||
for (_, is_open_ev), buf in buffer_requests:
|
||||
# Check if suballocation is possible for the given buffer and device.
|
||||
if hasattr(Device[buf.device].allocator, "_offset"):
|
||||
if is_open_ev: buffer_replace[buf] = (None, global_planner[_key(buf)][1].alloc(round_up(buf.nbytes, BLK)))
|
||||
else: global_planner[_key(buf)][1].free(cast(int, buffer_replace[buf][1]))
|
||||
global_planner[_key(buf)] = (max(global_planner[_key(buf)][0], buffer_replace[buf][1] + buf.nbytes), global_planner[_key(buf)][1])
|
||||
else:
|
||||
key = (_key(buf), buf.dtype, buf.options, buf.nbytes)
|
||||
if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None)
|
||||
else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0]))
|
||||
offsets:dict[UOp, int] = {}
|
||||
peaks:dict[LaneKey, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=block_size, lv2_cnt=32)))
|
||||
for _, is_open, buf in events:
|
||||
if is_open: offsets[buf] = peaks[_key(buf)][1].alloc(nbytes[buf])
|
||||
else: peaks[_key(buf)][1].free(offsets[buf])
|
||||
peaks[_key(buf)] = (max(peaks[_key(buf)][0], offsets[buf] + buf.arg * buf.dtype.itemsize), peaks[_key(buf)][1])
|
||||
arena_sizes = {key: round_up(peak, block_size) for key, (peak, _) in peaks.items()}
|
||||
|
||||
# Allocate global buffers based on the memory planner.
|
||||
global_buffers = {key: Buffer(key[0], round_up(sz, BLK), dtypes.int8) for key, (sz, _) in global_planner.items()}
|
||||
buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[_key(buf)], off) for buf,(base,off) in buffer_replace.items()}
|
||||
# build replace_map: each buffer becomes a BUFFER_VIEW into a shared per-device-lane arena
|
||||
arenas = {key: UOp.new_buffer(key[0], sz, dtypes.int8) for key, sz in arena_sizes.items()}
|
||||
replace_map:dict[UOp, UOp] = {}
|
||||
for buf_uop, offset in offsets.items():
|
||||
assert offset % buf_uop.dtype.itemsize == 0, f"offset {offset} not aligned to {buf_uop.dtype.itemsize}"
|
||||
replace_map[buf_uop] = UOp(Ops.BUFFER_VIEW, buf_uop.dtype, (arenas[_key(buf_uop)],), (buf_uop.arg, offset // buf_uop.dtype.itemsize))
|
||||
|
||||
# Assign buffers. First, assign full buffers (not sub-buffers).
|
||||
assigned:dict[Buffer, Buffer] = {}
|
||||
for buf, (base, off) in buffer_resolve.items():
|
||||
if buf != base:
|
||||
assigned[buf] = base if off is None else Buffer(buf.device, buf.size, buf.dtype, base=base, offset=off)
|
||||
if DEBUG >= 1 and (omem:=sum(nbytes.values()) / 1e6) != (nmem:=sum(arena_sizes.values()) / 1e6):
|
||||
print(f"memory reduced from {omem:.2f} MB -> {nmem:.2f} MB, {len(first_appearance)} -> {len(arenas)} bufs")
|
||||
|
||||
# Now assign sub-buffers.
|
||||
for buf in buf_to_opt:
|
||||
if buf._base is not None:
|
||||
assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=(pbuf:=assigned.get(buf.base, buf.base)).base, offset=pbuf.offset+buf.offset)
|
||||
|
||||
if DEBUG >= 1:
|
||||
ak, av = dedup(x for x in assigned.keys() if x._base is None),dedup(x for x in assigned.values() if x._base is None)+list(global_buffers.values())
|
||||
omem, nmem = sum([x.nbytes for x in ak])/1e6, sum([x.nbytes for x in av])/1e6
|
||||
if omem != nmem: print(f"{debug_prefix}memory reduced from {omem:.2f} MB -> {nmem:.2f} MB,", f"{len(ak)} -> {len(av)} bufs")
|
||||
|
||||
return assigned
|
||||
|
||||
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
|
||||
copies=[(cast(Buffer,si.bufs[0]),cast(Buffer,si.bufs[1])) for si in schedule if si.ast.op is Ops.COPY])
|
||||
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]
|
||||
return linear.substitute(replace_map, name="memory plan", walk=True)
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
|||
si_lowerer = PatternMatcher([
|
||||
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
(UPat(Ops.COPY), lambda ctx: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="cf"), lambda ctx,cf: EncDec(cf, ctx[0].nbytes, ctx[0].device)),
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]:
|
|||
schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata))
|
||||
return schedule
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.memory import memory_plan_rewrite
|
||||
from tinygrad.engine.realize import capturing
|
||||
from tinygrad.schedule.rangeify import get_kernel_graph
|
||||
from tinygrad.helpers import CAPTURING
|
||||
|
|
@ -163,7 +163,9 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], di
|
|||
capturing[0].add_linear(linear, var_vals)
|
||||
return [], var_vals
|
||||
|
||||
held_bufs = ({b for b in linear_call.src[1:] if b.op is Ops.BUFFER} if linear_call.op is Ops.CALL else set())
|
||||
linear = memory_plan_rewrite(linear, held_bufs)
|
||||
|
||||
# convert LINEAR to ExecItems
|
||||
schedule: list[ExecItem] = linear_to_schedule(linear)
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
return schedule, var_vals
|
||||
|
|
|
|||
|
|
@ -677,7 +677,7 @@ class ElementwiseMixin(DTypeMixin):
|
|||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
|
||||
```
|
||||
"""
|
||||
return self.ne(0).where((self < 0).where(self.const_like(-1), self.const_like(1)), self.const_like(0)) + self * 0
|
||||
return self.ne(0).where((self < 0).where(self.const_like(-1), self.const_like(1)), self.const_like(0))
|
||||
|
||||
def abs(self) -> Self:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -105,14 +105,25 @@ class InstOpRDNA4(Enum):
|
|||
"""SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3."""
|
||||
SALU = 0x0
|
||||
SMEM = 0x1
|
||||
SMEM_WR = 0x2
|
||||
JUMP = 0x3
|
||||
JUMP_NO = 0x4
|
||||
CALL = 0x5
|
||||
SALU_NO_EXEC = 0x7
|
||||
MESSAGE = 0x9
|
||||
VALU_1 = 0xa
|
||||
VALU_TRANS = 0xb
|
||||
VALU_B1 = 0xc
|
||||
VALU_B2 = 0xd
|
||||
VALU_B4 = 0xe
|
||||
VALU_B16 = 0xf
|
||||
VINTERP = 0x12
|
||||
BARRIER_WAIT = 0x13
|
||||
FLAT_RD_2 = 0x1c
|
||||
FLAT_WR_3 = 0x1d
|
||||
FLAT_WR_4 = 0x1e
|
||||
FLAT_WR_5 = 0x1f
|
||||
FLAT_WR_6 = 0x20
|
||||
VMEM_RD_1 = 0x21
|
||||
VMEM_RD_2 = 0x22
|
||||
VMEM_WR_1 = 0x23
|
||||
|
|
@ -127,18 +138,45 @@ class InstOpRDNA4(Enum):
|
|||
LDS_WR_3 = 0x2c
|
||||
LDS_WR_4 = 0x2d
|
||||
LDS_WR_5 = 0x2e
|
||||
BUF_RD_1 = 0x2f
|
||||
BUF_RD_2 = 0x30
|
||||
BUF_WR_1 = 0x31
|
||||
BUF_WR_2 = 0x32
|
||||
BUF_WR_3 = 0x33
|
||||
BUF_WR_4 = 0x34
|
||||
BUF_WR_5 = 0x35
|
||||
BUF_WR_6 = 0x36
|
||||
OTHER_LDS_1 = 0x50
|
||||
OTHER_LDS_2 = 0x51
|
||||
OTHER_LDS_3 = 0x52
|
||||
OTHER_LDS_4 = 0x53
|
||||
OTHER_LDS_5 = 0x54
|
||||
OTHER_FLAT_2 = 0x55
|
||||
OTHER_FLAT_3 = 0x56
|
||||
OTHER_FLAT_4 = 0x57
|
||||
OTHER_FLAT_5 = 0x58
|
||||
OTHER_FLAT_6 = 0x59
|
||||
LDS_DIR_LOAD = 0x6e
|
||||
LDS_PARAM_LOAD = 0x6f
|
||||
SALU_WR_EXEC = 0x72
|
||||
VALU1_WR_EXEC = 0x73
|
||||
VALU_B2_WR_EXEC = 0x74
|
||||
OTHER_LDS_6 = 0x77
|
||||
OTHER_LDS_10 = 0x78
|
||||
BARRIER_SIGNAL = 0x7a
|
||||
DYN_VGPR = 0x87
|
||||
BARRIER_JOIN = 0x8a
|
||||
WMMA_8 = 0x8c
|
||||
WMMA_16 = 0x8d
|
||||
WMMA_32 = 0x8e
|
||||
WMMA_64 = 0x8f
|
||||
VALU_DPFP = 0x92
|
||||
SALU_FLOAT3 = 0x98
|
||||
VALU_SCL_TRANS = 0x99
|
||||
SALU_2 = 0x9b
|
||||
SALU_5 = 0x9c
|
||||
OTHER_VMEM = 0xbd
|
||||
OTHER_VMEM_5 = 0xc1
|
||||
OTHER_VMEM = 0xbc # 0xbc-0xdd: vmem_other_simd
|
||||
for _i in range(34): InstOpRDNA4._value2member_map_[0xbc + _i] = InstOpRDNA4.OTHER_VMEM
|
||||
|
||||
class InstOpCDNA(Enum):
|
||||
SMEM_RD = 0
|
||||
|
|
@ -650,8 +688,8 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
|
|||
yield (p, InstructionInfo(pc, wave, inst))
|
||||
elif isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)):
|
||||
inst = pc_map[pc:=wave_pc[p.wave]]
|
||||
# s_delay_alu and s_wait_alu instructions are skipped
|
||||
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU"}:
|
||||
# s_delay_alu, s_wait_alu and s_barrier_wait instructions are skipped
|
||||
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU", "S_BARRIER_WAIT"}:
|
||||
wave_pc[p.wave] += inst.size()
|
||||
inst = pc_map[pc:=wave_pc[p.wave]]
|
||||
# assert branch always has a JUMP packet
|
||||
|
|
|
|||
|
|
@ -487,8 +487,8 @@ class AMDHIPRenderer(CStyleLanguage):
|
|||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]},"
|
||||
f" {fp8_index(x.src[0].dtype)}, {fp8_index(x.src[0].dtype)}, 0, 0, 0, 0)" if x.arg[1][2] == 128 else None),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}, 0, 0, 0)"),
|
||||
(UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",),
|
||||
lambda ctx,x,y: f"f32_to_fp8({ctx[x.src[0]]}, {fp8_index(x.dtype)})"),
|
||||
(UPat(Ops.CAST, dtypes.fp8s, (UPat(dtype=dtypes.float),), name="x",),
|
||||
lambda ctx,x: f"f32_to_fp8({ctx[x.src[0]]}, {fp8_index(x.dtype)})"),
|
||||
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",),
|
||||
lambda ctx,x,y: f"__builtin_amdgcn_cvt_f32_{('fp8', 'bf8')[fp8_index(y.dtype)]}((unsigned int){ctx[x.src[0]]}, 0)"),
|
||||
]) + base_rewrite
|
||||
|
|
|
|||
|
|
@ -224,17 +224,17 @@ class AMDLLVMRenderer(LLVMRenderer):
|
|||
(UPat(tuple(llvm_intrinsics), name="x"),
|
||||
lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BARRIER), lambda ctx: barrier),
|
||||
(UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",), lambda ctx,x,y:
|
||||
(UPat(Ops.CAST, dtypes.fp8s, (UPat(dtype=dtypes.float),), name="x",), lambda ctx,x:
|
||||
f" {ctx[x]} = call i8 @f32_to_fp8({ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i1 {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"),
|
||||
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",), lambda ctx,x,y:
|
||||
f" {ctx[x.src[0]]}_i32 = zext i8 {ctx[x.src[0]]} to i32\n"
|
||||
f" {ctx[x]} = call float @llvm.amdgcn.cvt.f32.{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}(i32 {ctx[x.src[0]]}_i32, i32 0)"),
|
||||
]) + base_rewrite
|
||||
extra_matcher = LLVMRenderer.extra_matcher + create_non_native_float_pats(dtypes.fp8s) + PatternMatcher([
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
|
||||
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
|
||||
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
|
||||
lambda y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
|
||||
lambda y: UOp(Ops.VECTORIZE, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
|
||||
# amd llvm intrinsics llvm.log2/llvm.exp2 don't support double
|
||||
(UPat(Ops.LOG2, dtype=dtypes.double, src=(UPat.var("d"),)), xlog2),
|
||||
(UPat(Ops.EXP2, dtype=dtypes.double, src=(UPat.var("d"),)), xexp2),
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0
|
|||
|
||||
# alu ops, aop[<dtype>][<op>]
|
||||
u_aop = { Ops.ADD: "iadd", Ops.MUL: "imul", Ops.IDIV: "udiv", Ops.MOD: "umod", Ops.CMPLT: "ult", Ops.CMPNE: "ine", Ops.CMPEQ: "ieq", Ops.OR: "ior",
|
||||
Ops.AND: "iand", Ops.XOR: "ixor", Ops.WHERE: "bcsel", Ops.MAX: "umax"}
|
||||
s_aop = {**u_aop, Ops.CMPLT: "ilt", Ops.IDIV: "idiv", Ops.MOD: "irem", Ops.MAX: "imax"}
|
||||
Ops.AND: "iand", Ops.XOR: "ixor", Ops.WHERE: "bcsel", Ops.MAX: "umax", Ops.SHL: "ishl", Ops.SHR: "ushr"}
|
||||
s_aop = {**u_aop, Ops.CMPLT: "ilt", Ops.IDIV: "idiv", Ops.MOD: "irem", Ops.MAX: "imax", Ops.SHR: "ishr"}
|
||||
f_aop = { Ops.ADD: "fadd", Ops.MUL: "fmul", Ops.CMPLT: "flt", Ops.CMPNE: "fneu", Ops.CMPEQ: "feq", Ops.FDIV: "fdiv", Ops.RECIPROCAL: "frcp",
|
||||
Ops.MAX: "fmax", Ops.TRUNC: "ftrunc", Ops.SIN: "fsin", Ops.EXP2: "fexp2", Ops.LOG2: "flog2"}
|
||||
aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dtypes.sints}, **{x:f_aop for x in dtypes.floats}}
|
||||
|
|
@ -130,6 +130,8 @@ class NIRRenderer(Renderer):
|
|||
lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: x.replace(src=x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
|
||||
# NIR requires shift amount to be 32 bit: https://docs.mesa3d.org/nir/alu.html#nir-alu-op-ishl
|
||||
(UPat((Ops.SHL, Ops.SHR), name="x"), lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint))) if x.src[1].dtype.bitsize != 32 else None),
|
||||
# OpConvertFToU is undefined if Result Type is not wide enough, cast through int32
|
||||
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
|
||||
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
|
||||
|
|
@ -144,8 +146,8 @@ class NIRRenderer(Renderer):
|
|||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val")), allow_any_len=True, name="x"),
|
||||
lambda ctx,x,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val")), allow_any_len=True),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
|
||||
|
|
|
|||
|
|
@ -317,7 +317,7 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
|
|||
# copy on CONST is CONST
|
||||
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
|
||||
# hack if a noop turned to a const
|
||||
(UPat(Ops.NOOP, src=(UPat.cvar("c"),), name="noop"), lambda c,noop: c),
|
||||
(UPat(Ops.NOOP, src=(UPat.cvar("c"),)), lambda c: c),
|
||||
# mstack on CONST is CONST
|
||||
(UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True),
|
||||
lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None),
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from contextlib import ContextDecorator
|
|||
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
if TYPE_CHECKING: import numpy
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid, InvalidType
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||
from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
|
|
@ -134,13 +134,7 @@ class Tensor(OpMixin):
|
|||
if isinstance(data, UOp):
|
||||
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}"
|
||||
# if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||
# TODO: remove this and stay in weakint
|
||||
if data.dtype==dtypes.weakint: data = _index_to_concrete_int(data)
|
||||
if data.op is Ops.BIND:
|
||||
var, val = data.unbind()
|
||||
# give the bound constant a device
|
||||
const = UOp.const(var.dtype, val, _device, ())
|
||||
data = data.replace(src=(var.replace(src=const.src), const))
|
||||
if data.dtype == dtypes.weakint: data = Tensor.from_uop(data, device=_device).uop
|
||||
elif data is None:
|
||||
data = UOp.const(_dtype or dtypes.default_float, 0, _device)
|
||||
elif isinstance(data, get_args(ConstType)):
|
||||
|
|
@ -503,7 +497,13 @@ class Tensor(OpMixin):
|
|||
|
||||
@staticmethod
|
||||
def from_uop(y:UOp, **kwargs) -> Tensor:
|
||||
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)
|
||||
# TODO: remove this and stay in weakint
|
||||
if y.dtype == dtypes.weakint: y = _index_to_concrete_int(y)
|
||||
if y.op is Ops.BIND:
|
||||
var, val = y.unbind()
|
||||
_device = canonicalize_device(kwargs.get("device"))
|
||||
const = UOp.const(var.dtype, val, _device, ())
|
||||
return Tensor(y.replace(src=(var.replace(src=const.src), const)), **kwargs, requires_grad=False)
|
||||
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
||||
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
||||
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
||||
|
|
@ -1822,7 +1822,7 @@ class Tensor(OpMixin):
|
|||
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
||||
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
||||
denominator = prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
||||
return numerator.div(Tensor.from_uop(denominator, device=numerator.device) if isinstance(denominator, UOp) else denominator).cast(output_dtype)
|
||||
return numerator.div(denominator).cast(output_dtype)
|
||||
|
||||
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
|
||||
"""
|
||||
|
|
@ -1848,7 +1848,8 @@ class Tensor(OpMixin):
|
|||
"""
|
||||
squares = (self - self.mean(axis=axis, keepdim=True)).square()
|
||||
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
||||
denominator = (Tensor.from_uop(n, device=self.device) if isinstance(n, UOp) else Tensor(n, device=self.device)) - correction
|
||||
denominator = Tensor(n, device=self.device) - correction
|
||||
# TODO: infer device and remove relu
|
||||
return squares.sum(axis=axis, keepdim=keepdim).div(denominator.relu())
|
||||
|
||||
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
|
||||
|
|
@ -2953,10 +2954,8 @@ class Tensor(OpMixin):
|
|||
if not isinstance(y, Tensor):
|
||||
# make y a Tensor
|
||||
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
|
||||
if y is Invalid or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
||||
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
|
||||
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
|
||||
else: y = Tensor(y_dtype.const(y), x.device, y_dtype, requires_grad=False)
|
||||
y_dtype = x.dtype if dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, (int, InvalidType))) else None
|
||||
y = Tensor(y, x.device, y_dtype, requires_grad=False)
|
||||
|
||||
if match_dtype and x.dtype != y.dtype:
|
||||
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
||||
|
|
@ -2993,7 +2992,7 @@ class Tensor(OpMixin):
|
|||
a, b = self._broadcasted(x, reverse)
|
||||
return a + (-b)
|
||||
|
||||
def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
|
||||
def div(self, x:Tensor|ConstType|UOp, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor:
|
||||
"""
|
||||
Divides `self` by `x`.
|
||||
Equivalent to `self / x`.
|
||||
|
|
|
|||
|
|
@ -431,13 +431,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if len(idx) < len(self.shape): idx += tuple([slice(None)]*(len(self.shape)-len(idx)))
|
||||
assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
|
||||
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
|
||||
perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx))
|
||||
# apply SHRINK for slices that aren't the full range
|
||||
bounds = tuple((s.start or 0, s.stop if s.stop is not None else self.shape[i]) if isinstance(s, slice) else (0, self.shape[i])
|
||||
for i, s in enumerate(idx))
|
||||
src = self if all(b == (0, self.shape[i]) for i, b in enumerate(bounds)) else self.shrink(bounds)
|
||||
perm = src.permute(tuple([i for i in range(src.ndim) if i not in slice_idx] + slice_idx))
|
||||
return perm.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
|
||||
else:
|
||||
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
|
||||
return UOp.const(self.dtype.base, b, device=self._device, shape=self._shape)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.vcount == 1
|
||||
if count == 1: return self
|
||||
|
|
|
|||
|
|
@ -69,9 +69,9 @@ shared_spec = PatternMatcher([
|
|||
# ***** UOp spec in the Tensor graph *****
|
||||
|
||||
movement_ops = PatternMatcher([
|
||||
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint))), lambda mv,x: True),
|
||||
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint))), lambda mv,x: True),
|
||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
||||
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat(dtype=dtypes.weakint))), lambda: True),
|
||||
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint))), lambda: True),
|
||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
|
||||
|
||||
# inputs to movement ops
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.weakint), lambda: True),
|
||||
|
|
@ -213,7 +213,7 @@ kernel_spec = PatternMatcher([
|
|||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),
|
||||
|
||||
# bufferize can be on anything
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
|
||||
# reduce must be on ranges
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.weakint, dtypes.int) for y in x.src[1:])),
|
||||
|
|
@ -286,7 +286,7 @@ full_spec = PatternMatcher([
|
|||
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint), (UPat(), UPat()), arg=None), lambda: True),
|
||||
|
||||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK)), lambda: True),
|
||||
|
||||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.VECTORIZE), lambda: True),
|
||||
|
|
|
|||
|
|
@ -28,19 +28,19 @@ z3_renderer = PatternMatcher([
|
|||
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
|
||||
(UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx:
|
||||
create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
|
||||
(UPat((Ops.LOAD, Ops.INDEX), dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
|
||||
(UPat((Ops.LOAD, Ops.INDEX), dtypes.bool), lambda ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
|
||||
# constants
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x,ctx: (z3.Int("Invalid", ctx=ctx[0]), None)),
|
||||
(UPat(Ops.CONST, arg=Invalid), lambda ctx: (z3.Int("Invalid", ctx=ctx[0]), None)),
|
||||
(UPat(Ops.CONST, dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)),
|
||||
(UPat(Ops.CONST, dtypes.bool, name="x"), lambda x,ctx: (z3.BoolVal(x.arg, ctx=ctx[0]), None)),
|
||||
# casts from floats create new variables
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx:
|
||||
create_bounded(f"cast{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
|
||||
# A comparison between floats introduces a new bool variable
|
||||
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0]), None)),
|
||||
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats)), lambda ctx: (z3.Bool(f"float_cmp{len(ctx[1])}", ctx=ctx[0]), None)),
|
||||
# casts from bool/int to int/bool
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,),src=(UPat.var("x", dtypes.bool),), name="c"), lambda x,c,ctx: (z3.If(ctx[1][x], 1, 0), None)),
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat.var("x", dtypes.ints+(dtypes.weakint,)),), name="c"), lambda x,c,ctx: (ctx[1][x], None)),
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,),src=(UPat.var("x", dtypes.bool),)), lambda x,ctx: (z3.If(ctx[1][x], 1, 0), None)),
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.weakint,), src=(UPat.var("x", dtypes.ints+(dtypes.weakint,)),)), lambda x,ctx: (ctx[1][x], None)),
|
||||
(UPat(Ops.CAST, dtypes.bool, name="x"), lambda x,ctx: (ctx[1][x.src[0]]!=0, None)),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x,ctx: (z3_alu[x.op](*(ctx[1][s] for s in x.src)), None)),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -315,6 +315,7 @@ function setFocus(key) {
|
|||
p.append("span").style("margin-left", "8px").style("color", "#f0f0f566").text(formatTime(e.x));
|
||||
tableData.push(["Cycle", formatTime(e.x-data.instSt)], ["Time", p.node()]);
|
||||
} else tableData.push(["Start Time", formatTime(e.x)]);
|
||||
if (data.link != null) tableData.push(["Delay", `${formatTime(Math.abs(selectShape(data.link[0]).e.x - selectShape(data.link[1]).e.x))} Cycles`]);
|
||||
html.append(() => tabulate(tableData));
|
||||
let group = html.append("div").classed("args", true);
|
||||
for (const r of rest) group.append("p").text(r);
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
|
|||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
||||
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
|
||||
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
||||
def sqtt_timeline(data:bytes, lib:bytes, target:str, max_pkts=getenv("MAX_SQTT_PKTS",50_000)) -> tuple[list[ProfileEvent], bool]:
|
||||
from tinygrad.renderer.amd.sqtt import (map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC,
|
||||
ALUEXEC, INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA,
|
||||
WAVEEND, CDNA_WAVEEND, WAVERDY)
|
||||
|
|
@ -365,8 +365,11 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
|||
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None:
|
||||
exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}")
|
||||
if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): e.name = TracingKey(op or name, ret=f"LINK:{exec_pending[name].pop(0)}")
|
||||
has_more = False
|
||||
for p, info in map_insts(data, lib, target):
|
||||
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
|
||||
if len(ret) > max_pkts:
|
||||
has_more = True
|
||||
break
|
||||
if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker:
|
||||
pair = (p._time, p.delta)
|
||||
if prev_pair is None: prev_pair = pair
|
||||
|
|
@ -393,7 +396,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
|||
else:
|
||||
add(name.replace("_ALT", ""), p, op=name)
|
||||
pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()}
|
||||
return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret
|
||||
return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in row_ends]+ret, has_more
|
||||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
|
|
@ -619,7 +622,7 @@ def get_render(query:str) -> dict:
|
|||
if fmt.startswith("prg-pkts"):
|
||||
ret = {}
|
||||
with soft_err(lambda err:ret.update(err)):
|
||||
if (events:=get_profile(sqtt_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
if (events:=get_profile(sqtt_timeline(*data)[0], sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
else: ret = {"src":"No SQTT trace on this SE."}
|
||||
return ret
|
||||
if fmt == "prg-sqtt":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue