Merge remote-tracking branch 'upstream/master' into new_x86_backend

This commit is contained in:
ttomsa 2026-02-21 16:47:52 +00:00
commit b5db91bfdf
95 changed files with 1649 additions and 1764 deletions

View file

@ -21,6 +21,9 @@ jobs:
# the 3 minute timeout should not be raised
testmacpytest:
name: Mac pytest
env:
CI: ""
CAPTURE_PROCESS_REPLAY: "0"
runs-on: [self-hosted, macOS]
timeout-minutes: 3
defaults:
@ -41,22 +44,14 @@ jobs:
run: |
echo "CACHEDB=/tmp/pytest-db-ci.db" >> $GITHUB_ENV
rm -f /tmp/pytest-db-ci*
# TODO: remove this step once all old caches are migrated
- name: Migrate old huggingface cache (symlinks break onnxruntime 1.24+)
run: |
cd ~/Library/Caches/tinygrad/downloads/models 2>/dev/null || exit 0
for old_dir in models--*; do
[ -d "$old_dir" ] || continue
repo_id=$(echo "$old_dir" | sed 's/models--//; s/--/\//g')
snapshot=$(ls -1 "$old_dir/snapshots" 2>/dev/null | head -1)
[ -n "$snapshot" ] || continue
mkdir -p "$repo_id"
cp -RLn "$old_dir/snapshots/$snapshot/"* "$repo_id/" 2>/dev/null || true
done
- name: Run pytest -nauto
run: |
source /tmp/tinygrad_pytest_ci/bin/activate
pytest -nauto --durations=20
- name: openpilot compile3 0.10.1 driving_vision
run: FLOAT16=1 CL=1 IMAGE=2 python3.11 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
- name: IMAGE=1 openpilot compile3 0.10.1 driving_vision
run: FLOAT16=1 CL=1 IMAGE=1 python3.11 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
testmacbenchmark:
name: Mac Benchmark
@ -515,7 +510,7 @@ jobs:
- name: Run 10 CIFAR training steps
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 python3 examples/hlb_cifar10.py
- name: Run 10 CIFAR training steps w HALF
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=230 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
# - name: Run 10 CIFAR training steps w BF16
# run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=288 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py
# TODO: too slow
@ -525,8 +520,9 @@ jobs:
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Run full CIFAR training steps w 6 GPUS
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Test full tinyfs load
run: TINYFS_ENDPOINT=10.0.52.11:6767 PYTHONPATH=. python extra/tinyfs/fetch_file.py --hash d734f5e3be9f1e9d863bfaa4fc6c1ef2 --len 175866113 --dest mapping.json --check
# this needs to be mocked and testable on a local machine
#- name: Test full tinyfs load
# run: TINYFS_ENDPOINT=10.0.52.11:6767 PYTHONPATH=. python extra/tinyfs/fetch_file.py --hash d734f5e3be9f1e9d863bfaa4fc6c1ef2 --len 175866113 --dest mapping.json --check
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py

View file

@ -1,7 +1,7 @@
name: Unit Tests
env:
# increment this when downloads substantially change to avoid the internet
CACHE_VERSION: '16'
CACHE_VERSION: '17'
CAPTURE_PROCESS_REPLAY: 1
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PYTHONPATH: ${{ github.workspace }}
@ -664,6 +664,28 @@ jobs:
- name: Run LLVM test
run: AMD_LLVM=1 python test/device/test_amd_llvm.py
testmockam:
name: Linux (am)
runs-on: ubuntu-24.04
timeout-minutes: 15
env:
AMD: 1
MOCKGPU: 1
AMD_IFACE: PCI
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: mockam
deps: testing_unit
amd: 'true'
- name: Run test_tiny on MOCKAM
run: python test/test_tiny.py
- name: Run test_hcq on MOCKAM
run: python -m pytest test/device/test_hcq.py
testamd:
strategy:
fail-fast: false

View file

@ -10,7 +10,7 @@ Directories are listed in order of how they are processed.
Group UOps into kernels.
::: tinygrad.schedule.rangeify.get_rangeify_map
::: tinygrad.schedule.rangeify.get_rangeify
options:
members: false
show_labels: false

View file

@ -19,8 +19,8 @@ cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
EVAL_BS = getenv("EVAL_BS", BS)
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}"
assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}"
class UnsyncedBatchNorm:
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):

View file

@ -65,17 +65,7 @@ def loader_process(q_in, q_out, X:Tensor, seed):
else:
# pad data with training mean
img = np.tile(np.array([[[123.68, 116.78, 103.94]]], dtype=np.uint8), (224, 224, 1))
# broken out
#img_tensor = Tensor(img.tobytes(), device='CPU')
#storage_tensor = X[idx].contiguous().realize().lazydata.base.realized
#storage_tensor._copyin(img_tensor.numpy())
# faster
X[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = img.tobytes()
# ideal
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
X[idx].flatten().assign(img.tobytes())
q_out.put(idx)
q_out.put(None)
@ -264,8 +254,8 @@ def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tens
x = random_brightness_augmentation(x)
x = gaussian_noise(x)
X[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = x.tobytes()
Y[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = y.tobytes()
X[idx].flatten().assign(x.tobytes())
Y[idx].flatten().assign(y.tobytes())
queue_out.put(idx)
queue_out.put(None)
@ -379,12 +369,12 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
clipped_match_idxs = np.clip(match_idxs, 0, None)
clipped_boxes, clipped_labels = tgt["boxes"][clipped_match_idxs], tgt["labels"][clipped_match_idxs]
boxes[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = clipped_boxes.tobytes()
labels[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = clipped_labels.tobytes()
matches[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = match_idxs.tobytes()
anchors[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = anchor.tobytes()
boxes[idx].flatten().assign(clipped_boxes.tobytes())
labels[idx].flatten().assign(clipped_labels.tobytes())
matches[idx].flatten().assign(match_idxs.tobytes())
anchors[idx].flatten().assign(anchor.tobytes())
imgs[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = img.tobytes()
imgs[idx].flatten().assign(img.tobytes())
queue_out.put(idx)
queue_out.put(None)

View file

@ -1285,6 +1285,7 @@ def train_llama3():
from extra.models.llama import Transformer
from examples.llama3 import MODEL_PARAMS
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
from examples.mlperf.optim import GradAccClipAdamW
BENCHMARK = getenv("BENCHMARK")
@ -1370,13 +1371,13 @@ def train_llama3():
# prevents memory spike on device 0
v.realize()
optim = AdamW(get_parameters(model), lr=0.0,
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
optim = GradAccClipAdamW(get_parameters(model), lr=0.0,
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay, grad_acc=grad_acc)
# init grads
for p in optim.params:
p.grad = p.zeros_like().contiguous().realize()
grads = [p.grad for p in optim.params]
grads: list[Tensor] = [p.grad for p in optim.params]
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
@ -1407,25 +1408,11 @@ def train_llama3():
@TinyJit
def optim_step():
for p in optim.params:
p.grad.assign(p.grad / grad_acc)
# L2 norm grad clip
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
if not getenv("DISABLE_GRAD_CLIP_NORM"):
total_norm = Tensor(0.0, dtype=dtypes.float32, device=optim.params[0].device)
for g in grads:
total_norm += g.float().square().sum()
total_norm = total_norm.sqrt().contiguous().realize()
for g in grads:
g.assign((g * (opt_gradient_clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(g.dtype)).realize()
optim.step()
scheduler.step()
for g in grads:
g.assign(g.zeros_like().contiguous()).realize()
g.assign(g.zeros_like())
lr = optim.lr
Tensor.realize(lr, *grads)

24
examples/mlperf/optim.py Normal file
View file

@ -0,0 +1,24 @@
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.nn.optim import LAMB
from tinygrad.helpers import FUSE_OPTIM
class GradAccClipAdamW(LAMB):
def __init__(self, params:list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, grad_acc=1, clip_norm=1.0, fused=FUSE_OPTIM):
super().__init__(params, lr, b1, b2, eps, weight_decay, adam=True, fused=FUSE_OPTIM)
self.grad_acc, self.clip_norm = grad_acc, clip_norm
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
if self.fused:
grads[0] = grads[0] / self.grad_acc
total_norm = grads[0].float().square().sum().sqrt()
grads[0] = (grads[0] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[0].dtype)
else:
total_norm = Tensor.zeros((), dtype=dtypes.float32, device=self.device)
for g in grads:
total_norm += g.float().square().sum()
total_norm = total_norm.sqrt()
for i in range(len(grads)):
grads[i] = grads[i] / self.grad_acc
grads[i] = (grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype)
return super()._step(params, grads)

View file

@ -11,9 +11,10 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export WQKV=${WQKV:-0}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-1}
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
export GBS=$((BS * GRADIENT_ACC_STEPS))
export MODEL="llama3"
@ -30,8 +31,11 @@ export SEED=${SEED:-5760}
export DATA_SEED=${DATA_SEED:-5760}
export JITBEAM=${JITBEAM:-3}
export BEAM_UOPS_MAX=6000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export BEAM_UOPS_MAX=6000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=1
export FAKEDATA=1 BENCHMARK=10 LLAMA_LAYERS=2
export FAKEDATA=1 BENCHMARK=10
if [ -z "$FULL_LAYERS" ]; then
export LLAMA_LAYERS=2
fi
python3 examples/mlperf/model_train.py

View file

@ -11,9 +11,10 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export WQKV=${WQKV:-0}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-1}
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
export GBS=$((BS * GRADIENT_ACC_STEPS))
export MODEL="llama3"
@ -30,6 +31,6 @@ export SEED=${SEED:-$RANDOM}
export DATA_SEED=${DATA_SEED:-5760}
export JITBEAM=${JITBEAM:-3}
export BEAM_UOPS_MAX=6000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export BEAM_UOPS_MAX=6000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=1
python3 examples/mlperf/model_train.py

View file

@ -236,8 +236,6 @@ class SMICtx:
case _: return metrics.SmuMetrics.AverageSocketPower, metrics.SmuMetrics.dGPU_W_MAX
def get_mem_usage(self, dev):
return 0
usage = 0
pt_stack = [dev.mm.root_page_table]
while len(pt_stack) > 0:
@ -246,8 +244,8 @@ class SMICtx:
entry = pt.entries[i]
if (entry & am.AMDGPU_PTE_VALID) == 0: continue
if pt.lv!=am.AMDGPU_VM_PTB and not dev.gmc.is_pte_huge_page(pt.lv, entry):
pt_stack.append(AMPageTableEntry(dev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1))
if pt.lv < am.AMDGPU_VM_PDB0 and not dev.gmc.is_pte_huge_page(pt.lv, entry):
pt_stack.append(AMPageTableEntry(dev, dev.xgmi2paddr(entry & 0x0000FFFFFFFFF000), lv=pt.lv+1))
continue
if (entry & am.AMDGPU_PTE_SYSTEM) != 0: continue
usage += (1 << ((9 * (3-pt.lv)) + 12))

View file

@ -41,9 +41,13 @@ class Attention:
self.n_rep = self.n_heads // self.n_kv_heads
self.max_context = max_context
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
if getenv("WQKV"):
self.wqkv = linear(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2, bias=False)
else:
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
@ -51,9 +55,8 @@ class Attention:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]=None) -> Tensor:
if getenv("WQKV"):
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
xqkv = x @ self.wqkv.T
xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
xqkv = self.wqkv(x)
xq, xk, xv = xqkv.split([self.n_heads * self.head_dim, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim], dim=2)
else:
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
@ -200,14 +203,14 @@ class Transformer:
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.tok_embeddings(tokens).contiguous()
freqs_cis = self.freqs_cis.cast(h.dtype)[:, start_pos:start_pos+seqlen, :, :, :]
if self.max_context != 0 and seqlen > 1:
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1)
else: mask = None
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h))
logits = self.output(self.norm(h).contiguous().contiguous_backward()).contiguous_backward()
if math.isnan(temperature): return logits
return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p)

View file

@ -4,6 +4,8 @@ from typing import Generator
from tinygrad.helpers import temp, unwrap, DEBUG
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
from tinygrad.runtime.autogen import rocprof
from tinygrad.renderer.amd.dsl import Inst
from test.amd.disasm import disasm
@dataclasses.dataclass(frozen=True)
class InstExec:
@ -44,8 +46,8 @@ class OccEvent(WaveSlot):
RunKey = tuple[str, int]
class _ROCParseCtx:
def __init__(self, sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]):
self.sqtt_evs, self.disasms = iter(sqtt_evs), disasms
def __init__(self, sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, Inst]]):
self.sqtt_evs, self.disasms = iter(sqtt_evs), {k:{k2:(disasm(v2), v2.size()) for k2,v2 in v.items()} for k,v in disasms.items()}
self.inst_execs:dict[RunKey, list[WaveExec]] = {}
self.occ_events:dict[RunKey, list[OccEvent]] = {}
@ -71,7 +73,7 @@ class _ROCParseCtx:
self.inst_execs.setdefault(unwrap(self.active_run), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time,
ev.end_time, insts_blob))
def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]) -> _ROCParseCtx:
def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, Inst]]) -> _ROCParseCtx:
ROCParseCtx = _ROCParseCtx(sqtt_evs, disasms)
@rocprof.rocprof_trace_decoder_se_data_callback_t

View file

@ -47,24 +47,21 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
# delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
delta_vec = _sharded_empty((B, H, 1, N), xq, axis=0, dtype=dtypes.float32)
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch))[:2]
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:2]
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch))[:3]
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:3]
# unshuffle dq
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch))[0]
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[0]
return None, None, dq.uop, dk.uop, dv.uop
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch), grad_fxn=grad)[:2]
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D), grad_fxn=grad)[:2]
return attn.transpose(1, 2)
@functools.cache
def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str):
B, N, H, D = q.shape
H_KV = k.shape[2]
def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
code = (pathlib.Path(__file__).parent / "fa_fwd_causal.cpp").read_text()
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"]
@ -95,9 +92,7 @@ def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:st
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str):
B, N, H, D = o.shape
def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
code = (pathlib.Path(__file__).parent / "fa_bwd_pre.cpp").read_text()
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"]
@ -128,10 +123,7 @@ def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arc
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str):
B, N, H, D = q.shape
H_KV = k.shape[2]
def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
code = (pathlib.Path(__file__).parent / "fa_bwd_causal.cpp").read_text()
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"]
@ -162,9 +154,7 @@ def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_ve
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str):
B, N, H, D = dq_out.shape
def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
code = (pathlib.Path(__file__).parent / "fa_bwd_post.cpp").read_text()
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"]

View file

@ -23,7 +23,8 @@ if __name__ == "__main__":
kernel_count = GlobalCounters.kernel_count
assert kernel_count > 0, "No kernels, test failed"
expected_kernels = 228
# NOTE: this is 124 on torch 2.10.0
expected_kernels = 332
expectation = f"ResNet18 kernels are {kernel_count} vs {expected_kernels} expected."
if kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
assert kernel_count <= expected_kernels, f"{expectation}"

View file

@ -26,7 +26,7 @@ class TestKernelFusionRegression(unittest.TestCase):
def fn():
x = torch.randn(128, 128, device=device)
return (x + 1.0) * 2.0 - 0.5
self._check_kernel_count(fn, 6)
self._check_kernel_count(fn, 5)
def test_relu_fusion(self):
def fn():
@ -50,14 +50,14 @@ class TestKernelFusionRegression(unittest.TestCase):
def fn():
x = torch.randn(64, 64, device=device)
return (x * 2.0).sum()
self._check_kernel_count(fn, 7)
self._check_kernel_count(fn, 5)
def test_matmul_elementwise_fusion(self):
def fn():
x = torch.randn(32, 32, device=device)
w = torch.randn(32, 32, device=device)
return torch.nn.functional.relu(x @ w + 1.0)
self._check_kernel_count(fn, 6)
self._check_kernel_count(fn, 7)
def test_pooling_fusion(self):
def fn():
@ -71,7 +71,7 @@ class TestKernelFusionRegression(unittest.TestCase):
identity = torch.randn(1, 8, 16, 16, device=device)
out = x + identity
return torch.nn.functional.relu(out)
self._check_kernel_count(fn, 6)
self._check_kernel_count(fn, 7)
def test_inplace_add_relu_fusion(self):
def fn():
@ -79,7 +79,7 @@ class TestKernelFusionRegression(unittest.TestCase):
y = torch.randn(1, 16, 32, 32, device=device)
x += y
return torch.nn.functional.relu(x)
self._check_kernel_count(fn, 6)
self._check_kernel_count(fn, 7)
def test_conv_bn_add_relu_fusion(self):
def fn():
@ -92,7 +92,7 @@ class TestKernelFusionRegression(unittest.TestCase):
out = bn(conv(x))
out += identity
return torch.nn.functional.relu(out)
self._check_kernel_count(fn, 16)
self._check_kernel_count(fn, 17)
def test_multiple_inplace_ops_fusion(self):
def fn():
@ -138,7 +138,7 @@ class TestKernelFusionRegression(unittest.TestCase):
loss.backward()
optimizer.step()
return loss
self._check_kernel_count(fn, 33)
self._check_kernel_count(fn, 28)
if __name__ == "__main__":
unittest.main()

View file

@ -74,7 +74,7 @@ testing_minimal = [
"hypothesis>=6.148.9",
"z3-solver<4.15.4", # 4.15.4 has a segfault when creating many z3.Context()
]
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate", "openai", "ggml-python"]
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate", "openai", "gguf"]
testing = [
"tinygrad[testing_unit]",
"pillow",

View file

@ -1,19 +1,9 @@
"""Shared test helpers for AMD tests."""
import ctypes
from dataclasses import dataclass
from tinygrad.helpers import unwrap
from tinygrad.runtime.autogen import llvm
from tinygrad.runtime.support.elf import elf_loader
@dataclass
class KernelInfo:
code: bytes
src: str
global_size: tuple[int, int, int]
local_size: tuple[int, int, int]
buf_idxs: list[int] # indices into shared buffer pool
buf_sizes: list[int] # sizes for each buffer index
ARCH_TO_TARGET:dict[str, list[str]] = {
"rdna3":["gfx1100"],
"rdna4":["gfx1200"],

View file

@ -6,7 +6,6 @@ from tinygrad import Device
from test.mockgpu.amd.emu import WaveState, _decode_at, WAVE_SIZE, VCC_LO, EXEC_LO, SCC
from tinygrad.renderer.amd import decode_inst
from test.amd.helpers import KernelInfo
import tinygrad
REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.so"
if not REMU_PATH.exists(): REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.dylib"
@ -22,6 +21,15 @@ def _vals_equal(a: int, b: int) -> bool:
if a == b: return True
return _is_f32_nan(a) and _is_f32_nan(b)
@dataclass
class KernelSnapshot:
code: bytes
src: str
global_size: tuple[int, int, int]
local_size: tuple[int, int, int]
buf_idxs: list[int] # indices into shared buffer pool
buf_sizes: list[int] # sizes for each buffer index
@dataclass
class StateSnapshot:
pc: int
@ -285,7 +293,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
return True, f"Completed {gx*gy*gz} workgroups", total_steps
def compare_emulators_multi_kernel(kernels: list[KernelInfo], buf_pool: dict[int, int], max_steps: int = 1000,
def compare_emulators_multi_kernel(kernels: list[KernelSnapshot], buf_pool: dict[int, int], max_steps: int = 1000,
debug: bool = False, trace_len: int = 10, buf_data: dict[int, bytes] | None = None) -> tuple[bool, str]:
"""Run all kernels through both emulators with shared buffer pool."""
if buf_data is None: buf_data = {}
@ -349,7 +357,7 @@ def compare_emulators_with_memory(kernel: bytes, n_lanes: int, buf_sizes: list,
ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, (n_lanes, 1, 1), max_steps, debug, trace_len)
return ok, msg
def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int], dict[int, bytes]]:
def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelSnapshot], dict[int, int], dict[int, bytes]]:
"""Compile a tinygrad operation and extract all kernels with their buffer mappings."""
from tinygrad import Tensor
from tinygrad.runtime.support.elf import elf_loader
@ -387,7 +395,7 @@ def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int],
buf_pool[buf_id] = b.nbytes
buf_idxs.append(buf_id)
buf_sizes.append(b.nbytes)
kernels.append(KernelInfo(
kernels.append(KernelSnapshot(
code=bytes(sec.content),
src=lowered.prg.p.src,
global_size=tuple(lowered.prg.p.global_size),

View file

@ -21,7 +21,7 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR,
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.UNK_60}
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_STORE}
# ═══════════════════════════════════════════════════════════════════════════════
# ROCPROF DECODER

View file

@ -14,7 +14,7 @@ def rocprof_inst_traces_match(sqtt, prg, target):
from tinygrad.viz.serve import amd_decode
from extra.sqtt.roc import decode as roc_decode, InstExec
addr_table = amd_decode(prg.lib, target)
disasm_map = {addr+prg.base:(disasm(inst), inst.size()) for addr,inst in addr_table.items()}
disasm_map = {addr+prg.base:inst for addr,inst in addr_table.items()}
rctx = roc_decode([sqtt], {prg.tag:disasm_map})
rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), [])
rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit
@ -30,7 +30,7 @@ def rocprof_inst_traces_match(sqtt, prg, target):
rocprof_inst = next(rwaves_iter[info.wave][0])
ref_pc = rocprof_inst.pc-prg.base
# always check pc matches
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc][0]} != {info.pc}:{disasm(info.inst)}"
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}"
# special handling for s_endpgm, it marks the wave completion.
if info.inst == s_endpgm():
completed_wave = list(rwaves_iter[info.wave].pop(0))
@ -72,7 +72,6 @@ class TestSQTTMapBase(unittest.TestCase):
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
@unittest.skip("this doesn't work")
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"
if __name__ == "__main__":

View file

@ -67,6 +67,7 @@ class TestGemmLarge(unittest.TestCase):
if not is_cdna4():
self.skipTest("very slow on non mi350x")
def test_tiny(self): verify_asm_gemm(1, 256, 256, 64)
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336)
def test_gemm_batched(self): verify_asm_gemm(2, 8192, 4096, 4096)

View file

@ -10,7 +10,9 @@ def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once
schedule = t.schedule()
asts = [s for s in schedule if s.ast.op is Ops.SINK]
assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
len(asts)
# NOT SUPPORTED ANYMORE
#assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
class TestMovedConstFolding(unittest.TestCase):
def test_add_shrunk_zero(self):

View file

@ -266,7 +266,7 @@ class TestCustomKernel(unittest.TestCase):
The custom_addmul kernel should be at index 3.
"""
from tinygrad.engine.schedule import create_schedule
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.rangeify import get_rangeify
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
A2 = (A + 1).contiguous() # kernel 0: depends on A
@ -277,8 +277,7 @@ class TestCustomKernel(unittest.TestCase):
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
big_sink = result.uop.sink()
tensor_map = get_rangeify_map(big_sink)
sched_sink = big_sink.substitute(tensor_map)
sched_sink = get_rangeify(big_sink)
schedule, _ = create_schedule(sched_sink)
# Find the custom_addmul kernel position

View file

@ -115,7 +115,7 @@ class TestImageDType(unittest.TestCase):
tst = data.numpy()
it = data.cast(dtypes.imagef((9,27,4))).realize()
# the underlying UOp is identical
self.assertIs(it.uop.base.realized, data.uop.base.realized)
#self.assertIs(it.uop.base.realized, data.uop.base.realized)
np.testing.assert_equal(tst, it.numpy())
def test_image_and_back_wrong_shape(self):

View file

@ -332,7 +332,6 @@ class TestJit(unittest.TestCase):
assert len(res3) == 10, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
#@unittest.expectedFailure # requires contiguous folding
def test_jit_random_after_unrealized_random(self):
@TinyJit
def f(): return Tensor.rand()
@ -476,7 +475,7 @@ class TestJit(unittest.TestCase):
b = f(Tensor([2.0]))
assert abs((a - b).item()) > 0.5
def test_jit_init_with_empty_different_size(self):
def test_jit_init_empty(self):
@TinyJit
def f(x:Tensor) -> Tensor: return (x + 1).realize()
@ -485,9 +484,16 @@ class TestJit(unittest.TestCase):
# scalar const input is not allowed
with self.assertRaises(JitError):
f(Tensor(2.0)).item()
# list input has different view structure than empty(1)
with self.assertRaises(JitError):
f(Tensor([2.0])).item()
# self.assertEqual(f(Tensor([2.0])).item(), 1.0) # TODO: wrong output, should be 3.0. currently depends on empty value
def test_jit_init_empty_alt(self):
@TinyJit
def f(a:Tensor, b:Tensor) -> Tensor: return b.assign(a+1)
for i in range(4):
a = Tensor([i])
b = Tensor.empty_like(a)
c = f(a, b)
self.assertEqual(c.item(), i+1)
@unittest.skip("Pending multioutput implementation #3607")
class TestMultioutputJit(unittest.TestCase):
@ -645,8 +651,8 @@ class TestJitFree(unittest.TestCase):
def test_replan_buffers_memory_layout(self):
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless")
ext_tensor = Tensor([1,24,23,45,1])
ext_tensor_2 = Tensor([2,2,2,2,2])
ext_tensor = Tensor([1,24,23,45,1]).contiguous()
ext_tensor_2 = Tensor([2,2,2,2,2]).contiguous()
@TinyJit
def fxn(x:Tensor):
out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
@ -654,9 +660,9 @@ class TestJitFree(unittest.TestCase):
for i in range(5):
out = fxn(Tensor([i,1,2,3,4]))
self.assertEqual(out.item(), 11400+200*i)
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 4)
fxn.captured.replan_buffers_memory_layout()
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 2)
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)

View file

@ -3,8 +3,7 @@ import unittest
from dataclasses import replace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, CompiledRunner, get_program
@ -254,100 +253,6 @@ class TestLinearizer(unittest.TestCase):
if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src):
assert end_range < uops.index(u)
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
sizes = [x.src[0].arg for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
# TODO: add these back after uop symbolic
# for i in range(len(dims)):
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
# for i in range(len(loop_idxs)):
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
# no-op
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals: len(dims) == len(max)
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
# prefer group_dim strategy when possible
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 3
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# 2 -> 2
_assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
# test when the only divisor is the square root of dim
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
# collapse on left-most available axis (the left most is too small)
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
# TODO: In the above cases we only test if the shape after reshape is correct, never the indices.
# We should check if the returned indices are correct, for all cases.
# (65536, 2) -> (32768, 4)
dims, expected_limited_dims = (65536,2), (32768, 4)
idxs = get_grouped_dims("gidx", dims, (65535,65535,65535))
def match_div(): raise RuntimeError("match_div")
def match_mod(): raise RuntimeError("match_mod")
flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1')
pm = PatternMatcher([
(flat_idx_pattern//dims[1], match_div),
(flat_idx_pattern%dims[1], match_mod)
])
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[0], pm)
self.assertIn("match_div", str(error.exception))
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[1], pm)
self.assertIn("match_mod", str(error.exception))
# # variable too large
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_default_global_reversed(self):
# shrink so that the dims do not collapse

View file

@ -94,7 +94,7 @@ class TestMultiTensor(unittest.TestCase):
def _test_shard_op(self, op, out, n=4):
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
r = op(t).realize()
assert t.uop.is_realized, "shard didn't realize"
#assert t.uop.is_realized, "shard didn't realize"
self.assertEqual(r.tolist(), out)
def test_shard_reshape(self): self._test_shard_op(lambda t:t.reshape(2, 2), [[1.,1.],[1.,1.]])
def test_shard_elementwise(self): self._test_shard_op(lambda t:(t+t).reshape(2, 2), [[2.,2.],[2.,2.]])
@ -654,54 +654,6 @@ class TestMultiTensor(unittest.TestCase):
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
assert isinstance(jf.jit_cache[5].prg, graph_d1)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard(self):
for N in range(1, 6):
X = Tensor.rand(4, 1, 257).contiguous().realize()
n = X.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
X.shard_(devices, 2)
np.testing.assert_equal(X.numpy(), n)
np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257)))
np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257])
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
@unittest.skip("no longer supports uneven shard")
def test_uneven_multiple_zeros(self):
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
for N in (1, 2, 3, 4):
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
# make sure something is computed on each device
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
np.testing.assert_equal(X.numpy(), data)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard_with_empty(self):
N = 4
X = Tensor.rand(16, 1, 3).contiguous().realize()
np_x = X.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
# test empty shard
np.testing.assert_equal(X.shard(devices, 0).numpy(), np_x)
# test reshape with empty shard
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
@unittest.skip("no longer supports uneven shard")
def test_multiple_uneven_shard(self):
N = 4
X = Tensor.rand(4, 1, 257).contiguous().realize()
Y = Tensor.rand(4, 1, 257).contiguous().realize()
np_x, np_y = X.numpy(), Y.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
X.shard_(devices, 2)
Y.shard_(devices, 2)
np.testing.assert_equal(X.numpy(), np_x)
np.testing.assert_equal(Y.numpy(), np_y)
np.testing.assert_equal((X + Y).numpy(), np_x + np_y)
def test_bn_ast_on_devices(self):
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
bn = nn.BatchNorm2d(64)
@ -752,34 +704,7 @@ class TestMultiTensor(unittest.TestCase):
# test no left join
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7)).schedule()
@unittest.skip("no longer supports uneven shard")
def test_reshape_on_axis_uneven(self):
def reshape_helper(t0, t, t_axis):
assert t.uop.axis == t_axis
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21])
# ok to reshape as long as elements remain on same device
reshape_helper(t0, t0.reshape(2, 2, 42, 3, 5), 2)
# split to the right
reshape_helper(t0, t0.reshape(2, 2, 6, 7, 15), 2)
# split off and merge to the right
reshape_helper(t0, t0.reshape(4, 6, 105), 1)
# really blend the axes together
reshape_helper(t0, t0.reshape(4, 30, 21), 1)
# split off 1-shape
reshape_helper(t0, t0.reshape(4, 1, 42, 15), 2)
reshape_helper(t0, t0.reshape(4, 6, 1, 7, 15), 1)
# assert if cannot maintain shard axis without moving items between devices
with self.assertRaises(AssertionError): t0.reshape(4, 7, 6, 15)
# assert for degenerate reshape
with self.assertRaises(AssertionError): t0.reshape(4, 5, 7, 15)
# assert for cannot maintain axis
with self.assertRaises(AssertionError): t0.reshape(4, 3, 2, 7, 15)
t0.reshape((26*15,7)).contiguous().schedule()
# it doesn't work like this anymore
# NOTE: this never failed in assign_multi, it failed tensor spec because MULTI was never pushed in the graph
@ -849,16 +774,6 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(rab.device, devices_4)
self.assertEqual(rab.uop.axis, 0)
@unittest.skip("no longer supports uneven shard")
def test_rand_like_uneven_shard(self):
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
t2 = Tensor.rand_like(t)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.uop.axis, t2.uop.axis)
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.uop.src, t2.uop.src))
def test_rand_like_none_shard(self):
t = Tensor.empty((16, 16)).shard(devices_2)
t2 = Tensor.rand_like(t)
@ -894,6 +809,14 @@ class TestMultiTensor(unittest.TestCase):
t2.realize()
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_full_like_shrink_on_shard_axis(self):
t = Tensor.ones(16, 16, dtype=dtypes.int).shard(devices_2, axis=0)
out = Tensor.full_like(t, 2)[:, :8]
sched = out.schedule()
self.assertEqual(len(sched), 2) # TODO: 0. fix mstack_early_shrink
run_schedule(sched)
self.assertEqual(out.tolist(), [[2]*8]*16)
def test_dropout_on_shard(self):
with Tensor.train():
X = Tensor.ones(256).to(devices_2)
@ -910,15 +833,6 @@ class TestMultiTensor(unittest.TestCase):
assert set(unique) == {0, 2}, unique
assert 200 < counts[0] < 312, counts[0]
@unittest.skip("no longer supports uneven shard")
def test_dropout_on_uneven_shard_axis(self):
with Tensor.train():
X = Tensor.ones(256).shard(devices_3, axis=0)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique
assert 100 < counts[0] < 156, counts[0]
@unittest.skip("TODO: this requires forced_realize to be deleted.")
def test_shard_memory(self):
devices = (d0, d1, d2, d3)
@ -926,13 +840,15 @@ class TestMultiTensor(unittest.TestCase):
t.shard_(devices, axis=0).realize()
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.uop.src])
@unittest.skip("this is unreliable on OSX")
def test_clone(self):
t = Tensor.rand(16, 16).shard(devices_2, axis=None)
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
t = Tensor.rand(16, 16).shard(devices_2, axis=0)
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
for axis in (None, 0):
t = Tensor.arange(16).reshape(4, 4).shard(devices_2, axis=axis).contiguous().realize()
t_clone = t.clone().realize()
self.assertEqual(t_clone.device, t.device)
self.assertEqual(t_clone.uop.axis, axis)
self.assertEqual(t_clone.tolist(), t.tolist())
t_clone += 1
self.assertNotEqual(t_clone.tolist(), t.tolist())
@unittest.skip("RANGEIFY doesn't support multi const folding")
def test_multi_const_folding(self):
@ -981,18 +897,18 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
with self.assertRaises(AssertionError):
# sharded axis shrink on non-device boundry is not allowed
a = t.shrink(((0, 3), (0, 8)))
a.schedule()
with self.assertRaises(AssertionError):
# cannot shrink sharded and non-sharded axis at the same time
a = t.shrink(((0, 2), (2, 4)))
a = t.shrink(((0, 3), (0, 8))).contiguous()
a.schedule()
a = t.shrink(((0, 2), (2, 4)))
assert a.shape == (2, 2)
ref = Tensor.arange(64).reshape(8, 8).shrink(((0, 2), (2, 4)))
np.testing.assert_equal(a.numpy(), ref.numpy())
a = t.shrink(((0, 2), (0, 8)))
a = t.shrink(((0, 2), (0, 8))).contiguous()
a.schedule()
assert a.shape == (2, 8)
p = a.pad(((0, 6), (0, 0)))
p = a.pad(((0, 6), (0, 0))).contiguous()
p.schedule()
assert p.shape == (8, 8)
@ -1042,24 +958,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
@unittest.skip("no longer supports uneven shard")
def test_uneven(self):
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)
a = t.shrink(((0, 2), None))
b = t.shrink(((2, 3), None))
na = t.numpy()[0:2]
nb = t.numpy()[2:3]
np.testing.assert_equal(a.numpy(), na)
np.testing.assert_equal(b.numpy(), nb)
np.testing.assert_equal((a+1).numpy(), na+1)
np.testing.assert_equal((b+1).numpy(), nb+1)
np.testing.assert_equal((1+a).numpy(), 1+na)
np.testing.assert_equal((1+b).numpy(), 1+nb)
np.testing.assert_equal((a+a).numpy(), na+na)
np.testing.assert_equal((b+b).numpy(), nb+nb)
def test_add_two_partitions(self):
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)

View file

@ -8,7 +8,8 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
from tinygrad.renderer.nir import NIRRenderer
if getenv("TINY_BACKEND"):
TINY_BACKEND = getenv("TINY_BACKEND")
if TINY_BACKEND:
import tinygrad.nn.torch # noqa: F401 # pylint: disable=unused-import
torch.set_default_device("tiny")
@ -760,6 +761,7 @@ class TestOps(unittest.TestCase):
data = [[1,-8,1],[32,1,6]]
tor = torch.tensor(data, dtype=torch.int)
ten = Tensor(data, dtype=dtypes.int32)
# NOTE: this breaks assigns because it's folded to 0!
helper_test_op([], lambda: tor^tor, lambda: ten^ten, forward_only=True)
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)

View file

@ -1,230 +0,0 @@
import unittest
import numpy as np
from tinygrad import Tensor, UOp, nn
from tinygrad.uop.ops import AxisType, Ops
class TestOuterworldReduce(unittest.TestCase):
def test_reduce(self):
x = Tensor.ones(5, 5).contiguous()
a = UOp.range(5, -1, AxisType.REDUCE)
out = x[a]
# TODO: syntax for this
t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
self.assertListEqual(t.tolist(), [5.,5.,5.,5.,5.])
# TODO: delete test_outerworld_range?
class TestOuterRange(unittest.TestCase):
def test_simple_range(self):
a = Tensor.ones(10).contiguous()
acc = Tensor.zeros().contiguous()
Tensor.realize(a, acc)
# this is fold
i = UOp.range(10, -100, AxisType.OUTER)
acc_i = acc.uop.after(i)
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[vi].uop).end(i)))
out.realize()
assert out.item() == 10.0
def test_inner_range(self):
a = Tensor.ones(10, 10).contiguous()
acc = Tensor.zeros(10).contiguous()
Tensor.realize(a, acc)
# this is fold
i = UOp.range(10, -100, AxisType.OUTER)
acc_i = acc.uop.after(i)
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[:, vi].uop).end(i)))
out.realize()
self.assertEqual(out.tolist(), [10.0]*10)
def test_range_matmul(self):
vec = Tensor.randn(1, 10).realize()
mats = Tensor.randn(3, 10, 10).realize()
# 3 matmuls in "scan"
ref = ((vec @ mats[0]) @ mats[1]) @ mats[2]
ref.realize()
# 3 matmuls with outer world range
i = UOp.range(3, -100, AxisType.OUTER)
vec_i = Tensor(vec.uop.after(i))
comp = vec_i.contiguous() @ mats[i]
store = vec_i.uop.store(comp.uop).end(i)
out = Tensor(vec.uop.after(store))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}"
class TestOuterScan(unittest.TestCase):
def _test_scan(self):
vec = Tensor.randn(1, 10).realize()
mats = Tensor.randn(3, 10, 10).realize()
# 3 matmuls in "scan"
vec1 = vec @ mats[0]
vec2 = vec1 @ mats[1]
vec3 = vec2 @ mats[2]
ref = Tensor.stack(vec1, vec2, vec3)
ref.realize()
return vec, mats, ref
def test_uop_scan_matmul(self):
vec, mats, ref = self._test_scan()
# 3 matmuls with SCAN
i = UOp.range(3, -100, AxisType.OUTER)
out = Tensor.empty(3, 1, 10)
phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
comp = phi @ mats[i]
store = out[i].uop.store(comp.uop).end(i)
out = Tensor(out.uop.after(store))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}"
class TestOuterworld(unittest.TestCase):
def test_range_plus_1(self):
t = Tensor.arange(100).reshape(10,10).realize()
# passthrough ranges
a = UOp.range(10, -1)
sel = t[a] + 1
assert sel.shape == (10,)
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
self.assertTrue((t+1==cpy).all().item())
def test_range_plus_1_transpose(self):
t = Tensor.arange(100).reshape(10,10).realize()
# passthrough ranges
a = UOp.range(10, -1)
sel = t[a] + 1
assert sel.shape == (10,)
cpy = sel.reshape(10, 1).expand(10, a).contiguous().realize()
self.assertTrue(((t+1).T==cpy).all().item())
def test_flip_range(self):
t = Tensor.rand(10, 10).realize()
# passthrough ranges
a = UOp.range(10, -1)
sel = t[9-a]
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
self.assertTrue((t.flip(0)==cpy).all().item())
def test_vmap(self):
def f(x): return x.sum(axis=0)*2
x = Tensor.ones(3, 10, 2).contiguous()
# vmap across axis 0
a = UOp.range(3, -1)
out = f(x[a])
out = out.reshape(1, 2).expand(a, 2).contiguous()
# 3x2 grid of 20
out.realize()
self.assertTrue((out==20).all().item())
def test_fancy_vmap(self):
def f(x,y): return x+y
x = Tensor.arange(9).reshape(3,3).contiguous()
y = Tensor.arange(9).reshape(3,3).contiguous()
a = UOp.range(3, -1)
out = f(x[:,a], y[a,:])
# TODO: this should support flatten
out = out.reshape(1, 3).expand(a, 3).contiguous().realize()
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
class TestVmap(unittest.TestCase):
def test_vmap_inner(self, axis_type=AxisType.LOOP, fuse=False, grad=False):
x = Tensor.ones(1, 10).contiguous().requires_grad_()
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
ref = x @ mats
if fuse: ref = ref * 2
# vmap across axis 0
a = UOp.range(3, -1, axis_type)
out = x @ mats[a]
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
if fuse: out = out * 2
if grad:
out.mean().backward()
np.testing.assert_allclose(mats.grad.numpy(), (2./30) if fuse else (1./30))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-6), f"max diff {(ref-out).abs().max().item()}"
def test_vmap_inner_fuse(self): self.test_vmap_inner(fuse=True)
def test_vmap_outer(self): self.test_vmap_inner(AxisType.OUTER)
def test_vmap_outer_fuse(self): self.test_vmap_inner(AxisType.OUTER, fuse=True)
def test_vmap_inner_grad(self): self.test_vmap_inner(grad=True)
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
def test_vmap_convs(self):
layers = [
nn.Conv2d(1, 8, 3), Tensor.relu,
nn.Conv2d(8, 8, 3), Tensor.relu]
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None, None, None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
def test_vmap_gemm(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
def test_vmap_gemm_grad(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
layer_tensors = nn.state.get_parameters(layers)
img = Tensor.randn(4, 16).realize(*layer_tensors)
for l in layer_tensors: l.requires_grad_()
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
out_grads = [x.numpy() for x in grads]
# compute reference grads
for l in layer_tensors: l.grad = None
img.sequential(layers).mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
ref_grads = [x.numpy() for x in grads]
# compare
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
if __name__ == '__main__':
unittest.main()

View file

@ -1,19 +0,0 @@
import unittest
from tinygrad import Tensor
class TestOuterCall(unittest.TestCase):
def test_outer_call_assign(self):
a = Tensor.zeros(10,10).contiguous()
b = Tensor.ones(10,10).contiguous()
Tensor.realize(a,b)
pa = a.as_param(0)
pb = b.as_param(1)
out = Tensor.call(a, b, fxn=pa.assign(pa+pb))
out.realize()
print(a.numpy())
assert (a == 1).all().item()
if __name__ == '__main__':
unittest.main()

View file

@ -1,148 +0,0 @@
import unittest
from tinygrad import Tensor, nn, Variable, UOp
# outerworld range should support three things
# 1. full optimizer steps (test_model_bound_range)
# 2. gradient accumulation (you want to end the range before running the optimizer)
# 3. stacked linear layers
class Model:
def __init__(self): self.w = nn.Linear(64, 8, bias=False)
def __call__(self, x:Tensor) -> Tensor: return self.w(x)
def get_model_and_opt():
Tensor.manual_seed(1337)
m = Model()
opt = nn.optim.SGD(nn.state.get_parameters(m), lr=0.1, weight_decay=0)
return m, opt
class TestOuterworldRange(unittest.TestCase):
STEPS = 5
BS = 20
@classmethod
def setUpClass(cls):
Tensor.manual_seed(1338)
# it learns to compute mean
cls.X = Tensor.randn(cls.STEPS, cls.BS, 64).contiguous().realize()
cls.Y = cls.X.reshape(cls.STEPS, cls.BS, 8, 8).mean(axis=-1).contiguous().realize()
cls.losses = cls._get_model_baseline()
def _compare(self, losses):
for i,(x,y) in enumerate(zip(self.losses, losses)):
self.assertAlmostEqual(x, y, places=5, msg=f"mismatch at {i} in {self.losses} vs {losses}")
@classmethod
@Tensor.train()
def _get_model_baseline(self):
m, opt = get_model_and_opt()
losses = []
for i in range(self.STEPS):
opt.zero_grad()
loss = (m(self.X[i]) - self.Y[i]).square().mean()
loss.backward()
loss.realize(*opt.schedule_step())
losses.append(loss.item())
return losses
@Tensor.train()
def test_model_grad_acc(self):
m, opt = get_model_and_opt()
losses = []
for i in range(self.STEPS):
opt.zero_grad()
sub_batch_size = self.BS//2
loss = 0
scaling_factor = self.BS//sub_batch_size
for j in range(0, self.BS, sub_batch_size):
sub_loss = (m(self.X[i][j:j+sub_batch_size]) - self.Y[i][j:j+sub_batch_size]).square().mean() / scaling_factor
sub_loss.backward()
loss += sub_loss
loss.realize(*opt.schedule_step())
losses.append(loss.item())
self._compare(losses)
@Tensor.train()
def test_model_variable(self):
m, opt = get_model_and_opt()
losses = []
vi = Variable('i', 0, self.STEPS-1)
for i in range(self.STEPS):
vib = vi.bind(i)
opt.zero_grad()
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
loss.backward()
loss.realize(*opt.schedule_step())
losses.append(loss.item())
self._compare(losses)
@Tensor.train()
def test_model_scheduled(self):
m, opt = get_model_and_opt()
losses = []
for i in range(self.STEPS):
opt.zero_grad()
loss = (m(self.X[i]) - self.Y[i]).square().mean()
loss.backward()
opt.schedule_step()
losses.append(loss)
self._compare(Tensor.stack(*losses).tolist())
@Tensor.train()
def test_model_scheduled_setitem(self):
m, opt = get_model_and_opt()
losses = Tensor.empty(self.STEPS)
for i in range(self.STEPS):
opt.zero_grad()
loss = (m(self.X[i]) - self.Y[i]).square().mean()
loss.backward()
opt.schedule_step()
# TODO: this shouldn't realize
losses[i] = loss.requires_grad_(False)
self._compare(losses.tolist())
@unittest.expectedFailure
@Tensor.train()
def test_model_scheduled_variable(self):
m, opt = get_model_and_opt()
losses = []
vi = Variable('i', 0, self.STEPS-1)
for i in range(self.STEPS):
vib = vi.bind(i)
opt.zero_grad()
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
loss.backward()
opt.schedule_step()
losses.append(loss)
self._compare(Tensor.stack(*losses).tolist())
@unittest.expectedFailure
@Tensor.train()
def test_model_scheduled_variable_setitem(self):
m, opt = get_model_and_opt()
losses = Tensor.empty(self.STEPS)
vi = Variable('i', 0, self.STEPS-1)
for i in range(self.STEPS):
vib = vi.bind(i)
opt.zero_grad()
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
loss.backward()
opt.schedule_step()
losses[vib] = loss.requires_grad_(False)
self._compare(losses.tolist())
@unittest.expectedFailure
@Tensor.train()
def test_model_bound_range(self):
m, opt = get_model_and_opt()
# TODO: should ranges be unique so you don't have to pass in the -1?
rng = UOp.range(self.STEPS, -1)
vib = Variable('i', 0, self.STEPS-1).bind(rng)
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
loss.backward()
losses = Tensor.empty(self.STEPS)
losses[vib] = loss
losses.realize(*opt.schedule_step())
if __name__ == "__main__":
unittest.main()

View file

@ -1,4 +1,4 @@
import unittest, struct, contextlib, statistics, time, gc
import unittest, struct, contextlib, statistics, gc
from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, getenv, Context, ProfileRangeEvent, cpu_profile, cpu_events, ProfilePointEvent, dedup
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileDeviceEvent, ProfileGraphEvent
@ -20,7 +20,7 @@ def helper_collect_profile(*devs):
cpu_events.clear()
profile_list = []
with Context(VIZ=1, PROFILE=1):
with Context(PROFILE=1):
yield profile_list
for dev in devs: dev.synchronize()
for dev in devs: dev._at_profile_finalize()
@ -170,30 +170,19 @@ class TestProfiler(unittest.TestCase):
for (i1, d1), (i2, d2) in pairs:
assert abs(jitter_matrix[i1][i2]) < 0.5, "jitter should be less than 0.5us"
@unittest.skip("this test is flaky")
def test_cpu_profile(self):
def test_fxn(err=False):
time.sleep(0.1)
if err: raise Exception()
time.sleep(0.1)
with helper_collect_profile(dev:=TestProfiler.d0) as profile:
with cpu_profile("test_1", dev.device):
with cpu_profile("test_1", dev):
test_fxn(err=False)
with self.assertRaises(Exception):
with cpu_profile("test_2", dev.device):
with cpu_profile("test_2", dev):
test_fxn(err=True)
range_events = [p for p in profile if isinstance(p, ProfileRangeEvent)]
range_events = [p for p in profile if isinstance(p, ProfileRangeEvent) and p.device == dev]
self.assertEqual(len(range_events), 2)
# record start/end time up to exit (error or success)
for e in range_events:
self.assertGreater(e.en, e.st)
e1, e2 = range_events
self.assertEqual([e1.name, e2.name], ["test_1", "test_2"])
# TODO: this is flaky
#self.assertLess(e1.st, e2.st)
#self.assertGreater(e1.en-e1.st, e2.en-e2.st)
@unittest.skip("this test is flaky")
@unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")

View file

@ -2,7 +2,7 @@
# schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py
import unittest, functools
import gc, unittest, functools
import numpy as np
from typing import cast
from hypothesis import assume, given, settings, strategies as strat
@ -168,13 +168,13 @@ class TestSchedule(unittest.TestCase):
a = Tensor.full((4,), 4.0).contiguous().realize()
b = Tensor.full((4,), 2.0).contiguous().realize()
expr = (a*b)/b
check_schedule(expr, 0)
run_schedule(check_schedule(expr, 1))
np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0))
def test_div_collapse_const(self):
a = Tensor.full((4,), 4.0).contiguous().realize()
expr = a/a
check_schedule(expr, 0)
run_schedule(check_schedule(expr, 1))
np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0))
def test_div_collapse(self):
@ -747,7 +747,7 @@ class TestSchedule(unittest.TestCase):
p = P[0]
p = p.pad(((1, 0), ))
p = p.repeat([2])
run_schedule(check_schedule(p, 3))
run_schedule(check_schedule(p, 4)) # TODO: this is high
tiny_ret = p.numpy()
P = np.ones((3, 3), dtype=np.float32)
@ -775,11 +775,12 @@ class TestSchedule(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Causes other tests to fail")
def test_conv2d_fused_half(self): _test_conv2d(4, dtype=dtypes.half)
@unittest.skip("TODO: this is consistently creating non reproducible failures")
def test_schedule_mem_used_with_inputs(self):
gc.collect()
base = GlobalCounters.mem_used
x = Tensor.ones(256).contiguous().realize()
(x+Tensor.ones(256).contiguous()).schedule()
gc.collect()
self.assertEqual(GlobalCounters.mem_used-base, 1024)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
@ -840,10 +841,9 @@ class TestSchedule(unittest.TestCase):
def test_cast_const_view(self):
a = Tensor.ones((4, 4), dtype=dtypes.float32)
casted_view = a.cast(dtypes.int32)
run_schedule(check_schedule(casted_view, 0))
self.assertIsNone(casted_view.uop.base.realized)
run_schedule(check_schedule(casted_view, 1))
realized_const_view = casted_view.contiguous()
run_schedule(check_schedule(realized_const_view, 1))
run_schedule(check_schedule(realized_const_view, 0))
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
@given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all))
@ -1036,7 +1036,7 @@ class TestSchedule(unittest.TestCase):
idx = Tensor([1,2,5,6], dtype=dtypes.int32)
flat_base[idx] = Tensor([99,99,99,99])
base.assign(flat_base.reshape(4, 4))
sched = check_schedule(base, 2)
sched = check_schedule(base, 6) # TODO: this is high
run_schedule(sched)
expected = list(range(16))
for i, v in zip([1,2,5,6], [99,99,99,99]): expected[i] = v
@ -1235,11 +1235,11 @@ class TestView(unittest.TestCase):
bv = b.pad(((0, 2),))[-2:]
# this becomes a late a*0
late_mul = a*bv
check_schedule(late_mul, 0)
run_schedule(check_schedule(late_mul, 2))
# the arange doesn't realize
self.assertIsNone(b.uop.base.realized)
#self.assertIsNone(b.uop.base.realized)
# mul doesn't realize
self.assertIsNone(late_mul.uop.base.realized)
#self.assertIsNone(late_mul.uop.base.realized)
self.assertEqual(late_mul.tolist(), [0, 0])
# SINK has two branches:
@ -1252,20 +1252,21 @@ class TestView(unittest.TestCase):
bv = b.pad(((0, 2),))[-2:]
late_mul = a*bv
other_child = b+2
s = check_schedule([late_mul, other_child], 2)
s = check_schedule([late_mul, other_child], 3)
# the arange becomes a BUFFER
self.assertIs(b.uop.base.op, Ops.BUFFER)
# NOTE: no longer checked
# mul still collapses
self.assertIs(late_mul.uop.base.op, Ops.CONST)
#self.assertIs(late_mul.uop.base.op, Ops.CONST)
run_schedule(s)
self.assertEqual(other_child.tolist(), [2, 3, 4])
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu")
class TestCopyFolding(unittest.TestCase):
def test_const_copy_is_free(self):
b = Tensor(1).to("CPU")
check_schedule(b, 0, filter_sink=False)
assert b.item() == 1
b = Tensor(1).to("CPU") * 4
run_schedule(check_schedule(b, 1, filter_sink=False))
assert b.item() == 4
def test_one_hot_with_copy(self):
y = Tensor([1, 2, 3]).to("CPU")
@ -1273,16 +1274,16 @@ class TestCopyFolding(unittest.TestCase):
check_schedule(x, 3, filter_sink=False)
def test_const_copy_multi(self):
x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"])
check_schedule(x, 0, filter_sink=False)
self.assertEqual(x.item(), 1)
x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"]) * 2
run_schedule(check_schedule(x, 2, filter_sink=False))
self.assertEqual(x.item(), 2.0)
def test_late_const_copy_folding(self):
a = Tensor.arange(3).realize()
zeros = Tensor.zeros(3).realize()
b = (a*zeros).to("CPU")
run_schedule(check_schedule(b, 0, filter_sink=False))
self.assertListEqual(b.tolist(), [0, 0, 0])
b = (a*zeros).to("CPU") + 1
run_schedule(check_schedule(b, 1, filter_sink=False))
self.assertListEqual(b.tolist(), [1, 1, 1])
self.assertEqual(b.device, "CPU")
def test_alu_after_copy(self):
@ -1321,7 +1322,7 @@ class TestCopyFolding(unittest.TestCase):
a = Tensor.ones(4, 4).contiguous().realize()
# use copy_to_device to bypass Tensor.to() shortcircuit and force a real same-device COPY in the graph
a.assign(Tensor(a.uop.copy_to_device(a.device), a.device))
run_schedule(check_schedule(a, 0, filter_sink=False))
run_schedule(check_schedule(a, 2, filter_sink=False))
self.assertListEqual(a.tolist(), [[1.]*4]*4)
def test_clone(self):

View file

@ -113,6 +113,12 @@ class TestFloatUOps(TestUOps):
def test_max(self): self._test_bop_fxn(Ops.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: a<b)
def test_cmpne(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: a!=b)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support NaN comparison correctly")
def test_cmpne_nan(self): # NaN != x for any x (IEEE 754), fixes #14095
for a, b in [(math.nan, 1.0), (1.0, math.nan), (math.nan, math.nan)]:
self.assertTrue(_test_single_value(
[dtypes.as_const(a, dtypes.float32), dtypes.as_const(b, dtypes.float32)],
Ops.CMPNE, (dtypes.float32, dtypes.float32)))
# MOD isn't tested on floats
def test_where(self):

View file

@ -76,7 +76,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "Can't handle async update on CPU device")
@unittest.skipIf(Device.DEFAULT in {"CPU"} or getenv("AMD_IFACE", "") == "PCI", "Can't handle async update on CPU/MOCKAM device")
def test_wait_late_set(self):
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
if queue_type is None: continue
@ -538,7 +538,7 @@ class TestHCQ(unittest.TestCase):
np.testing.assert_equal(cpu_buffer.numpy(), local_buf.numpy(), "failed")
@unittest.skipUnless(MOCKGPU, "Emulate this on MOCKGPU to check the path in CI")
@unittest.skipUnless(MOCKGPU and getenv("AMD_IFACE", "") != "PCI", "Emulate this on MOCKGPU to check the path in CI")
def test_on_device_hang(self):
if not hasattr(self.d0, 'on_device_hang'): self.skipTest("device does not have on_device_hang")

View file

@ -56,10 +56,12 @@ class TestOnnxRunner(unittest.TestCase):
output = runner({'inp': Tensor([1, 2, 3, 4])})['output']
_check_ast_count(0, output)
@unittest.skip("const folding is removed")
def test_const_fold_from_disk(self):
self._test_const_fold_unary_op(True)
self._test_const_fold_binary_op(True)
@unittest.skip("const folding is removed")
def test_const_fold_from_memory(self):
self._test_const_fold_unary_op(False)
# TODO: understand this and fix this, bitcast related

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools, base64, codecs
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, functools, base64, codecs
from dataclasses import replace
from typing import Callable, Any
@ -8,7 +8,6 @@ ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in
if not int(os.getenv("ASSERT_PROCESS_REPLAY", "1")): ASSERT_DIFF = 0
try:
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.engine.realize import get_program
from tinygrad.uop.ops import UOp, Ops, KernelInfo
@ -43,14 +42,6 @@ class ProcessReplayWarning(Warning): pass
# *** replay the function and convert return values to string
def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
def to_str(ret:UOp) -> str:
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.CALL]
return "\n".join([f"{len(asts)} kernels", *asts])
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
# the ast.arg is non None if we are inside of search.py
sink_arg = ast.arg or KernelInfo()
@ -68,8 +59,6 @@ def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {}
replayers["get_program"] = replay_get_program
# disable this for speed, does it ever find things?
#replayers["get_rangeify_map"] = replay_get_rangeify_map
# *** run replayers on captured rows and print diffs

View file

122
test/mockgpu/am/amdriver.py Normal file
View file

@ -0,0 +1,122 @@
import ctypes, ctypes.util, mmap, functools
from test.mockgpu.driver import VirtDriver, VirtFileDesc, TextFileDesc, DirFileDesc, VirtFile
from test.mockgpu.am.amgpu import MockAMGPU, VRAM_SIZE
DOORBELL_SIZE = 0x2000
BAR5_SIZE = (512 << 20)
PCIBUS = "mock:am:0"
libc = ctypes.CDLL(ctypes.util.find_library("c"))
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
libc.mmap.restype = ctypes.c_void_p
_empty_bar = "0x0000000000000000 0x0000000000000000 0x0000000000000000"
_resource_lines = [
f"0x0000000000000000 0x{VRAM_SIZE-1:016x} 0x0000000000000000", _empty_bar,
f"0x0000000000000000 0x{DOORBELL_SIZE-1:016x} 0x0000000000000000", _empty_bar, _empty_bar,
f"0x0000000000000000 0x{BAR5_SIZE-1:016x} 0x0000000000000000", _empty_bar,
]
class PagemapFileDesc(VirtFileDesc):
def __init__(self, fd, gpu):
super().__init__(fd)
self.gpu = gpu
def seek(self, offset): self.off = offset
def read_contents(self, size=None):
entries = bytearray()
for i in range((size or 8) // 8):
vaddr = ((self.off // 8) + i) * 0x1000
paddr = self.gpu._next_sysmem_paddr
self.gpu._next_sysmem_paddr += 0x1000
self.gpu._sysmem_map[paddr] = vaddr
entries += ((1 << 63) | (paddr // 0x1000)).to_bytes(8, 'little')
self.off += len(entries)
return bytes(entries)
class PCIBarFileDesc(VirtFileDesc):
def __init__(self, fd, memfd, driver=None):
super().__init__(fd)
self.memfd, self.driver = memfd, driver
def mmap(self, start, sz, prot, flags, fd, off):
addr = libc.mmap(start, sz, prot, flags, self.memfd, off)
if self.driver is not None:
self.driver.track_address(addr, addr + sz, lambda mv, idx: None, lambda mv, idx: self.driver._emulate_execute())
return addr
class PCIMMIOBarFileDesc(VirtFileDesc):
def __init__(self, fd, bar5_addr):
super().__init__(fd)
self.bar5_addr = bar5_addr
def mmap(self, start, sz, prot, flags, fd, off): return self.bar5_addr + off
class PCIConfigFileDesc(VirtFileDesc):
def __init__(self, fd):
super().__init__(fd)
self.data = bytearray(256)
def read_contents(self, size=None): return bytes(self.data[self.off:self.off + (size or len(self.data) - self.off)])
def write_contents(self, content): self.data[self.off:self.off + len(content)] = content
def seek(self, offset): self.off = offset
class PCIEnableFileDesc(VirtFileDesc):
def __init__(self, fd): super().__init__(fd)
def read_contents(self, size=None): return "1\n"
def write_contents(self, content): pass
class AMDriver(VirtDriver):
def __init__(self):
super().__init__()
self.gpus:dict[int, MockAMGPU] = {}
self._executing = False
self.gpu = MockAMGPU(0)
self.gpus[0] = self.gpu
self.next_fd = 1 << 30
self._bar5_addr = libc.mmap(0, BAR5_SIZE, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | mmap.MAP_ANONYMOUS, -1, 0)
mmio = self.gpu.mmio
self.track_address(self._bar5_addr, self._bar5_addr + BAR5_SIZE,
lambda mv, idx: _bar5_sync_read(mv, idx, mmio), lambda mv, idx: _bar5_sync_write(mv, idx, mmio))
p = f"/sys/bus/pci/devices/{PCIBUS}"
self.tracked_files += [
VirtFile("/proc/sys/vm/compact_unevictable_allowed", functools.partial(TextFileDesc, text="0\n")),
VirtFile("/proc/self/pagemap", functools.partial(PagemapFileDesc, gpu=self.gpu)),
VirtFile("/sys/bus/pci/devices", functools.partial(DirFileDesc, child_names=[PCIBUS])),
VirtFile(f"{p}/vendor", functools.partial(TextFileDesc, text="0x1002\n")),
VirtFile(f"{p}/device", functools.partial(TextFileDesc, text="0x74a1\n")),
VirtFile(f"{p}/enable", PCIEnableFileDesc),
VirtFile(f"{p}/config", PCIConfigFileDesc),
VirtFile(f"{p}/resource", functools.partial(TextFileDesc, text="\n".join(_resource_lines) + "\n")),
VirtFile(f"{p}/resource0", functools.partial(PCIBarFileDesc, memfd=self.gpu.vram_fd)),
VirtFile(f"{p}/resource2", functools.partial(PCIBarFileDesc, memfd=self.gpu.doorbell_fd, driver=self)),
VirtFile(f"{p}/resource5", functools.partial(PCIMMIOBarFileDesc, bar5_addr=self._bar5_addr)),
]
def _alloc_fd(self):
fd = self.next_fd
self.next_fd += 1
return fd
def open(self, name, flags, mode, virtfile): return virtfile.fdcls(self._alloc_fd())
def _emulate_execute(self):
if self._executing: return
self._executing = True
try:
any_progress = True
while any_progress:
any_progress = False
for gpu in self.gpus.values():
for q in gpu.queues:
if q.executing: any_progress |= q.execute() > 0
finally:
self._executing = False
def _bar5_sync_read(mv, idx, mmio):
if isinstance(idx, slice):
for i in range(idx.start or 0, idx.stop or len(mv), idx.step or 1): mv[i] = mmio[i]
else: mv[idx] = mmio[idx]
def _bar5_sync_write(mv, idx, mmio):
if isinstance(idx, slice):
for i in range(idx.start or 0, idx.stop or len(mv), idx.step or 1): mmio[i] = mv[i]
else: mmio[idx] = mv[idx]

309
test/mockgpu/am/amgpu.py Normal file
View file

@ -0,0 +1,309 @@
# mypy: ignore-errors
import ctypes, ctypes.util, struct, functools, os, mmap
from tinygrad.runtime.autogen.am import am
from tinygrad.runtime.support.amd import AMDReg, import_asic_regs
from test.mockgpu.amd.amdgpu import AMDGPU
libc = ctypes.CDLL(ctypes.util.find_library("c"))
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
libc.mmap.restype = ctypes.c_void_p
VRAM_SIZE = 512 << 20
IP_VERSIONS = {
am.GC_HWIP: (12, 0, 0), am.SDMA0_HWIP: (7, 0, 0), am.MMHUB_HWIP: (4, 1, 0), am.NBIO_HWIP: (6, 3, 1),
am.MP0_HWIP: (14, 0, 2), am.MP1_HWIP: (14, 0, 2), am.HDP_HWIP: (7, 0, 0), am.OSSSYS_HWIP: (7, 0, 0),
}
def _pad(t, n=10): return t + (0,) * (n - len(t))
IP_BASES = {
am.GC_HWIP: _pad((0x00001260, 0x0000A000, 0x0001C000, 0x02402C00)),
am.SDMA0_HWIP: _pad((0x00001260, 0x0000A000, 0x0001C000, 0x02402C00)),
am.MMHUB_HWIP: _pad((0x0001A000, 0x02408800)),
am.NBIO_HWIP: _pad((0x00000000, 0x00000014, 0x00000D20, 0x00010400, 0x0241B000, 0x04040000)),
am.MP0_HWIP: _pad((0x00016000, 0x00DC0000, 0x00E00000, 0x00E40000, 0x0243FC00)),
am.MP1_HWIP: _pad((0x00016000, 0x00DC0000, 0x00E00000, 0x00E40000, 0x0243FC00)),
am.HDP_HWIP: _pad((0x00000F20, 0x0240A400)),
am.OSSSYS_HWIP: _pad((0x000010A0, 0x0240A000)),
}
IP_HWIDS = {hwip: am.hw_id_map[hwip] for hwip in IP_VERSIONS}
GC_INFO = dict(gc_num_se=2, gc_num_cu_per_sh=8, gc_num_sh_per_se=2, gc_num_rb_per_se=4,
gc_num_tccs=8, gc_wave_size=32, gc_max_waves_per_simd=16, gc_max_scratch_slots_per_cu=32, gc_lds_size=64)
def _build_ip_regs(prefix, hwip) -> dict[str, AMDReg]:
try: return import_asic_regs(prefix, IP_VERSIONS[hwip], cls=functools.partial(AMDReg, bases={0: IP_BASES[hwip]}))
except Exception: return {}
class MockMMU:
def __init__(self, gpu:'MockAMGPU'):
self.gpu = gpu
self.tlb: dict[int, tuple[int, int, bool]] = {}
def invalidate(self, pt_base:int, va_base:int):
new_tlb: dict[int, tuple[int, int, bool]] = {}
self._walk(pt_base, 0, 0, new_tlb, va_base)
for va, (pa, sz, is_sys) in new_tlb.items():
if va not in self.tlb:
if not is_sys: self.gpu.map_vram_at(va, pa, sz)
self.gpu.map_range(va, sz)
self.tlb = new_tlb
def _walk(self, pt_paddr:int, level:int, va_acc:int, out:dict, va_base:int):
shift = [39, 30, 21, 12][level]
for i in range(512):
pte = struct.unpack_from('<Q', self.gpu.vram, pt_paddr + i * 8)[0]
if not (pte & am.AMDGPU_PTE_VALID): continue
va, pa = va_acc | (i << shift), pte & 0x0000FFFFFFFFF000
if level == 3 or (pte & am.AMDGPU_PDE_PTE_GFX12):
out[va_base + va] = (pa, 1 << shift, bool(pte & am.AMDGPU_PTE_SYSTEM))
else:
self._walk(pa, level + 1, va, out, va_base)
def paddr_to_host(self, paddr:int) -> int:
if paddr < VRAM_SIZE: return self.gpu.vram_addr + paddr
page, off = paddr & ~0xFFF, paddr & 0xFFF
return self.gpu._sysmem_map[page] + off
def addr_to_host(self, addr:int) -> int:
gmc = self.gpu.mmio.gmc
sys_lo = self.gpu.mmio.regs.get(gmc.reg('regMMMC_VM_SYSTEM_APERTURE_LOW_ADDR') or 0, 0) << 18
sys_hi = self.gpu.mmio.regs.get(gmc.reg('regMMMC_VM_SYSTEM_APERTURE_HIGH_ADDR') or 0, 0) << 18
if sys_lo <= addr < sys_hi: return self.paddr_to_host(addr - self.gpu.mc_base)
for tva, (pa, sz, is_sys) in self.tlb.items():
if tva <= addr < tva + sz:
if not is_sys: return addr
return self.paddr_to_host(pa + (addr - tva))
raise ValueError(f"addr {addr:#x} not mapped (sys_aperture=[{sys_lo:#x}, {sys_hi:#x}])")
class MockIPBlock:
def __init__(self, gpu:'MockAMGPU', mmio:'MockMMIOInterface', regs:dict[str, AMDReg]):
self.gpu, self.mmio, self._regs = gpu, mmio, regs
self._n2a = {n: r.addr[0] for n, r in regs.items()}
self._a2n = {a: n for n, a in self._n2a.items()}
self.addrs = set(self._n2a.values())
def reg(self, name) -> int|None: return self._n2a.get(name)
def decode(self, name) -> dict: return self._regs[name].decode(self.mmio.regs.get(self._n2a[name], 0))
def read(self, reg:int) -> int: return self.mmio.regs.get(reg, 0)
def write(self, reg:int, val:int): self.mmio.regs[reg] = val
def _read_pair(self, pair) -> int:
if pair[0] is None: return 0
return self.mmio.regs.get(pair[0], 0) | (self.mmio.regs.get(pair[1], 0) << 32)
class MockPSP(MockIPBlock):
def __init__(self, gpu, mmio):
super().__init__(gpu, mmio, _build_ip_regs('mp', am.MP0_HWIP))
self._sos_alive, self._ring_wptr = False, 0
pref = "regMPASP_SMN_C2PMSG" if IP_VERSIONS[am.MP0_HWIP] >= (14,0,0) else "regMP0_SMN_C2PMSG"
def r(n): return self.reg(f"{pref}_{n}")
self._c2pmsg_35, self._c2pmsg_64, self._c2pmsg_67 = r(35), r(64), r(67)
self._c2pmsg_69, self._c2pmsg_70, self._c2pmsg_81 = r(69), r(70), r(81)
def read(self, reg:int) -> int:
if reg == self._c2pmsg_35: return 0x80000000
if reg == self._c2pmsg_81: return 0x1 if self._sos_alive else 0x0
if reg == self._c2pmsg_64: return 0x80000000 if self._sos_alive else 0x0
if reg == self._c2pmsg_67: return self._ring_wptr
return super().read(reg)
def write(self, reg:int, val:int):
super().write(reg, val)
if reg == self._c2pmsg_35 and val == am.PSP_BL__LOAD_SOSDRV: self._sos_alive = True
if reg == self._c2pmsg_67: self._ring_submit(val)
def _ring_submit(self, new_wptr:int):
old_wptr = self._ring_wptr
self._ring_wptr = new_wptr
lo, hi = self._c2pmsg_69, self._c2pmsg_70
if lo is None or hi is None: return
ring_mc = self.mmio.regs.get(lo, 0) | (self.mmio.regs.get(hi, 0) << 32)
ring_paddr = ring_mc - self.gpu.mc_base
frame_off = ring_paddr + old_wptr * 4
frame = am.struct_psp_gfx_rb_frame.from_buffer_copy(bytes(self.gpu.vram[frame_off:frame_off + ctypes.sizeof(am.struct_psp_gfx_rb_frame)]))
fence_paddr = ((frame.fence_addr_hi << 32) | frame.fence_addr_lo) - self.gpu.mc_base
if 0 <= fence_paddr < len(self.gpu.vram):
struct.pack_into('<I', self.gpu.vram, fence_paddr, frame.fence_value)
cmd_paddr = ((frame.cmd_buf_addr_hi << 32) | frame.cmd_buf_addr_lo) - self.gpu.mc_base
if 0 <= cmd_paddr < len(self.gpu.vram):
struct.pack_into('<I', self.gpu.vram, cmd_paddr + 864, 0)
class MockSMU(MockIPBlock):
def __init__(self, gpu, mmio):
try: regs = import_asic_regs('mp', (11, 0), cls=functools.partial(AMDReg, bases={0: IP_BASES[am.MP1_HWIP]}))
except Exception: regs = {}
super().__init__(gpu, mmio, regs)
self._msg_pending = False
def r(n): return self.reg(f"mmMP1_SMN_C2PMSG_{n}")
self._c2pmsg_53, self._c2pmsg_54, self._c2pmsg_66 = r(53), r(54), r(66)
self._c2pmsg_75, self._c2pmsg_82, self._c2pmsg_90 = r(75), r(82), r(90)
def read(self, reg:int) -> int:
if reg == self._c2pmsg_90 or reg == self._c2pmsg_54: return 0x1 if self._msg_pending else super().read(reg)
if reg == self._c2pmsg_82: return self.mmio.regs.get(reg, 3)
return super().read(reg)
def write(self, reg:int, val:int):
super().write(reg, val)
if reg == self._c2pmsg_66 or reg == self._c2pmsg_75: self._msg_pending = True
if (reg == self._c2pmsg_90 or reg == self._c2pmsg_54) and val == 0: self._msg_pending = False
class MockSDMA(MockIPBlock):
def __init__(self, gpu, mmio):
all_gc = _build_ip_regs('gc', am.GC_HWIP)
super().__init__(gpu, mmio, {n: r for n, r in all_gc.items() if 'SDMA' in n})
def write(self, reg:int, val:int):
super().write(reg, val)
name = self._a2n.get(reg, '')
if name.endswith('_RB_CNTL') and self._regs[name].decode(val).get('rb_enable', 0):
self._activate_queue(name.rsplit('_RB_CNTL', 1)[0])
def _activate_queue(self, prefix:str):
ring_addr = self._read_pair((self.reg(f'{prefix}_RB_BASE'), self.reg(f'{prefix}_RB_BASE_HI'))) << 8
rptr_addr = self._read_pair((self.reg(f'{prefix}_RB_RPTR_ADDR_LO'), self.reg(f'{prefix}_RB_RPTR_ADDR_HI')))
wptr_addr = self._read_pair((self.reg(f'{prefix}_RB_WPTR_POLL_ADDR_LO'), self.reg(f'{prefix}_RB_WPTR_POLL_ADDR_HI')))
rb_size = self.decode(f'{prefix}_RB_CNTL')['rb_size']
self.gpu.add_sdma_queue(self.gpu.mmu.addr_to_host(ring_addr), 4 << rb_size,
self.gpu.mmu.addr_to_host(rptr_addr), self.gpu.mmu.addr_to_host(wptr_addr))
class MockGFX(MockIPBlock):
def __init__(self, gpu, mmio):
super().__init__(gpu, mmio, _build_ip_regs('gc', am.GC_HWIP))
self._pt_base = (self.reg('regGCVM_CONTEXT0_PAGE_TABLE_BASE_ADDR_LO32'), self.reg('regGCVM_CONTEXT0_PAGE_TABLE_BASE_ADDR_HI32'))
self._pt_start = (self.reg('regGCVM_CONTEXT0_PAGE_TABLE_START_ADDR_LO32'), self.reg('regGCVM_CONTEXT0_PAGE_TABLE_START_ADDR_HI32'))
self._gc_inv_ack = self.reg('regGCVM_INVALIDATE_ENG17_ACK')
self._gc_inv_req = self.reg('regGCVM_INVALIDATE_ENG17_REQ')
self._hqd_active = self.reg('regCP_HQD_ACTIVE')
def read(self, reg:int) -> int:
if reg == self.reg('regCP_STAT') or reg == self.reg('regRLC_SAFE_MODE'): return 0
if reg == self.reg('regRLC_RLCS_BOOTLOAD_STATUS'): return 0x2
if reg == self._gc_inv_ack: return 0x1
return super().read(reg)
def write(self, reg:int, val:int):
super().write(reg, val)
if reg == self.reg('regCP_HQD_DEQUEUE_REQUEST'):
if self._hqd_active is not None: self.mmio.regs[self._hqd_active] = 0
if reg == self._hqd_active and val == 1: self._activate_pm4_queue()
if reg == self._gc_inv_req: self.gpu.mmu.invalidate(self.get_pt_base(), self.get_va_base())
def _activate_pm4_queue(self):
ring_addr = self._read_pair((self.reg('regCP_HQD_PQ_BASE'), self.reg('regCP_HQD_PQ_BASE_HI'))) << 8
rptr_addr = self._read_pair((self.reg('regCP_HQD_PQ_RPTR_REPORT_ADDR'), self.reg('regCP_HQD_PQ_RPTR_REPORT_ADDR_HI')))
wptr_addr = self._read_pair((self.reg('regCP_HQD_PQ_WPTR_POLL_ADDR'), self.reg('regCP_HQD_PQ_WPTR_POLL_ADDR_HI')))
queue_size = self.decode('regCP_HQD_PQ_CONTROL')['queue_size']
self.gpu.add_pm4_queue(self.gpu.mmu.addr_to_host(ring_addr), 4 << (queue_size + 1),
self.gpu.mmu.addr_to_host(rptr_addr), self.gpu.mmu.addr_to_host(wptr_addr))
def get_pt_base(self) -> int: return self._read_pair(self._pt_base) & 0x0000FFFFFFFFF000
def get_va_base(self) -> int: return self._read_pair(self._pt_start) << 12
class MockGMC(MockIPBlock):
def __init__(self, gpu, mmio, gfx:MockGFX):
super().__init__(gpu, mmio, _build_ip_regs('mmhub', am.MMHUB_HWIP))
self._gfx = gfx
self._inv_ack = self.reg('regMMVM_INVALIDATE_ENG17_ACK')
self._inv_sem = self.reg('regMMVM_INVALIDATE_ENG17_SEM')
self._inv_req = self.reg('regMMVM_INVALIDATE_ENG17_REQ')
self._fb_loc_top = self.reg('regMMMC_VM_FB_LOCATION_TOP')
def read(self, reg:int) -> int:
if reg == self._inv_ack or reg == self._inv_sem: return 0x1
if reg == self._fb_loc_top: return VRAM_SIZE >> 24
return super().read(reg)
def write(self, reg:int, val:int):
super().write(reg, val)
if reg == self._inv_req: self.gpu.mmu.invalidate(self._gfx.get_pt_base(), self._gfx.get_va_base())
class MockNBIO(MockIPBlock):
def __init__(self, gpu, mmio):
regs = _build_ip_regs('nbif', am.NBIO_HWIP)
regs.update(_build_ip_regs('hdp', am.HDP_HWIP))
super().__init__(gpu, mmio, regs)
self._remap_hdp = self.reg('regBIF_BX0_REMAP_HDP_MEM_FLUSH_CNTL')
self._hdp_flush = self.reg('regHDP_MEM_FLUSH_CNTL')
def read(self, reg:int) -> int:
if reg == self._remap_hdp and self._hdp_flush is not None: return self._hdp_flush * 4
return super().read(reg)
class MockMMIOInterface:
def __init__(self, gpu:'MockAMGPU'):
self.gpu = gpu
self.regs: dict[int, int] = {}
gfx = MockGFX(gpu, self)
self.gmc = MockGMC(gpu, self, gfx)
self.blocks = [MockPSP(gpu, self), MockSMU(gpu, self), MockSDMA(gpu, self), gfx, self.gmc, MockNBIO(gpu, self)]
self._addr_block: dict[int, MockIPBlock] = {}
for block in self.blocks:
for addr in block.addrs: self._addr_block.setdefault(addr, block)
def __getitem__(self, index:int|slice) -> int|list[int]:
if isinstance(index, slice): return [self[i] for i in range(index.start or 0, index.stop or 0, index.step or 1)] # type: ignore[misc]
if index == 0xde3: return VRAM_SIZE >> 20
if block := self._addr_block.get(index): return block.read(index)
return self.regs.get(index, 0)
def __setitem__(self, index:int|slice, val:int|list[int]|tuple[int, ...]):
if isinstance(index, slice):
vals = val if isinstance(val, (list, tuple)) else [val] * ((index.stop - index.start) // (index.step or 1)) # type: ignore[operator]
for i, v in zip(range(index.start or 0, index.stop or 0, index.step or 1), vals): self[i] = v
return
assert isinstance(val, int)
self.regs[index] = val
if block := self._addr_block.get(index): block.write(index, val)
def __len__(self): return 0x10000000
class MockAMGPU(AMDGPU):
def __init__(self, gpuid:int=0):
super().__init__(gpuid)
self.vram_fd = os.memfd_create("vram")
os.ftruncate(self.vram_fd, VRAM_SIZE)
self.vram_addr = libc.mmap(0, VRAM_SIZE, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, self.vram_fd, 0)
self.vram = (ctypes.c_ubyte * VRAM_SIZE).from_address(self.vram_addr)
self.doorbell_fd = os.memfd_create("doorbell")
os.ftruncate(self.doorbell_fd, 0x2000)
self.arch = "rdna4"
self._sysmem_map:dict[int,int] = {}
self._next_sysmem_paddr = 0x100000000
self.mmu = MockMMU(self)
self.mmio = MockMMIOInterface(self)
self._preboot()
def map_vram_at(self, va:int, paddr:int, size:int):
libc.mmap(va, size, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | 0x10, self.vram_fd, paddr)
def _preboot(self):
ip_data = bytearray()
for hwip, (major, minor, rev) in IP_VERSIONS.items():
ip = am.struct_ip_v4(hw_id=IP_HWIDS[hwip], num_base_address=len(IP_BASES[hwip]), major=major, minor=minor, revision=rev)
ip_data += bytes(ip) + b'\x00'
for b in IP_BASES[hwip]: ip_data += struct.pack('<I', b)
dhdr = am.struct_die_header(num_ips=len(IP_VERSIONS))
ihdr = am.struct_ip_discovery_header(signature=am.DISCOVERY_TABLE_SIGNATURE, version=4, num_dies=1)
ip_disc_off = ctypes.sizeof(am.struct_binary_header)
ihdr.die_info[0].die_offset = ip_disc_off + ctypes.sizeof(am.struct_ip_discovery_header)
gc = am.struct_gc_info_v2_1()
gc.header.table_id, gc.header.version_major, gc.header.version_minor = am.GC, 2, 1
gc.header.size = ctypes.sizeof(am.struct_gc_info_v2_1)
for field, val in GC_INFO.items(): setattr(gc, field, val)
gc_off = ip_disc_off + ctypes.sizeof(am.struct_ip_discovery_header) + ctypes.sizeof(am.struct_die_header) + len(ip_data)
bhdr = am.struct_binary_header(binary_signature=am.BINARY_SIGNATURE)
bhdr.table_list[am.IP_DISCOVERY].offset = ip_disc_off
bhdr.table_list[am.GC].offset = gc_off
tbl = bytes(bhdr) + bytes(ihdr) + bytes(dhdr) + ip_data + bytes(gc)
tbl_offset = VRAM_SIZE - (64 << 10)
self.vram[tbl_offset:tbl_offset + len(tbl)] = list(tbl)
@property
def mc_base(self) -> int:
fb_loc_base = self.mmio.gmc.reg('regMMMC_VM_FB_LOCATION_BASE') or 0
return (self.mmio.regs.get(fb_loc_base, 0) & 0xFFFFFF) << 24

View file

@ -2,6 +2,7 @@ import ctypes, ctypes.util, time, os, builtins, fcntl
from tinygrad.runtime.support.hcq import FileIOInterface
from test.mockgpu.nv.nvdriver import NVDriver
from test.mockgpu.amd.amddriver import AMDDriver
from test.mockgpu.am.amdriver import AMDriver
start = time.perf_counter()
# *** ioctl lib ***
@ -9,7 +10,7 @@ libc = ctypes.CDLL(ctypes.util.find_library("c"))
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
libc.mmap.restype = ctypes.c_void_p
drivers = [AMDDriver(), NVDriver()]
drivers = [NVDriver(), AMDriver() if os.environ.get("AMD_IFACE") == "PCI" else AMDDriver()]
tracked_fds = {}
original_memoryview = builtins.memoryview
@ -77,9 +78,10 @@ class MockFileIOInterface(FileIOInterface):
return libc.mmap(start, sz, prot, flags, self.fd, offset)
def read(self, size=None, binary=False, offset=None):
if binary: raise NotImplementedError()
if self.fd in tracked_fds:
if offset is not None: tracked_fds[self.fd].seek(offset)
return tracked_fds[self.fd].read_contents(size)
if binary: raise NotImplementedError()
with open(self.fd, "rb" if binary else "r", closefd=False) as file:
if file.tell() >= os.fstat(self.fd).st_size: file.seek(0)
return file.read(size)
@ -89,13 +91,20 @@ class MockFileIOInterface(FileIOInterface):
return tracked_fds[self.fd].list_contents()
return os.listdir(self.path)
def write(self, content, binary=False, offset=None): raise NotImplementedError()
def write(self, content, binary=False, offset=None):
if self.fd in tracked_fds:
if offset is not None: tracked_fds[self.fd].seek(offset)
return tracked_fds[self.fd].write_contents(content)
raise NotImplementedError()
def seek(self, offset):
if self.fd in tracked_fds:
tracked_fds[self.fd].seek(offset)
else:
os.lseek(self.fd, offset, os.SEEK_CUR)
@staticmethod
def anon_mmap(start, sz, prot, flags, offset):
return FileIOInterface._mmap(start, sz, prot, flags & ~0x4a000, -1, offset) # strip MAP_LOCKED|MAP_POPULATE|MAP_HUGETLB
@staticmethod
def exists(path): return _open(path, os.O_RDONLY) is not None
@staticmethod
def readlink(path): raise NotImplementedError()

View file

@ -9,7 +9,9 @@ def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once
schedule = t.schedule()
asts = [s for s in schedule if s.ast.op is Ops.SINK]
assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
len(asts)
# NOT SUPPORTED ANYMORE
#assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
class TestUnaryOpsConstFolding(unittest.TestCase):
def test_all_consts_ops(self):

101
test/null/test_gpudims.py Normal file
View file

@ -0,0 +1,101 @@
import unittest, math
import z3
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.validate import uops_to_z3
from tinygrad.dtype import dtypes
from tinygrad.helpers import flatten, dedup
class TestGroupedDims(unittest.TestCase):
def _check_grouped_dims(self, prefix, dims, max_sizes, reverse, expected_sizes, assert_same_length=True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse)
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
sizes = [x.src[0].arg for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
self._verify_indices_z3(idxs, dims)
def _verify_indices_z3(self, idxs, dims):
"""Use z3 to prove bijectivity: bounds (0 <= flat < total) + injectivity (different inputs => different flat)."""
total = math.prod(dims)
specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg)
# build flat index and primed flat (same expression with renamed SPECIALs)
flat = UOp.const(dtypes.index, 0)
for i, idx in enumerate(idxs):
flat = flat + idx * int(math.prod(dims[i+1:]))
flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials})
solver = z3.Solver()
[z3_flat, z3_flat_p] = uops_to_z3(solver, flat, flat_p)
# bounds
self.assertEqual(solver.check(z3_flat < 0), z3.unsat, f"flat can be negative: {dims=}")
self.assertEqual(solver.check(z3_flat >= total), z3.unsat, f"flat can be >= {total}: {dims=}")
# injectivity: flat == flat' but inputs differ => unsat
inputs_differ = z3.Or(*[z3.Int(s.arg) != z3.Int(s.arg+"_p") for s in specials])
self.assertEqual(solver.check(z3.And(z3_flat == z3_flat_p, inputs_differ)), z3.unsat, f"not injective: {dims=}")
def test_grouped_dims(self):
# no-op
self._check_grouped_dims("gidx", (2,), (16,16,16), False, [2])
self._check_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
self._check_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
self._check_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals: len(dims) == len(max)
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
self._check_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
self._check_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
self._check_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
self._check_grouped_dims("gidx", (5,12,7), (8,4,16), False, [10,3,14])
# prefer group_dim strategy when possible
self._check_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
self._check_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
self._check_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 2
self._check_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
# test when the only divisor is the square root of dim
self._check_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# 2 -> 3
self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# collapse on onto the left most axis
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
self._check_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# collapse on left-most available axis (the left most is too small)
self._check_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
def test_grouped_direct_dims_are_special(self):
# when (2,3) are merged into 6, the unmerged dims (4,5) should map directly to SPECIAL ops (no div/mod)
idxs = get_grouped_dims("gidx", (2,3,4,5), (16,16,16), False)
assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}"
assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}"
def test_max_sizes_none(self):
self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4])
self._check_grouped_dims("gidx", (100,), None, False, [100])
if __name__ == '__main__':
unittest.main()

View file

@ -1,11 +1,13 @@
import unittest
import gc, unittest
from tinygrad import Tensor, GlobalCounters, dtypes
class TestMultiRamUsage(unittest.TestCase):
def setUp(self):
gc.collect()
self.baseline = GlobalCounters.mem_used
self.N = 100
def assertUsed(self, amt, strict=True):
gc.collect()
used = GlobalCounters.mem_used - self.baseline
print(f"used {used} bytes")
if strict: self.assertEqual(used, amt)
@ -20,20 +22,17 @@ class TestMultiRamUsage(unittest.TestCase):
del _
self.assertUsed(0)
@unittest.skip("flaky")
def test_zeros_copy(self):
devices_2 = ("NULL:1", "NULL:2")
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
# NOTE: the first one on the DEFAULT device should be freed
self.assertUsed(self.N*self.N*4*2)
@unittest.skip("flaky")
def test_zeros_shard(self, devices=("NULL:1", "NULL:2")):
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices, axis=0).realize()
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
def test_zeros_shard_self(self): self.test_zeros_shard(("NULL:0", "NULL:1"))
@unittest.skip("flaky")
def test_zeros_contiguous_shard(self):
devices_2 = ("NULL:1", "NULL:2")
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
@ -54,5 +53,26 @@ class TestMultiRamUsage(unittest.TestCase):
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
class TestMultiAxis(unittest.TestCase):
def test_reshape_shard_invalid(self):
devices = ("NULL:0", "NULL:1")
t = Tensor.ones(4, 3).shard(devices, axis=0)
with self.assertRaises(RuntimeError, msg="reshape cannot move items between shards"):
t.reshape(3, 4).uop.axis
def test_reshape_shard_valid(self):
devices = ("NULL:0", "NULL:1")
t = Tensor.ones(4, 8).shard(devices, axis=0)
self.assertEqual(t.reshape(2, 16).uop.axis, 0)
self.assertEqual(t.reshape(2, 2, 8).uop.axis, 0)
def test_empty_like_sharded(self):
t = Tensor.ones(4, 8).shard(("NULL:0", "NULL:1"), axis=0)
e = t.empty_like()
self.assertEqual(e.shape, t.shape)
self.assertEqual(e.device, t.device)
self.assertEqual(e.uop.axis, 0)
self.assertTrue(e.uop.has_buffer_identity())
if __name__ == '__main__':
unittest.main()

View file

@ -1,21 +0,0 @@
import unittest
from tinygrad import Tensor, Device
from tinygrad.helpers import CPU_LLVM, CPU_LVP, CPU_X86
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import get_program
class TestOpts(unittest.TestCase):
def test_opt_upcast(self):
opts = (Opt(OptOps.UPCAST, 0, 4),)
a = Tensor.empty(16)
b = Tensor.empty(16)
out = (a+b).contiguous(arg=opts)
s = out.schedule()
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP and not CPU_X86:
prg = get_program(s[-1].ast, renderer=Device[Device.DEFAULT].renderer)
self.assertIn('float4', prg.src)
if __name__ == '__main__':
unittest.main()

View file

@ -1,202 +0,0 @@
import unittest
from tinygrad import dtypes
from tinygrad.uop.ops import UOp, graph_rewrite_map, _substitute
from tinygrad.uop.symbolic import symbolic
class TestRewriteMap(unittest.TestCase):
def test_substitute(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
e = UOp.variable('e', 0, 10)
ret = (a+b)*c
sub = {a+b: e}
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
self.assertIs(sub_map[a+b], e)
self.assertIs(sub_map[(a+b)*c], e*c)
def test_substitute_depth_2(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
d = UOp.variable('d', 0, 10)
e = UOp.variable('e', 0, 10)
f = UOp.variable('f', 0, 10)
ret = (a+b)*c+d
sub = {a+b: e, (a+b)*c: f}
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
self.assertIs(sub_map[a+b], e)
self.assertIs(sub_map[(a+b)*c], f)
def test_multistage_substitute(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
d = UOp.variable('d', 0, 10)
sub1 = {a+b:c}
start = (a+b)*c
# stage 1: (a+b)*c -> c*c
sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True)
self.assertIs(sub_map1[(a+b)*c], c*c)
# stage 2: c*c -> d
sub2 = {c*c:d}
sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True)
# (a+b)*c -> c*c -> d
self.assertIs(sub_map2[(a+b)*c], d)
def test_add_zero(self):
# Build a small graph: add(0, add(const=0, const=5))
zero_node = UOp.const(dtypes.index, 0)
five_node = UOp.const(dtypes.index, 5)
inner_add = zero_node + five_node
root_add = zero_node + inner_add
# Perform top-down rewrite
node_map = graph_rewrite_map(root_add, symbolic)
# We expect that add(0, add(0, 5)) -> add(0, 5) -> 5
# Check the mapping
assert node_map[root_add] == five_node
assert node_map[inner_add] == five_node
# zero_node and five_node map to themselves
assert node_map[zero_node] == zero_node
assert node_map[five_node] == five_node
def test_double_neg(self):
"""
Test rewriting neg(neg(5)) => 5 using symbolic.
"""
# In some versions of TinyGrad, you might do: (-(-five_node))
five_node = UOp.const(dtypes.index, 5)
# If your code allows UOp(...), do that; else you might do something like:
# double_neg_five = -(-five_node)
# But let's be explicit:
neg_five = -five_node
double_neg_five = -neg_five
node_map = graph_rewrite_map(double_neg_five, symbolic)
# node_map should map double_neg_five -> five_node
self.assertEqual(node_map[double_neg_five], five_node)
# five_node maps to itself
self.assertEqual(node_map[five_node], five_node)
def test_add_zero_and_double_neg(self):
"""
Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5
"""
zero_node = UOp.const(dtypes.index, 0)
five_node = UOp.const(dtypes.index, 5)
neg_five = -five_node
double_neg_five = -neg_five
root_add = zero_node + double_neg_five
node_map = graph_rewrite_map(root_add, symbolic)
# node_map: root_add -> five_node, double_neg_five -> five_node
self.assertEqual(node_map[root_add], five_node)
self.assertEqual(node_map[double_neg_five], five_node)
# zero_node, five_node map to themselves
self.assertEqual(node_map[zero_node], zero_node)
self.assertEqual(node_map[five_node], five_node)
def test_multi_var_rewrites(self):
x_var = UOp.variable('x', 0, 10)
y_var = UOp.variable('y', -5, 5)
zero_node = UOp.const(dtypes.index, 0)
sum_with_zero = y_var + zero_node # (y + 0)
combined = x_var + sum_with_zero # x + (y + 0)
double_neg = -(-combined) # neg(neg(x + y))
final_expr = zero_node + double_neg # 0 + (x + y)
node_map = graph_rewrite_map(final_expr, symbolic)
# The final root should be (x_var + y_var).
expected = x_var + y_var
# Each sub-expression has its own "final" result.
# (y + 0) -> y_var
self.assertEqual(node_map[sum_with_zero], y_var)
# (x + (y+0)) -> (x + y)
self.assertEqual(node_map[combined], expected)
# neg(neg(x+y)) -> (x + y)
self.assertEqual(node_map[double_neg], expected)
# 0 + (x+y) -> (x + y)
self.assertEqual(node_map[final_expr], expected)
# x_var, y_var, zero_node remain unchanged
self.assertEqual(node_map[x_var], x_var)
self.assertEqual(node_map[y_var], y_var)
self.assertEqual(node_map[zero_node], zero_node)
def test_complex_multi_var_edges(self):
"""
Build a multi-variable expression with multiple intermediates:
x_var = UOp.variable('x', 1, 10)
y_var = UOp.variable('y', -5, 5)
z_var = UOp.variable('z', 0, 5)
zero_node = UOp.const(dtypes.int, 0)
one_node = UOp.const(dtypes.int, 1)
yz_sum = y_var + z_var
yz_sum_zero = yz_sum + zero_node -> rewrites to yz_sum
yz_neg = -yz_sum_zero -> -(y+z)
yz_dneg = -yz_neg -> y+z (double neg gone)
x_plus_yz = x_var + yz_dneg -> x + (y+z)
double_neg_x = -(-x_plus_yz) -> x + (y+z)
final_expr = double_neg_x * one_node -> x + (y+z)
We expect the final result to be (x + (y+z)).
Each original node should map to the final node that replaces it,
which might be structurally equivalent but not the same reference.
"""
x_var = UOp.variable('x', 1, 10)
y_var = UOp.variable('y', -5, 5)
z_var = UOp.variable('z', 0, 5)
zero_node = UOp.const(dtypes.index, 0)
one_node = UOp.const(dtypes.index, 1)
# Build sub-expressions
yz_sum = y_var + z_var # (y + z)
yz_sum_zero = yz_sum + zero_node # (y + z) + 0
yz_neg = -yz_sum_zero # -(y+z)
yz_dneg = -yz_neg # -(-(y+z)) -> (y+z)
x_plus_yz = x_var + yz_dneg # x + (y+z)
double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> x+(y+z)
final_expr = double_neg_x * one_node # (x+(y+z)) * 1 -> x+(y+z)
node_map = graph_rewrite_map(final_expr, symbolic)
# (y + z) is unchanged
self.assertEqual(node_map[yz_sum], yz_sum)
# (y+z) + 0 => (y+z)
self.assertEqual(node_map[yz_sum_zero], yz_sum)
# -(y+z) remains -(y+z), but might be a new UOp with updated children
# Compare structurally to -(y_var + z_var).
self.assertEqual(node_map[yz_neg], -yz_sum)
# -(-(y+z)) => (y+z)
self.assertEqual(node_map[yz_dneg], yz_sum)
# x + (y+z) => might get recreated if yz_dneg was changed, so compare to x + yz_sum
self.assertEqual(node_map[x_plus_yz], x_var + yz_sum)
# -(-(x+(y+z))) => x + (y+z)
self.assertEqual(node_map[double_neg_x], x_var + yz_sum)
# (x+(y+z)) * 1 => x+(y+z)
self.assertEqual(node_map[final_expr], x_var + yz_sum)
# Unchanged atomic nodes map to themselves
self.assertEqual(node_map[x_var], x_var)
self.assertEqual(node_map[y_var], y_var)
self.assertEqual(node_map[z_var], z_var)
self.assertEqual(node_map[zero_node], zero_node)
self.assertEqual(node_map[one_node], one_node)
if __name__ == "__main__":
unittest.main()

View file

@ -1,5 +1,5 @@
# schedule tests that pass on NULL backend (no copyout needed)
import unittest, time
import gc, unittest, time
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
@ -59,7 +59,7 @@ class TestBufferUOp(unittest.TestCase):
def test_buffer_view_not_allowed(self):
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
with self.assertRaisesRegex(AssertionError, "can only be RESHAPE"):
with self.assertRaises(RuntimeError):
permuted_view.uop.buffer # cannot access Buffer of a non contiguous VIEW
def test_buffer_only_after_realize(self):
@ -74,7 +74,7 @@ class TestBufferUOp(unittest.TestCase):
self.assertIsNotNone(a.uop.buffer)
def test_const_does_not_realize(self):
a = Tensor(1)+Tensor(2)
a = Tensor(1)
run_schedule(check_schedule(a, 0))
self.assertIsNone(a.uop.base.realized)
@ -191,6 +191,13 @@ class TestSchedule(unittest.TestCase):
a, _ = Tensor.empty(1022).cummax(axis=0)
check_schedule(a, 3)
@unittest.skip("should this pass?")
def test_contiguous_assign(self):
a = Tensor.ones(10) * 2
b = Tensor.empty(10)
c = b.assign(a.contiguous())
check_schedule(c, 1)
def test_basic_binop_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
@ -198,6 +205,14 @@ class TestSchedule(unittest.TestCase):
d = a+b+c
check_schedule(d, 1)
def test_basic_binop_fusion_assign(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = a+b+c
e = Tensor.empty(10).assign(d)
check_schedule(e, 1)
def test_basic_binop_fusion_deep(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
@ -212,6 +227,13 @@ class TestSchedule(unittest.TestCase):
c = (a*b).sum()
check_schedule(c, 1)
def test_mulacc_fusion_assign(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a*b).sum()
d = Tensor.empty(1).assign(c)
check_schedule(d, 1)
def test_mulacc_relu_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
@ -390,20 +412,20 @@ class TestSchedule(unittest.TestCase):
out = bn(c1(img)).relu()
check_schedule(out, 4, [c1.weight, c1.bias])
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 27), (nn.optim.SGD, 7)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
_realize_weights([c1, bn])
opt = optim(nn.state.get_parameters([c1, bn]))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
check_schedule(opt.schedule_step(), cnt)
def test_fold_conv_batchnorm_optim(self, adam=False):
# 2 is too low?
optim, cnt = (nn.optim.Adam, 16) if adam else (nn.optim.SGD, 2)
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
_realize_weights([c1, bn])
opt = optim(nn.state.get_parameters([c1, bn]))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
check_schedule(opt.schedule_step(), cnt)
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
def test_fold_batchnorm_backward(self):
with Tensor.train():
@ -627,6 +649,7 @@ class TestSchedule(unittest.TestCase):
t = Tensor([1.0, 2.0, 3.0]) ** 8
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL])
@unittest.skip("const folding is removed")
def test_pow_const_tensor_to_zero(self):
x = Tensor([1,2,3,4])
out = x ** Tensor(0.0)
@ -751,7 +774,7 @@ class TestSchedule(unittest.TestCase):
_realize_weights(layer)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 19)
check_schedule(opt.schedule_step(), 13)
def test_adam_conv_fuse(self):
with Tensor.train():
@ -761,7 +784,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 19)
check_schedule(opt.schedule_step(), 13)
def test_adam_2convs_fuse(self):
with Tensor.train():
@ -772,7 +795,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 21)
check_schedule(opt.schedule_step(), 15)
def test_sgd_conv_fuse(self):
with Tensor.train():
@ -804,7 +827,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 13)
check_schedule(opt.schedule_step(), 11)
def test_sgd_4convs_fuse(self):
with Tensor.train():
@ -880,9 +903,11 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 2)
def test_schedule_mem_used(self):
gc.collect()
base = GlobalCounters.mem_used
Tensor.ones(256).contiguous().realize()
Tensor.ones(5, 5).contiguous().schedule()
gc.collect()
self.assertEqual(GlobalCounters.mem_used-base, 0)
def test_const_schedule(self):
@ -986,6 +1011,7 @@ class TestUOpBecome(unittest.TestCase):
# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer
@unittest.skip("no longer supported")
def test_reorder_expand(self):
a = Tensor.empty(4, 1)
b = a.expand(4, 4).reciprocal()
@ -1021,6 +1047,7 @@ class TestUOpBecome(unittest.TestCase):
late_add = noop+2
late_add.realize()
@unittest.skip("const folding is removed")
def test_become_const_in_base(self):
a = Tensor.empty(4)
b = a*0
@ -1028,6 +1055,7 @@ class TestUOpBecome(unittest.TestCase):
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER)
@unittest.skip("const folding is removed")
def test_become_const_from_const(self):
const_add = Tensor(1)+Tensor(2)
assert UPat(Ops.ADD).match(const_add.uop, {})
@ -1116,7 +1144,7 @@ class TestFusionOp(unittest.TestCase):
a = Tensor(val)
for _ in range(24): a = Tensor.stack(a, a)[0]
sched = a.schedule()
self.assertEqual(len(sched), 0)
self.assertLessEqual(len(sched), 1)
self.assertLess(time.perf_counter()-st, 2.0)
def test_recursive_reshape(self):

View file

@ -4,6 +4,7 @@ from tinygrad.tensor import _METADATA
from tinygrad.engine.realize import capturing
from tinygrad.helpers import Context
@unittest.skip("tensor metadata is no longer supported")
class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None:
_METADATA.set(None)

View file

@ -8,6 +8,7 @@ def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pa
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
class TestTensorMutates(unittest.TestCase):
@unittest.skip("this doesn't mutate anymore")
def test_mutate_add(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])

View file

@ -655,6 +655,24 @@ class TestSymbolic(unittest.TestCase):
with self.assertRaises(AssertionError):
self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)")
def test_div_mod_recombine_3level(self):
gidx = Variable("gidx", 0, 150527)
self.helper_test_variable(gidx//3%224*3 + gidx%3 + gidx//672*672, 0, 150527, "gidx")
# different shapes
x = Variable("x", 0, 5*7*11-1)
self.helper_test_variable(x//11%7*11 + x%11 + x//77*77, 0, 5*7*11-1, "x")
# result is x//a*c2 not just x
x2 = Variable("x2", 0, 5*6*7-1)
self.helper_test_variable(x2//7%6*14 + x2//42*84, 0, (5*6*7-1)//7*14, "(x2//7*14)")
# negative variable range
xn = Variable("x", -1000, 1000)
self.helper_test_variable(xn//3%224*3 + xn%3 + xn//672*672, -1000, 1000, "x")
self.helper_test_variable(xn//3%7*3 + xn//21*21, -999, 999, "(x//3*3)")
# should NOT simplify: a*c1 != b (3*224 != 600)
self.helper_test_variable(gidx//3%224*3 + gidx//600*600, 0, 150669, "(gidx//600*600+gidx//3%224*3)")
# should NOT simplify: c1*c2 != c3 (224*3 != 700)
self.helper_test_variable(gidx//3%224*3 + gidx//672*700, 0, 156769, "(gidx//672*700+gidx//3%224*3)")
def test_div_mod_recombine_with_gcd(self):
b = Variable("b", 0, 100)
exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18
@ -835,34 +853,33 @@ class TestSymbolicNumeric(unittest.TestCase):
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
class TestSymbolicVars(unittest.TestCase):
class TestSymbolicVariables(unittest.TestCase):
def test_simple(self):
z = uconst(0)
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert z.vars() == z.vars() == set()
print(a.vars())
assert a.vars() == a.vars() == {a}
assert z.variables() == []
assert a.variables() == [a]
m = a * 3
assert m.vars() == {a}
assert m.variables() == [a]
s = usum([a, b, c])
assert s.vars() == {a, b, c}
assert s.variables() == [a, b, c]
def test_compound(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert (a + b * c).vars() == {a, b, c}
assert (a % 3 + b // 5).vars() == {a, b}
assert (a + b * c).variables() == [a, b, c]
assert (a % 3 + b // 5).variables() == [a, b]
# TODO: fix me
with self.assertRaises(AssertionError):
assert (a + b + c - a).vars() == {b, c}
assert (a + b + c - a).variables() == [b, c]
def test_dedup(self):
a = Variable("a", 0, 10)
assert (a * a).vars() == {a}
assert (a//4 + a//6).vars() == {a}
assert (a * a).variables() == [a]
assert (a//4 + a//6).variables() == [a]
class TestSymInfer(unittest.TestCase):
def test_sym_infer(self):

View file

@ -94,12 +94,6 @@ class TestExecALU(unittest.TestCase):
# test no truncate
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)
class TestConstantFolding(unittest.TestCase):
def test_cast_const(self):
t = Tensor(1, dtype=dtypes.float).cast(dtypes.int)
si = t.schedule()
assert len(si) == 0
class TestGatedStoreRewrite(unittest.TestCase):
def test_tiny_gate_store(self):
gmem = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)

View file

@ -45,6 +45,7 @@ class TestMemoryCount(unittest.TestCase):
_, mem = get_stats(a+b)
self.assertEqual(mem, 1024*1024*2 + 1024) # 1 full read + 1 lil read + 1 write
@unittest.skip("no longer supported")
def test_both_expanded(self):
# TODO: this probably should be a full write
a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)

View file

@ -286,6 +286,20 @@ class TestVizIntegration(BaseTestViz):
self.assertEqual(lst[0]["name"], "Schedule 1 Kernel n1")
self.assertEqual(lst[1]["name"], prg.name)
# schedule graph CALL nodes have a link to jump to codegen
def test_link_sched_codegen(self):
c1 = Tensor.empty(4).add(1)
c2 = Tensor.empty(8).add(1)
sched = Tensor.schedule(c1, c2)
prgs = [si.lower().prg.p.name for si in sched]
lst = get_viz_list()
viz_kernel = next(i for i,s in enumerate(lst[0]["steps"]) if s["name"] == "View Kernel Graph")
graph = next(get_viz_details(0, viz_kernel))["graph"]
call_nodes = [n for n in graph.values() if n["label"].startswith("CALL")]
for i,n in enumerate(call_nodes):
assert n["ref"] is not None
self.assertEqual(lst[n["ref"]]["name"], prgs[i])
def test_metadata_tracing(self):
with Context(TRACEMETA=2):
a = Tensor.empty(1)

View file

@ -45,7 +45,7 @@ class TestWinograd(unittest.TestCase):
# TODO: what's optimal on this?
self.assertLess(ops_ratio, 4.3)
self.assertLess(mem_ratio, 3)
self.assertLess(mem_ratio, 4)
def test_dtype(self):
IC, OC, X, Y = 4,4,9,9

View file

@ -30,7 +30,7 @@ class TestCfg(unittest.TestCase):
def setUp(self):
self.arch = Device["AMD"].arch
if not any(self.arch.startswith(a) for a in {"gfx11", "gfx12"}):
self.skipTest(f"tests written for RDNA, got arch {arch}")
self.skipTest(f"tests written for RDNA, got arch {self.arch}")
def test_simple(self):
k = Kernel(arch=Device["AMD"].arch)

View file

@ -35,8 +35,16 @@ class TestAssign(unittest.TestCase):
a.realize()
np.testing.assert_allclose(b.numpy(), 0)
def test_assign_copy(self):
a = Tensor([1.,2,3], device="PYTHON")
c = Tensor.empty(3).assign(a.to(None))
# it should copy into the empty buffer
GlobalCounters.reset()
c.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
def test_assign_add(self):
for T in (1, 2, 10, 100):
for T in (1, 2, 10):#, 100): # this crashes in CI, not sure why
x = Tensor([0]).realize()
buf = x.uop.base.realized
for _ in range(T):
@ -120,6 +128,7 @@ class TestAssign(unittest.TestCase):
new = a + old_a
np.testing.assert_allclose(new.numpy(), 4)
@unittest.skip("TODO: this is broken")
def test_assign_changes_alt(self, realize=False):
a = Tensor(1).contiguous()
if realize: a.realize()
@ -629,6 +638,7 @@ class TestAssignOrdering(unittest.TestCase):
self.assertEqual(r1.item(), 4)
self.assertEqual(r2.item(), 8)
@unittest.skip("TODO: this is broken")
def test_write_read_write_chain(self):
"""Write, read, write chain - middle read must complete before second write."""
buf = Tensor.zeros(4).contiguous().realize()

View file

@ -92,5 +92,13 @@ class TestCall(unittest.TestCase):
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
def test_call_plus_sharded(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.ones(10, 10).shard(devs, axis=0)
b = Tensor.ones(10, 10).shard(devs, axis=0)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1))
np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10)))
if __name__ == '__main__':
unittest.main()

View file

@ -85,7 +85,7 @@ class TestRawDiskBuffer(unittest.TestCase):
_test_bitcasted(t, dtypes.uint32, 0x40490FDB)
# doesn't suport normal cast
with self.assertRaises(NotImplementedError):
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16).realize()
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16).to(None).realize()
# Those two should be moved to test_dtype.py:test_shape_change_bitcast after bitcast works on non-disk
with self.assertRaises(RuntimeError):
@ -264,18 +264,20 @@ class TestDiskTensor(TempDirTestCase):
def test_strided_read(self):
# test non-contiguous (strided) read - should read elements at indices 0, 2, 4
dt = Tensor([0, 1, 2, 3, 4, 5]).to(f"disk:{self.tmp('dt_strided_read')}")
result = dt[::2].tolist()
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [0, 2, 4]
# self.assertEqual(result, [0, 2, 4])
self.assertEqual(result, [0, 1, 2]) # wrong!
with self.assertRaises(RuntimeError):
result = dt[::2].tolist()
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [0, 2, 4]
# self.assertEqual(result, [0, 2, 4])
self.assertEqual(result, [0, 1, 2]) # wrong!
def test_permuted_read(self):
# test non-contiguous (permuted) read - should read transposed
dt = Tensor([[0, 1, 2], [3, 4, 5]]).to(f"disk:{self.tmp('dt_permuted_read')}")
result = dt.T.tolist()
# TODO: transpose should give [[0, 3], [1, 4], [2, 5]]
# self.assertEqual(result, [[0, 3], [1, 4], [2, 5]])
self.assertEqual(result, [[0, 1], [2, 3], [4, 5]]) # wrong!
with self.assertRaises(RuntimeError):
result = dt.T.tolist()
# TODO: transpose should give [[0, 3], [1, 4], [2, 5]]
# self.assertEqual(result, [[0, 3], [1, 4], [2, 5]])
self.assertEqual(result, [[0, 1], [2, 3], [4, 5]]) # wrong!
def test_write_ones(self):
out = Tensor.ones(10, 10, device="CPU").contiguous()
@ -303,10 +305,11 @@ class TestDiskTensor(TempDirTestCase):
def test_strided_setitem(self):
# test non-contiguous (strided) setitem - should set elements at indices 0, 2, 4
dt = Tensor([1, 2, 3, 4, 5, 6]).to(f"disk:{self.tmp('dt_strided_setitem')}")
dt[::2] = Tensor([10, 20, 30])
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [10, 2, 20, 4, 30, 6]
# self.assertEqual(dt.tolist(), [10, 2, 20, 4, 30, 6])
self.assertEqual(dt.tolist(), [10, 20, 30, 4, 5, 6]) # wrong!
with self.assertRaises(RuntimeError):
dt[::2] = Tensor([10, 20, 30])
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [10, 2, 20, 4, 30, 6]
# self.assertEqual(dt.tolist(), [10, 2, 20, 4, 30, 6])
self.assertEqual(dt.tolist(), [10, 20, 30, 4, 5, 6]) # wrong!
def test_advanced_setitem_not_supported(self):
dt = Tensor.arange(12).reshape(3, 4).to(f"disk:{self.tmp('dt_advanced_setitem')}")

View file

@ -1,62 +1,38 @@
import os, unittest, ctypes
import os, unittest
from tinygrad import dtypes, Tensor, fetch, Device
from tinygrad.nn.state import ggml_data_to_tensor, gguf_load
from tinygrad.device import is_dtype_supported
import numpy as np
import ggml
from gguf import GGUFReader, GGUFValueType, GGMLQuantizationType, GGML_QUANT_SIZES, dequantize, quantize
ggml_test_block_count = 4
ggml_type_to_np_dtype = {
ggml.GGML_TYPE_F16: np.float16, ggml.GGML_TYPE_F32:np.float32, ggml.GGML_TYPE_F64:np.float64,
ggml.GGML_TYPE_I8:np.int8, ggml.GGML_TYPE_I16: np.int16, ggml.GGML_TYPE_I32: np.int32, ggml.GGML_TYPE_I64: np.int64,
}
np_dtype_to_ctype = { np.float16: ctypes.c_uint16 }
gguf_val_getters = [
ggml.gguf_get_val_u8, ggml.gguf_get_val_i8, ggml.gguf_get_val_u16, ggml.gguf_get_val_i16,
ggml.gguf_get_val_u32, ggml.gguf_get_val_i32, ggml.gguf_get_val_f32, ggml.gguf_get_val_bool,
lambda *args: ggml.gguf_get_val_str(*args).decode("utf-8"), None,
ggml.gguf_get_val_u64, ggml.gguf_get_val_i64, ggml.gguf_get_val_f64,
]
def ggml_tensor_to_numpy(tensor: ggml.ggml_tensor_p):
ctx: ggml.ggml_context_p | None = None
ggml_type, n_dims, n_els = tensor.contents.type, ggml.ggml_n_dims(tensor), ggml.ggml_nelements(tensor)
shape = tuple(reversed(tensor.contents.ne[:n_dims]))
if ggml_type not in ggml_type_to_np_dtype:
ctx = ggml.ggml_init(ggml.ggml_init_params(mem_size=n_els * 5 + 500, mem_buffer=None))
ntensor = ggml.ggml_new_tensor(ctx, ggml.GGML_TYPE_F32, n_dims, tensor.contents.ne)
type_traits = ggml.ggml_internal_get_type_traits(ggml_type)
type_traits.to_float(ggml.ggml_get_data(tensor), ggml.ggml_get_data_f32(ntensor), n_els)
tensor, ggml_type = ntensor, ggml.GGML_TYPE_F32
np_type = ggml_type_to_np_dtype[ggml_type]
ctypes_type = np_dtype_to_ctype.get(np_type, None) or np.ctypeslib.as_ctypes_type(np_type)
data = ggml.ggml_get_data(tensor)
if data is None: raise ValueError("tensor data is None")
arr = (ctypes_type * ggml.ggml_nelements(tensor)).from_address(data)
strides = tuple(reversed(tensor.contents.nb[:n_dims]))
output = np.ctypeslib.as_array(arr)
output.dtype = np_type
return np.lib.stride_tricks.as_strided(output, shape=shape, strides=strides), ctx
@unittest.skipIf(any(not is_dtype_supported(t) for t in [ dtypes.uint8, dtypes.half ]), "Backend must support uint8 and half")
class TestGGUF(unittest.TestCase):
def setUp(self) -> None:
params = ggml.ggml_init_params(mem_size=0, mem_buffer=None, no_alloc=False)
self.ctx = ctypes.cast(ggml.ggml_init(params), ctypes.POINTER(ctypes.c_void_p))
def tearDown(self) -> None: ggml.ggml_free(self.ctx)
def test_load_tinyllama_q8_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q8_0.gguf?download=true")
def test_load_tinyllama_q4_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf?download=true")
def test_load_gpt2_q4_1(self): self._test_gguf_load("https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.Q4_1.gguf?download=true")
def test_load_sample_q6_k(self): self._test_gguf_load("https://huggingface.co/Isotr0py/test-gguf-sample/resolve/main/Quant_Q6_K_1024.gguf?download=true")
def test_load_sample_mxfp4(self): self._test_gguf_load("https://huggingface.co/ngxson/boring-testing-tiny/resolve/main/stories260K-mxfp4.gguf?download=true")
def test_dequantization_q4_0(self): self._test_dequantization(ggml.GGML_TYPE_Q4_0)
def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_1)
def test_dequantization_q8_0(self): self._test_dequantization(ggml.GGML_TYPE_Q8_0)
def test_dequantization_q4_k(self): self._test_dequantization(ggml.GGML_TYPE_Q4_K)
def test_dequantization_q6_k(self): self._test_dequantization(ggml.GGML_TYPE_Q6_K)
def test_dequantization_q8_0_hardcoded(self):
# Q8_0: 2 bytes float16 scale + 32 bytes int8 values, dequant = scale * values
block = np.frombuffer(np.float16(2.0).tobytes() + np.arange(1, 33, dtype=np.int8).tobytes(), dtype=np.uint8).copy()
expected = np.arange(1, 33, dtype=np.float32) * 2.0
np.testing.assert_equal(ggml_data_to_tensor(Tensor(block), 32, GGMLQuantizationType.Q8_0.value).numpy().flatten(), expected)
def test_dequantization_mxfp4_hardcoded(self):
# MXFP4: 1 byte shared exponent E + 16 packed bytes (32 x 4-bit values)
# nibble: bit3=sign, bit2:1=exp, bit0=mant; E=128 gives scale=1.0
# codes 0-7 = [0, 1, 2, 3, 4, 6, 8, 12], codes 8-15 are their negatives
block = np.array([0x80] + list(range(16)), dtype=np.uint8) # E=128, nibbles 0-15 in low, zeros in high
expected = np.array([0., 1., 2., 3., 4., 6., 8., 12., -0., -1., -2., -3., -4., -6., -8., -12.] + [0.]*16, dtype=np.float32)
np.testing.assert_equal(ggml_data_to_tensor(Tensor(block), 32, 39).numpy().flatten(), expected)
def test_dequantization_q4_0(self): self._test_dequantization(GGMLQuantizationType.Q4_0)
def test_dequantization_q4_1(self): self._test_dequantization(GGMLQuantizationType.Q4_1)
def test_dequantization_q8_0(self): self._test_dequantization(GGMLQuantizationType.Q8_0)
def test_dequantization_q4_k(self): self._test_dequantization(GGMLQuantizationType.Q4_K)
def test_dequantization_q6_k(self): self._test_dequantization(GGMLQuantizationType.Q6_K)
def test_dequantization_mxfp4(self):
MXFP4 = 39
@ -68,7 +44,7 @@ class TestGGUF(unittest.TestCase):
sign = -1.0 if (code & 0b1000) else 1.0
exp = (code >> 1) & 0b11
mant = code & 0b1
val = (1.0 + 0.5 * mant) * np.exp2(exp - 1) if exp else 0.5 * mant
val = 2 * ((1.0 + 0.5 * mant) * np.exp2(exp - 1) if exp else 0.5 * mant)
scale = np.exp2(E - 128) if E >= 2 else np.exp2(-127 if E == 1 else -128)
return sign * val * scale
@ -84,24 +60,44 @@ class TestGGUF(unittest.TestCase):
# TODO: should this be exact equal? somehow failed on CI
np.testing.assert_allclose(out.numpy(), expected, atol=0.0, rtol=1e-6)
def test_dequantization_mxfp4_block(self):
MXFP4 = 39
# https://gist.github.com/Ananta-Ranganathan/3317b6ed51a3b033e9c2564fafb4e043
# used the above script to download the first block of blk.0.attn_k_b.weight from
# https://huggingface.co/unsloth/GLM-4.7-Flash-GGUF/blob/main/GLM-4.7-Flash-MXFP4_MOE.gguf
# and compute the canonical expected dequantized output with the GGUF PY implementation
block = np.array([0x7a, 0x29, 0xab, 0x61, 0x10, 0x21, 0x02, 0x4a,
0x15, 0xca, 0x05, 0x01, 0x9b, 0x39, 0x0b, 0x0b, 0x1c], dtype=np.uint8)
expected = np.array([-0.01562500, -0.04687500, 0.01562500, 0.00000000,
0.01562500, 0.03125000, -0.03125000, 0.09375000,
-0.03125000, 0.09375000, 0.01562500, -0.04687500,
-0.01562500, -0.04687500, -0.04687500, -0.06250000,
0.03125000, -0.03125000, 0.12500000, 0.01562500,
0.03125000, 0.00000000, 0.06250000, 0.01562500,
-0.06250000, 0.00000000, 0.00000000, -0.01562500,
0.04687500, 0.00000000, 0.00000000, 0.01562500], dtype=np.float32)
out = ggml_data_to_tensor(Tensor(block), 32, MXFP4)
# TODO: similar to previous test fails on Mac CI with assert_equal for unclear reason
np.testing.assert_allclose(out.numpy(), expected, atol=0.0, rtol=1e-6)
def test_expected_failure_unknown_type(self):
with self.assertRaises(ValueError):
ggml_data_to_tensor(Tensor.empty(512, dtype=dtypes.uint8), 256, 1337)
def _test_dequantization(self, ttype: int):
type_traits = ggml.ggml_internal_get_type_traits(ttype)
n_el, n_bytes = ggml_test_block_count * type_traits.blck_size, ggml_test_block_count * type_traits.type_size
def _test_dequantization(self, qtype: GGMLQuantizationType):
block_size, type_size = GGML_QUANT_SIZES[qtype]
n_el, n_bytes = ggml_test_block_count * block_size, ggml_test_block_count * type_size
data_in = (np.random.random((n_el,)).astype(np.float32) * 100 - 50).ctypes.data_as(ctypes.POINTER(ctypes.c_float))
try:
q_data = quantize((np.random.random((n_el,)).astype(np.float32) * 100 - 50), qtype)
except NotImplementedError:
q_data = np.random.default_rng(42).integers(0, 256, size=n_bytes, dtype=np.uint8)
ref = dequantize(q_data, qtype)
c_q_data, c_dq_data = (ctypes.c_char * n_bytes)(0), (ctypes.c_float * n_el)(0)
type_traits.from_float(data_in, c_q_data, n_el)
type_traits.to_float(c_q_data, c_dq_data, n_el)
q_tensor = Tensor(q_data)
dq_tensor = ggml_data_to_tensor(q_tensor, n_el, qtype.value).reshape(n_el)
q_tensor = Tensor(np.frombuffer(c_q_data, dtype=np.uint8, count=n_bytes))
dq_tensor = ggml_data_to_tensor(q_tensor, n_el, ttype).reshape(n_el)
np.testing.assert_equal(dq_tensor.numpy(), np.frombuffer(c_dq_data, dtype=np.float32))
np.testing.assert_equal(dq_tensor.numpy(), ref)
def _test_gguf_load(self, url: str):
fp = fetch(url)
@ -109,24 +105,20 @@ class TestGGUF(unittest.TestCase):
gguf_tensor = Tensor.empty(model_size, dtype=dtypes.uint8, device=f"disk:{fp}").to(Device.DEFAULT)
kv_data, tensors = gguf_load(gguf_tensor)
gguf_params = ggml.gguf_init_params(ctx=self.ctx, no_alloc=False)
gguf_ctx = ggml.gguf_init_from_file(str(fp).encode("utf8"), gguf_params)
param_ctx = gguf_params.ctx.contents.value
reader = GGUFReader(fp)
for ggml_tensor_idx in range(ggml.gguf_get_n_tensors(gguf_ctx)):
tensor_name = ggml.gguf_get_tensor_name(gguf_ctx, ggml_tensor_idx)
ggml_tensor = ggml.ggml_get_tensor(param_ctx, tensor_name)
ggml_tensor_numpy, temp_ctx = ggml_tensor_to_numpy(ggml_tensor)
tensor = tensors.get(tensor_name.decode("utf-8"))
np.testing.assert_equal(tensor.numpy(), ggml_tensor_numpy)
if temp_ctx is not None: ggml.ggml_free(temp_ctx)
for rt in reader.tensors:
ref = dequantize(rt.data, rt.tensor_type)
np.testing.assert_equal(tensors[rt.name].numpy(), ref.reshape(tensors[rt.name].shape))
for gguf_key_id in range(ggml.gguf_get_n_kv(gguf_ctx)):
v = kv_data[ggml.gguf_get_key(gguf_ctx, gguf_key_id).decode("utf-8")]
v_type = ggml.gguf_get_kv_type(gguf_ctx, gguf_key_id)
if (get_fn := gguf_val_getters[v_type]) is not None: self.assertEqual(get_fn(gguf_ctx, gguf_key_id), v)
ggml.gguf_free(gguf_ctx)
for k, f in reader.fields.items():
if k.startswith("GGUF."): continue # skip file header keys (version, tensor_count, kv_count)
def read_val(i, parts=f.parts, is_str=(f.types[-1] == GGUFValueType.STRING)):
return bytes(parts[i]).decode("utf-8") if is_str else parts[i][0].item()
if f.types[0] == GGUFValueType.ARRAY:
self.assertEqual(kv_data[k], [read_val(i) for i in f.data])
else:
self.assertEqual(kv_data[k], read_val(-1))
if __name__ == '__main__':
unittest.main()

View file

@ -28,16 +28,7 @@ class TestRealizeIsRealized(unittest.TestCase):
t = Tensor.ones(8).contiguous().shard((d, d), axis=0).realize()
assert all(u.is_realized for u in t.uop.src)
# TODO: these are not realized after .realize() because they stay as consts / don't allocate buffers
def test_const_not_realized(self):
t = Tensor(3.14).realize()
assert not t.uop.is_realized
def test_ones_not_realized(self):
t = Tensor.ones(4, 4).realize()
assert not t.uop.is_realized
def test_empty_not_realized(self):
def test_empty(self):
t = Tensor.empty(4, 4).realize()
assert not t.uop.is_realized
@ -48,6 +39,22 @@ class TestRealizeIsRealized(unittest.TestCase):
t = Tensor.empty(4, dtype=dtypes.float32, device=f"disk:{f.name}").realize()
assert not t.uop.is_realized
def test_assign(self):
t = Tensor([1, 2, 3])
t += 1
t.realize()
assert t.uop.is_realized
# TODO: these are not realized after .realize()
def test_const_not_realized(self):
t = Tensor(3.14).realize()
assert not t.uop.is_realized
def test_ones_not_realized(self):
t = Tensor.ones(4, 4).realize()
assert not t.uop.is_realized
def test_none_not_realized(self):
t = Tensor(None).realize()
assert not t.uop.is_realized

View file

@ -0,0 +1,28 @@
import sys
import pytest
@pytest.mark.skipif(sys.platform != "linux", reason="uses linux sysfs layout")
def test_pci_scan_bus_filters_vendor(monkeypatch):
import tinygrad.runtime.support.system as system
fake = {
"/sys/bus/pci/devices/0000:00:01.0/vendor": "0x1234",
"/sys/bus/pci/devices/0000:00:01.0/device": "0x1111",
"/sys/bus/pci/devices/0000:00:02.0/vendor": "0xabcd",
"/sys/bus/pci/devices/0000:00:02.0/device": "0x1111",
}
class FakeFileIOInterface:
def __init__(self, path, *args, **kwargs):
self.path = path
def listdir(self):
assert self.path == "/sys/bus/pci/devices"
return ["0000:00:01.0", "0000:00:02.0"]
def read(self, *args, **kwargs):
return fake[self.path]
monkeypatch.setattr(system, "FileIOInterface", FakeFileIOInterface)
assert system.System.pci_scan_bus(0x1234, devices=[(0xffff, [0x1111])]) == ["0000:00:01.0"]

View file

@ -340,6 +340,10 @@ if __name__ == "__main__":
# do benchmark
if args.benchmark:
param_bytes = sum(x.nbytes() for x in nn.state.get_parameters(model))
for b in model.blk:
if hasattr(b, 'ffn_gate_exps'):
expert_bytes = b.ffn_gate_exps.weight.nbytes() + b.ffn_up_exps.weight.nbytes() + b.ffn_down_exps.weight.nbytes()
param_bytes -= int(expert_bytes * (1 - b.num_experts_per_tok / b.ffn_gate_exps.weight.shape[0]))
gen = model.generate([0], 0)
for _ in range(args.benchmark):
GlobalCounters.reset()

View file

@ -48,10 +48,9 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
elif (a:=len(limited)) > (b:=len(dims)):
if a == 2 and b == 1: return [raw_idxs[0] * limited[1] + raw_idxs[1]]
if a == 3 and b == 1: return [(raw_idxs[0] * limited[1] + raw_idxs[1]) * limited[2] + raw_idxs[2]]
if a == 3 and b == 2: return [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
elif limited != dims:
if limited != dims:
# Convert to 1D
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(limited) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
# Get back original indices from 1D
return [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]]
return raw_idxs

View file

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate
from tinygrad.helpers import getenv, flatten, AMX, CPU_X86, prod, ceildiv, IMAGE
from tinygrad.helpers import getenv, flatten, AMX, CPU_X86, prod, IMAGE
from tinygrad.renderer import Renderer
# ***** image load valid simplification *****
@ -190,9 +190,9 @@ def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]:
buf = idx.src[0]
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
h, w = dt.shape[0], dt.shape[1]
if IMAGE == 1 and valid is not None and (tp:=dt.size // 4) // 64:
h, w = max(([(1, tp)] * (tp < 16384)) + [(tp//64//k, 64*k) for k in range(ceildiv(tp//64, 16384), min(tp//64, 256)+1) if (tp//64) % k == 0],
key=lambda hw: len(_drop_valid_stmts(valid, UOp.vectorize((x//4)%hw[1], x//(4*hw[1])), *hw)))
if IMAGE == 1 and valid is not None:
h, w = max(ImageDType.valid_dims(dt),
key=lambda hw: len(_drop_valid_stmts(valid, uop_given_valid(valid, UOp.vectorize((x//4)%hw[1], x//(4*hw[1]))), *hw)))
buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize))
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % w, (x // (4*w))))
return x, idx.replace(src=(buf, oidx.valid(valid))), w, h

View file

@ -7,7 +7,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.helpers import ALLOW_TF32, count, Context, ceildiv
from tinygrad.helpers import ALLOW_TF32, count, Context
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
@ -353,28 +353,21 @@ def apply_opts(ast:UOp, ren:Renderer) -> UOp:
k = hand_coded_optimizations(k)
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
# max image width (pixels): 16384. max image size: 4 * 16384 ** 2
def _image_shape(dt):
if dt.base not in (dtypes.half, dtypes.float) or isinstance(dt, ImageDType) or dt.size > 4*16384*16384 or dt.nbytes()%64 != 0: return None
if dt.size <= 4 * 16384: return (1, dt.size // 4, 4)
if (pxls:=dt.size // 4) % 64: return None
# verify that a valid format exists
try: return next((pxls // 64 // k, 64 * k, 4) for k in range(ceildiv(pxls // 64, 16384), min(pxls // 64, 256)+1))
except StopIteration: return None
def make_image(pa, off, idx):
if (idx.tag is None or idx.tag) and (shape:=_image_shape(dt:=pa.dtype)):
new_idx = idx.replace(src=(pa.replace(dtype=(dtypes.imageh if dt.base==dtypes.half else dtypes.imagef)(shape, shape[1] * 4 * dt.itemsize)), off),
dtype=dtypes.float if dt.base == dtypes.half else idx.dtype)
if not isinstance(dt:=pa.dtype, ImageDType) and (idx.tag is None or idx.tag) and (shapes:=ImageDType.valid_dims(dt)):
new_pa = pa.replace(dtype=(dtypes.imageh if dt.base==dtypes.half else dtypes.imagef)(shapes[0] + (4,), shapes[0][1] * 4 * dt.itemsize))
new_idx = idx.replace(src=(new_pa, off), dtype=dtypes.float if dt.base == dtypes.half else idx.dtype)
return new_idx if idx.tag or dt.base == dtypes.float else new_idx.cast(dtypes.half)
pm_make_images = PatternMatcher([
# ensure we dont create an unfoldable image store
(UPat(Ops.STORE, src=(UPat.var("idx"),), allow_any_len=True, name="st"), lambda idx,st:
st.replace(src=(idx.rtag(is_image:=any(c.op is Ops.RANGE and (c.vmax+1)%4 == 0 for c in idx.src[1].get_idx().split_uop(Ops.ADD))),
st.src[1].cast(dtypes.float if is_image and _image_shape(idx.src[0].dtype) else idx.dtype.base)))),
st.src[1].cast(dtypes.float if is_image and ImageDType.valid_dims(idx.src[0].dtype) else idx.dtype.base)))),
(UPat(Ops.INDEX, src=(UPat(Ops.PARAM, name="pa"), UPat.var("off")), name="idx"), make_image),
# remove double cast from image loads
# remove double cast from image loads / stores
(UPat(Ops.INDEX, src=(UPat(Ops.PARAM, name="pa"),), allow_any_len=True, name="idx").cast(dtypes.half).cast(dtypes.float), lambda idx,pa:
idx if isinstance(pa.dtype, ImageDType) else None),
(UPat(Ops.STORE, src=(UPat(Ops.PARAM, name="pa").index(UPat()), UPat.var("val").cast(dtypes.half).cast(dtypes.float)), name="st"), lambda st,pa,val:
st.replace(src=(st.src[0], val)) if isinstance(pa.dtype, ImageDType) else None),
])

View file

@ -283,10 +283,10 @@ class CompilerSet: cset:list[tuple[type[Renderer]|functools.partial, ContextVar|
class Compiled:
profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
def __init__(self, device:str, allocator:Allocator, compilers:CompilerSet|None, runtime, graph=None, group_id=None):
def __init__(self, device:str, allocator:Allocator, compilers:CompilerSet|None, runtime, graph=None):
from tinygrad.renderer import Renderer
self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id
self.device, self.allocator, self.runtime, self.graph = device, allocator, runtime, graph
self.comps_ctrl_var = compilers.ctrl_var if compilers is not None else None
self.comp_sets:dict[str, tuple[ContextVar|None, type[Renderer]|functools.partial]] = {}

View file

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Final, ClassVar, Callable, Literal
import math, struct, ctypes, functools
from dataclasses import dataclass, fields
from tinygrad.helpers import getenv, prod, round_up, next_power2, OSX
from tinygrad.helpers import ceildiv, getenv, prod, round_up, next_power2, OSX
from enum import Enum, auto
class ConstFloat(float):
@ -121,13 +121,25 @@ class ImageDType(PtrDType):
if self._pitch != -1: return self._pitch
imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize))
if OSX: return round_up(imgw, 256) * 4 * self.itemsize
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
# needs to be IMAGE_PITCH_ALIGN=256 for AMD
min_pitchalign = int(math.log2(v)) if (v := getenv("IMAGE_PITCH_ALIGN", 0)) > 0 else 6
pitchalign = max(min_pitchalign, 11 - int(math.log2(imgh))) if imgh > 1 else min_pitchalign
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
granularity = 128 if self.itemsize == 4 else 256
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
return round_up(imgw * 4 * self.itemsize, 1 << pitchalign) + pitch_add
# get list of (height, width) that do not require pitch padding
@staticmethod
def valid_dims(ptr:PtrDType) -> list[tuple[int,int]]:
ALIGN, MAXW = getenv("IMAGE_PITCH_ALIGN", 256 if OSX else 64), 16384
if ptr.base not in (dtypes.half, dtypes.float) or ptr.size > 4*MAXW*MAXW or (ptr.size if OSX else ptr.nbytes()) % ALIGN != 0: return []
if OSX and (ptr.size // 4) % ALIGN: return [] # OSX has stricter requirements for height=1 images
pxls: int = ptr.size // 4
return ([(1, pxls)] * (pxls < MAXW) + [(pxls//ALIGN//k, ALIGN*k) for k in range(ceildiv(pxls//ALIGN, MAXW), min(pxls//ALIGN, MAXW//ALIGN)+1)
if (pxls//ALIGN)%k == 0] if pxls//ALIGN else [])
class dtypes:
@staticmethod
@functools.cache

View file

@ -0,0 +1,129 @@
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element
from tinygrad.dtype import ImageDType
from tinygrad.helpers import prod, DEBUG, argsort, VIZ
def tag_uop(ctx:tuple[list[UOp], dict[UOp, UOp], set[UOp]], x:UOp):
if x.tag is not None: return None
ctx[0].append(x)
return x.replace(tag=(len(ctx[0])-1,))
def disk_copy_is_buffer(ctx, u):
# copies to disk are replaced with the disk buffer
to_disk = isinstance(u._device, str) and u._device.startswith("DISK")
if to_disk: ctx[1][u] = UOp.new_buffer(u.device, u.shard_size, u.dtype).reshape(u.max_shard_shape)
# all copies from disk/numpy are realized into a real buffer
from_creation = isinstance(u.src[0]._device, str) and any(u.src[0]._device.startswith(x) for x in ["NPY", "DISK", "PYTHON"])
if from_creation: return tag_uop(ctx, u)
def apply_after(ctx, u):
ctx[1][u] = u.src[0]
# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated
add_tags = PatternMatcher([
(UPat(Ops.COPY, name="u"), disk_copy_is_buffer),
# no tag on copies that are assigned
(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.COPY, name="c")), name="a"),
lambda a,c: a.replace(src=(a.src[0], c.rtag(())), tag=a.tag+c.tag) if a.tag and c.tag else None),
(UPat(Ops.AFTER, name="u"), apply_after),
(UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop),
(UPat(GroupOp.All, name="x"), lambda ctx,x: tag_uop(ctx,x) if x in ctx[2] else None),
])
def replace_contig_with_assign(u:UOp):
# if size is 0, remove the contig
if u.size == 0: return u.src[0]
# no real contig for DISK tensors, they are left alone
if isinstance(u._device, str) and u._device.startswith("DISK"): return u.rtag(None)
dtype = u.dtype
if isinstance(dtype, ImageDType):
if prod(dtype.shape) != prod(u.max_shard_shape) or ([x for x in u.max_shard_shape if x != 1] or [1])[-1] % 4 != 0:
if DEBUG >= 1: print(f"demoting Image {dtype} with shape {u.max_shard_shape}")
dtype = dtype.base
buffer = UOp.new_buffer(u.device, u.shard_size, dtype).reshape(u.max_shard_shape)
if isinstance(u.device, tuple) and u.axis is not None: buffer = buffer.multi(u.axis)
return buffer.assign(u.src[0]).rtag(u.tag)
def replace_assign_with_contig(u:UOp):
assigned_to = u
while assigned_to.op in {Ops.ASSIGN, Ops.BITCAST}: assigned_to = assigned_to.src[0].base
if assigned_to.op is not Ops.BUFFER:
return u.src[1].contiguous(tag=u.tag)
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
x = src
while x is not src.base:
if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg))
elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape)
else: return None
x = x.src[0]
ctx[src.base] = contig
pm_early_transform_tensor_graph = PatternMatcher([
# CONTIGUOUS replacement hack for openpilot
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous),
# replace ALU sources with contiguous versions found above
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
# add CONTIGUOUS to tagged UOps
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
# remove extra CONTIGUOUS on ASSIGN
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c"), lambda a,c: a.replace(tag=a.tag+c.tag)),
# replace ASSIGN with CONTIGUOUS
(UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig),
# replace CONTIGUOUS with ASSIGNs
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_assign),
# remove DETACH/CONTIGUOUS_BACKWARD
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# handle size 0
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
# early fixup const copy (TODO: is this wrong if there's a pad?)
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
])
def untag_and_append(ctx:tuple[list[UOp], dict[UOp, UOp], list[UOp]], x:UOp):
if x.tag is None: return None
uop_list, buffer_map, assigns = ctx
ret = x.replace(tag=None)
for t in x.tag:
original_uop: UOp = uop_list[t]
replace_uop = ret
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
assigns.append(ret)
return ret
def append_after(ctx:tuple[list[UOp], dict[UOp, UOp], list[UOp]], x:UOp):
ctx[2].append(x)
pm_finalize_call = PatternMatcher([
(UPat(Ops.ASSIGN, name="x"), untag_and_append),
(UPat(Ops.AFTER, name="x"), append_after),
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
# replace UNIQUE with LUNIQUE for CONST cache key normalization
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
])
def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
# uop list is a list in the original_sink graph and we can map to the tags later
# here we build buffer map
uop_list: list[UOp] = []
buffer_map: dict[UOp, UOp] = {}
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}
bases = set([x.multibase for x in big_sink.src if x.base.op not in dont_realize])
# this rewrite is "read-only", it adds simple things to buffer_map and may sink things on big_sink, bottom_up
# this is the only one where we have to be careful to not break the tensor graph
big_sink = graph_rewrite(big_sink, add_tags, ctx=(uop_list, buffer_map, bases), bottom_up=True, name="number the uops")
# here we can break the tensor graph. this is the only place you need to maintain numbered tags
big_sink = graph_rewrite(big_sink, pm_early_transform_tensor_graph, ctx={}, name="early transform tensor graph")
# here we construct the final buffer_map. this is everything that will go into the tensor map
assigns: list[UOp] = []
graph_rewrite(big_sink, pm_finalize_call, ctx=(uop_list, buffer_map, assigns), name="finalize call")
ret = UOp.sink(*assigns)
if VIZ: graph_rewrite(ret, PatternMatcher([]), name="*** Call")
return ret, buffer_map

View file

@ -348,6 +348,8 @@ class TinyJit(Generic[ReturnType]):
update_depends(depends, jit_cache)
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
# sync before re-executing onetime kernels
for dev in set(Device[b.device] for ei in onetime for b in ei.bufs if b is not None): dev.synchronize()
# run the onetime kernels here
for ei in onetime:
for b in ei.bufs: cast(Buffer, b).ensure_allocated()

View file

@ -1,17 +1,15 @@
import time
import time, sys
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.allocations import allocate_global_buffers
# **** schedule linearizer
# ScheduleItem = tuple[AST, buffer UOps, metadata, bound_ranges]
ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], tuple[UOp, ...]]
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
def _unwrap_src(s: UOp) -> UOp:
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
@ -23,9 +21,8 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort(gate_kernel_sink):
if u.op is Ops.RANGE: in_degree.setdefault(u, 0)
if u.op is not Ops.AFTER: continue
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
k = u.src[1]
assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
in_degree.setdefault(k, 0)
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
@ -50,55 +47,24 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
with cpu_profile(TracingKey("linearize schedule")):
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
schedule: list[UOp] = [] # RANGE, KERNEL, or END UOps
sched_item: dict[UOp, ScheduleItem] = {}
pre_schedule: list[ExecItem] = []
buf_uops_list: list[UOp] = []
while len(queue):
k = rk = queue.popleft()
if k.op is Ops.END: k = k.src[0]
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
if k.op is Ops.RANGE: schedule.append(k)
elif k.op is Ops.CALL:
ast = k.src[0]
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges)
schedule.append(k)
if rk.op is Ops.END: schedule.append(rk)
rk = queue.popleft()
k = rk.src[0] if rk.op is Ops.END else rk
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
pre_schedule.append(ExecItem(k.src[0], [], k.arg.metadata))
buf_uops_list.append(UOp.sink(*buf_uops))
for x in children.get(rk, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
with cpu_profile(TracingKey("unroll outer ranges")):
pre_schedule, buf_uops_list = unroll_outer_ranges(schedule, sched_item)
return pre_schedule, UOp.sink(*buf_uops_list)
def unroll_outer_ranges(schedule:list[UOp], sched_item:dict[UOp, ScheduleItem]) -> tuple[list[ExecItem], list[UOp]]:
pre_schedule: list[ExecItem] = []
buf_uops_list: list[UOp] = []
sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]()
while sched_ptr < len(schedule):
si = schedule[sched_ptr]
if si.op is Ops.RANGE:
in_ranges[si] = 0
range_ptrs[si] = sched_ptr + 1
elif si.op is Ops.END:
if in_ranges[si.src[1]] < si.src[1].vmax:
in_ranges[si.src[1]] += 1
sched_ptr = range_ptrs[si.src[1]]
continue
else:
assert si.op is Ops.CALL, f"unexpected op in schedule: {si.op}"
ast, buf_uops, metadata, bound_ranges = sched_item[si]
fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
buf_uops_list.append(UOp.sink(*buf_uops))
sched_ptr += 1
return pre_schedule, buf_uops_list
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map
from tinygrad.schedule.rangeify import get_rangeify
from tinygrad.schedule.multi import multi_pm
def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None:
@ -107,13 +73,6 @@ def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], li
ctx[2][0] += 1
return ret
def replace_input_const(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None:
# replace UNIQUE with LUNIQUE for CONST cache key normalization
ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=ctx[3][0]), b.src[1]))
ctx[3][0] += 1
return ret
def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
var, val = b.src[0], b.src[1].arg
assert var.expr not in ctx[1] or ctx[1][var.expr] == val, f"bind mismatch on {var}, {ctx[1][var.expr]} != {val}"
@ -123,8 +82,6 @@ def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]],
pm_pre_sched_cache = PatternMatcher([
# replace BUFFER with PARAM for cache key normalization
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# replace UNIQUE with LUNIQUE for CONST cache key normalization
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_const),
# strip value from BIND for cache key normalization, so different values hit same cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind),
])
@ -136,8 +93,6 @@ def create_new_buffer(ctx:dict[UOp, UOp], b:UOp):
pm_post_sched_cache = PatternMatcher([
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
# restore CONST back to original CONST
(UPat(Ops.CONST, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)),
# restore PARAM back to original BUFFER
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)),
# restore BIND value stripped in pm_pre_sched_cache
@ -150,6 +105,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# big_sink srcs are all the Tensors
st = time.perf_counter()
big_sink, buffer_map = allocate_global_buffers(big_sink)
# replace BUFFERs with PARAMs, CONSTs UNIQUE with LUNIQUE, strip BIND values for cache key, extract var_vals
input_buffers: dict[UOp, UOp] = {}
var_vals: dict[str, int] = {}
@ -159,39 +116,17 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
# hack to preserve metadata
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx=({}, {}, [0], [0]), name="preserve metadata")
# tensor map is what we return
tensor_map: dict[UOp, UOp] = {}
if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()):
tensor_map |= get_multi_map(big_sink_cache)
big_sink_cache = big_sink_cache.substitute(tensor_map, name="Apply Multi Map")
big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src]))
tensor_map |= get_rangeify_map(big_sink_cache)
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
pre_schedule, buf_uops_sink = create_schedule(big_sink)
# save in schedule cache (include AFTERs in tensor_map so we don't need big_sink)
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map))
combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink)
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True)
pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache))
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
else:
# schedule cache hit
del big_sink_cache
pre_schedule, combined_sink = sc_ret
pre_schedule, buf_uops_sink = sc_ret
del big_sink, big_sink_cache
# replace all the PARAMs/LUNIQUEs back (single graph_rewrite for everything)
input_buffers_inverse = {v:k for k,v in input_buffers.items()}
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
tensor_map_sink, buf_uops_sink = combined.src
tm_src = tensor_map_sink.src
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
# add bufs to pre_schedule
schedule: list[ExecItem] = []
@ -205,7 +140,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
ubufs = tuple(b.buffer for b in buf_uops)
if any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
dnums = [x for x in si.ast.variables() if x.expr == '_device_num']
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
else:
@ -214,9 +149,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
i = 6
while (frm:=sys._getframe(i)) and frm.f_code.co_filename.startswith(str(BASEDIR)): i += 1
print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache)} uops in cache")
f" | {len(UOpMetaClass.ucache):7d} uops in cache | {frm.f_code.co_filename}:{frm.f_lineno}")
used_vars = set().union(*[{v.arg[0] for v in si.ast.variables()} for si in schedule])
return tensor_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule])
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}

View file

@ -39,7 +39,6 @@ pm_gradient = PatternMatcher([
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
(UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)),
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),

View file

@ -13,6 +13,7 @@ def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1)
OSX, WIN = platform.system() == "Darwin", sys.platform == "win32"
CI = os.getenv("CI", "") != ""
ARCH_X86 = any(x in platform.processor() for x in ("Intel", "i386", "x86_64"))
BASEDIR = pathlib.Path(__file__).parent
# fix colors on Windows, https://stackoverflow.com/questions/12492810/python-how-can-i-make-the-ansi-escape-codes-to-work-also-in-windows
if WIN: os.system("")

View file

@ -8,7 +8,7 @@ class Optimizer:
"""
Base class for all optimizers.
"""
def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM):
def __init__(self, params: list[Tensor], lr: float, device=None, fused=FUSE_OPTIM):
# if requires_grad is None, but being put into an optimizer, set it to True
for x in params:
if x.requires_grad is None: x.requires_grad_(True)
@ -16,19 +16,19 @@ class Optimizer:
self.params: list[Tensor] = dedup([x for x in params if x.requires_grad])
assert len(self.params) != 0, "optimizer must have at least one param"
self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
self.device = device or self.params[0].device
self.fused = fused
# store lr in at least float32 precision
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
@property
def device(self): return self.params[0].device
def _new_optim_param(self) -> list[Tensor]:
param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32"))
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False).contiguous() for t in self.params]
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False)]
if self.device is not None:
return [Tensor.zeros(t.shape, dtype=param_dtype, device=self.device, requires_grad=False) for t in self.params]
return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False) for t in self.params]
def zero_grad(self):
"""
@ -54,13 +54,14 @@ class Optimizer:
# NOTE: contiguous is for speed
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
[Tensor.cat(*[unwrap(t.grad).contiguous().flatten() for t in self.params], dim=0)])
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
updates = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
else:
updated_params, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
for i, tt in enumerate(self.params): tt.assign(updated_params[i])
updates, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
for i, tt in enumerate(self.params): tt.assign(self._apply_update(tt, updates[i]))
return extra+self.params+self.buffers
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: raise NotImplementedError
def _apply_update(self, t:Tensor, up:Tensor) -> Tensor: return t.detach() - up.to(t.device)
class OptimizerGroup(Optimizer):
"""
@ -74,17 +75,17 @@ class OptimizerGroup(Optimizer):
def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()]
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 it's just standard SGD.
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, fused=FUSE_OPTIM):
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, device=None, fused=FUSE_OPTIM):
"""
Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
`classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
"""
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, fused=fused)
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, device=device, fused=fused)
# Muon applies the newton schulz algorithm on gradient. also can include momentum, nesterov, and weight decay
def Muon(params: list[Tensor], lr=0.001, momentum=0.95, weight_decay=0.1, ns_steps=5, ns_coefficients=(3.4445, -4.775, 2.0315),
nesterov=True, fused=FUSE_OPTIM):
nesterov=True, device=None, fused=FUSE_OPTIM):
"""
SGD with newton-schulz iteration and post momentum weight decay.
@ -92,7 +93,8 @@ def Muon(params: list[Tensor], lr=0.001, momentum=0.95, weight_decay=0.1, ns_ste
- Paper: https://arxiv.org/pdf/2502.16982
"""
assert not fused, "FUSE_OPTIM not allowed for Muon optimizer"
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_coefficients, nesterov, classic=False, pre_wd=False, tcoef=0.0, fused=fused)
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_coefficients, nesterov,
classic=False, pre_wd=False, tcoef=0.0, device=None, fused=fused)
class LARS(Optimizer):
"""
@ -101,8 +103,8 @@ class LARS(Optimizer):
- Paper: https://arxiv.org/abs/1708.03888v3
"""
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, ns_steps=0, ns_coefficients=None,
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, fused=FUSE_OPTIM):
super().__init__(params, lr, fused)
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, device=None, fused=FUSE_OPTIM):
super().__init__(params, lr, device, fused)
self.momentum, self.wd, self.ns_steps, self.ns_coefficients = momentum, weight_decay, ns_steps, ns_coefficients
self.nesterov, self.classic, self.pre_wd, self.tcoef = nesterov, classic, pre_wd, tcoef
self.b = self._new_optim_param() if self.momentum else []
@ -126,24 +128,24 @@ class LARS(Optimizer):
if not self.pre_wd and self.wd > 0: t = t.detach() * (1.0 - self.wd * self.lr)
# popular momentum does pre learning rate update
if not self.classic: g = g * r * self.lr
ret.append((t.detach() - g).cast(t.dtype))
ret.append(g.cast(t.dtype))
return ret, self.b
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 it's just Adam/W.
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, fused=FUSE_OPTIM):
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, device=None, fused=FUSE_OPTIM):
"""
AdamW optimizer with optional weight decay.
- Paper: https://arxiv.org/abs/1711.05101v3
"""
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, fused=fused)
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, fused=FUSE_OPTIM):
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, device=device, fused=fused)
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, device=None, fused=FUSE_OPTIM):
"""
Adam optimizer.
- Paper: https://arxiv.org/abs/1412.6980
"""
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, fused=fused)
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, device=device, fused=fused)
class LAMB(Optimizer):
"""
@ -151,10 +153,10 @@ class LAMB(Optimizer):
- Paper: https://arxiv.org/abs/1904.00962
"""
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, fused=FUSE_OPTIM):
super().__init__(params, lr, fused)
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, device=None, fused=FUSE_OPTIM):
super().__init__(params, lr, device, fused)
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False) for _ in [b1, b2])
self.m = self._new_optim_param()
self.v = self._new_optim_param()
@ -175,5 +177,5 @@ class LAMB(Optimizer):
r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
else:
r = 1.0
ret.append((t.detach() - self.lr * r * up).cast(t.dtype))
ret.append((self.lr * r * up).cast(t.dtype))
return ret, [self.b1_t, self.b2_t] + self.m + self.v

View file

@ -339,8 +339,8 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
codes = q_to_uint8(blocks[:, 1:17], 4)
sign = 1.0 - codes.rshift(3).cast(dtypes.float32) * 2.0
exp, mant = codes.rshift(1).bitwise_and(0x3).cast(dtypes.float32), codes.bitwise_and(0x1).cast(dtypes.float32)
fp4_val = sign * ((exp != 0).cast(dtypes.float32) * (1.0 + 0.5 * mant) * (exp - 1.0).exp2() +
(exp == 0).cast(dtypes.float32) * 0.5 * mant)
fp4_val = sign * 2.0 * ((exp != 0).cast(dtypes.float32) * (1.0 + 0.5 * mant) * (exp - 1.0).exp2() +
(exp == 0).cast(dtypes.float32) * 0.5 * mant)
return (fp4_val * d).flatten(-2)[:n]
raise ValueError(f"GGML type '{ggml_type}' is not supported!")

View file

@ -80,19 +80,16 @@ class ProgramSpec:
ins:list[int]=field(default_factory=list)
@property
def estimates(self) -> Estimates: return self.ast.arg.estimates
def estimates(self) -> Estimates: return self.ast.arg.estimates if self.ast.arg is not None and self.ast.arg.estimates is not None else Estimates()
@functools.cached_property
def function_name(self) -> str: return to_function_name(self.name)
@functools.cached_property
def runtimevars(self) -> dict[str, int]: return {v.arg[0]: i for i, v in enumerate(self.vars) if v.arg[0] == 'core_id'}
def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'}
@property
def applied_opts(self) -> tuple[Opt, ...]|None:
if self.uops is None: return None
assert self.uops[-1].op is Ops.SINK, self.uops[-1].op
return self.uops[-1].arg.applied_opts
def applied_opts(self) -> tuple[Opt, ...]|None: return self.ast.arg.applied_opts if self.ast.arg is not None else None
def launch_dims(self, var_vals:dict[str, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size]

View file

@ -9,8 +9,7 @@ from dataclasses import dataclass
from typing import Iterator
from enum import Enum
from tinygrad.renderer.amd.dsl import BitField, FixedBitField, Inst, bits
from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP, s_endpgm
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm # same encoding as RDNA4
# ═══════════════════════════════════════════════════════════════════════════════
# FIELD ENUMS
@ -102,16 +101,17 @@ class InstOpRDNA4(Enum):
"""SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3."""
# TODO: we need to do discovery of all of these from instructions
SALU = 0x0
SMEM = 0x1
UNK_02 = 0x2
JUMP_NO = 0x4
UNK_06 = 0x6
JUMP = 0x1
NEXT = 0x2
MESSAGE = 0x4
VALU_64 = 0x6
VALU_WMMA = 0x46
VMEM = 0x10
UNK_11 = 0x11
VINTERP = 0x12
UNK_14 = 0x14
VMEM_128 = 0x11
VMEM_STORE = 0x12
VMEM_STORE_128 = 0x14
OTHER_VMEM = 0x5e
UNK_60 = 0x60
OTHER_VMEM_STORE = 0x60
# ═══════════════════════════════════════════════════════════════════════════════
# PACKET TYPE BASE CLASS
@ -343,8 +343,12 @@ class INST_RDNA4(PacketType): # Layout 4: different delta position and InstOp e
delta = bits[5:3]
flag1 = bits[6:6]
flag2 = bits[7:7]
wave = bits[12:8]
wave_pair = bits[11:8]
flag3 = bits[12:12]
op = bits[19:13].enum(InstOpRDNA4)
# INST_RDNA4 wave_pair field (4 bits) addresses wave pairs, flag2 selects even/odd wave
@property
def wave(self): return self.wave_pair * 2 + self.flag2
class UTILCTR(PacketType):
encoding = bits[6:0] == 0b0110001
@ -586,7 +590,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
def simd_select(p) -> bool: return getattr(p, "cu", 0) == 0 and getattr(p, "simd", 0) == 0
for p in decode(data):
if not simd_select(p): continue
if isinstance(p, WAVESTART):
if isinstance(p, (WAVESTART, WAVESTART_RDNA4)):
assert p.wave not in wave_pc, "only one inflight wave per unit"
wave_pc[p.wave] = next(iter(pc_map))
continue
@ -595,33 +599,35 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
yield (p, InstructionInfo(pc, p.wave, s_endpgm()))
continue
# skip OTHER_ instructions, they don't belong to this unit
if isinstance(p, INST) and p.op.name.startswith("OTHER_"): continue
if isinstance(p, (INST, INST_RDNA4)) and p.op.name.startswith("OTHER_"): continue
if isinstance(p, IMMEDIATE_MASK):
# immediate mask may yield multiple times per packet
for wave in range(16):
if p.mask & (1 << wave):
inst = pc_map[pc:=wave_pc[wave]]
# can this assert be more strict?
assert isinstance(inst, SOPP), f"IMMEDIATE_MASK packet must map to SOPP, got {inst}"
assert type(inst).__name__ == "SOPP", f"IMMEDIATE_MASK packet must map to SOPP, got {inst}"
wave_pc[wave] += inst.size()
yield (p, InstructionInfo(pc, wave, inst))
continue
if isinstance(p, (VALUINST, INST, IMMEDIATE)):
if isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)):
inst = pc_map[pc:=wave_pc[p.wave]]
# s_delay_alu doesn't get a packet?
if isinstance(inst, SOPP) and inst.op in {SOPPOp.S_DELAY_ALU}:
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU"}:
wave_pc[p.wave] += inst.size()
inst = pc_map[pc:=wave_pc[p.wave]]
# identify a branch instruction, only used for asserts
is_branch = isinstance(inst, SOPP) and "BRANCH" in inst.op_name
if is_branch: assert isinstance(p, INST) and p.op in {InstOp.JUMP_NO, InstOp.JUMP}, f"branch can only be folowed by jump packets, got {p}"
branch_inst = inst if "BRANCH" in inst_op else None
if branch_inst is not None:
assert isinstance(p, (INST, INST_RDNA4)) and p.op.name in {"JUMP_NO", "JUMP", "NEXT"}, f"branch can only be folowed by JUMP, got {p}"
# JUMP handling
if isinstance(p, INST) and p.op is InstOp.JUMP:
assert is_branch, f"JUMP packet must map to a branch instruction, got {inst}"
x = inst.simm16 & 0xffff
wave_pc[p.wave] += inst.size() + (x - 0x10000 if x & 0x8000 else x)*4
if (isinstance(p, INST) and p.op is InstOp.JUMP) or (isinstance(p, INST_RDNA4) and branch_inst is not None and p.flag3):
simm16 = getattr(branch_inst, 'simm16')
assert branch_inst is not None and simm16 is not None, f"JUMP packet must map to a branch instruction, got {inst}"
x = simm16 & 0xffff
wave_pc[p.wave] += branch_inst.size() + (x - 0x10000 if x & 0x8000 else x)*4
else:
if is_branch: assert inst.op != SOPPOp.S_BRANCH, f"S_BRANCH must have a JUMP packet, got {p}"
if branch_inst is not None: assert inst_op != "S_BRANCH", f"S_BRANCH must have a JUMP packet, got {p}"
wave_pc[p.wave] += inst.size()
yield (p, InstructionInfo(pc, p.wave, inst))
continue

View file

@ -136,7 +136,7 @@ class LLVMRenderer(Renderer):
supports_float4 = True
abi: str | None
string_rewrite: PatternMatcher
code_for_op = {Ops.FDIV: lambda: None, Ops.CMPLT: lambda: None}
code_for_op = {k:lambda:None for v in lop.values() for k in v.keys()}
if AMX: tensor_cores = tc.amx
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
@ -169,7 +169,7 @@ class LLVMRenderer(Renderer):
if u.arg is not None: name = u.arg.function_name
continue
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.arg[0]}"
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.expr}"
args.append((r[u], u.dtype))
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"

View file

@ -28,7 +28,8 @@ asm_for_op: dict[Ops, Callable] = {
Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if dt == dtypes.bool else f"or.b{name[1:]} {d}, {a}, {b};",
Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
Ops.CMPNE: lambda d,a,b,dt,name: f"setp.{'neu' if dtypes.is_float(dt) else 'ne'}.{name} {d}, {a}, {b};",
Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
Ops.WHERE: lambda d,a,b,c,dt,name: [f"@{a} mov.{name} {d}, {b};", f"@!{a} mov.{name} {d}, {c};"] if dt == dtypes.bool else \
f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
@ -134,7 +135,7 @@ string_rewrite = PatternMatcher([
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.expr}+0];"),
])
class PTXRenderer(Renderer):
@ -220,7 +221,7 @@ class PTXRenderer(Renderer):
continue
if u.op is Ops.INDEX: continue # other index we can skip
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg
elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
elif u.op is Ops.LOAD:
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)

View file

@ -61,8 +61,9 @@ class CLProgram:
if isinstance(dt, ImageDType):
fmt = cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize])
desc = cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], image_row_pitch=dt.pitch, buffer=b)
b = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
check(cl.clSetKernelArg(self.kernel, real_i, ctypes.sizeof(b), ctypes.byref(b)))
img = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
check(cl.clSetKernelArg(self.kernel, real_i, ctypes.sizeof(img), ctypes.byref(img)))
else: check(cl.clSetKernelArg(self.kernel, real_i, ctypes.sizeof(b), ctypes.byref(b)))
for i,v in enumerate(vals,start=i+1): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))))
if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
event = cl.cl_event() if wait else None

View file

@ -315,6 +315,9 @@ class AMDev(PCIDevImplBase):
self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
self.reserved_vram_size = (384 << 20) if self.ip_ver[am.GC_HWIP][:2] in {(9,4), (9,5)} else (64 << 20)
@functools.cached_property
def hwid_names(self) -> dict[int, str]: return {v:k.removesuffix('_HWID') for k,v in vars(am).items() if k.endswith('_HWID') and isinstance(v, int)}
def _ip_module(self, prefix:str, hwip, prever_prefix:str=""): return import_module(prefix, self.ip_ver[hwip], prever_prefix)
def _build_regs(self):

View file

@ -217,6 +217,16 @@ class AM_SMU(AM_IP):
with contextlib.suppress(TimeoutError): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), timeout=20)
if self.adev.ip_ver[am.GC_HWIP] >= (10,0,0): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]))
def _aca_read_reg(self, bank_idx:int, reg_idx:int, ue=True) -> int:
msg = self.smu_mod.PPSMC_MSG_McaBankDumpDW if ue else self.smu_mod.PPSMC_MSG_McaBankCeDumpDW
return (self._send_msg(msg, (bank_idx << 16) | (reg_idx * 8 + 4), read_back_arg=True) << 32) | \
self._send_msg(msg, (bank_idx << 16) | (reg_idx * 8), read_back_arg=True)
def _aca_read_banks(self, ue=True) -> list[list[int]]:
if not hasattr(self.smu_mod, 'PPSMC_MSG_QueryValidMcaCount'): return []
count_msg = self.smu_mod.PPSMC_MSG_QueryValidMcaCount if ue else self.smu_mod.PPSMC_MSG_QueryValidMcaCeCount
return [[self._aca_read_reg(idx, reg_idx, ue=ue) for reg_idx in range(16)] for idx in range(self._send_msg(count_msg, 0, read_back_arg=True))]
def _smu_cmn_send_msg(self, msg:int, param=0, debug=False):
(self.adev.mmMP1_SMN_C2PMSG_90 if not debug else self.adev.mmMP1_SMN_C2PMSG_54).write(0) # resp reg
(self.adev.mmMP1_SMN_C2PMSG_82 if not debug else self.adev.mmMP1_SMN_C2PMSG_53).write(param)
@ -462,6 +472,20 @@ class AM_IH(AM_IP):
self.adev.regIH_RB_RPTR.write(wptr['offset'] % (self.ring_size // 4))
bif_intr = self.adev.regBIF_BX0_BIF_DOORBELL_INT_CNTL.read_bitfields()
athub_err, cntlr_err = bif_intr['ras_athub_err_event_interrupt_status'], bif_intr['ras_cntlr_interrupt_status']
if athub_err or cntlr_err:
print(f"am {self.adev.devfmt}: fatal hardware error detected: {'RAS_ATHUB_ERR_EVENT ' if athub_err else ''}{'RAS_CNTLR' if cntlr_err else ''}")
acas = self.adev.smu._aca_read_banks(ue=True) + self.adev.smu._aca_read_banks(ue=False)
for regs in acas:
acatyp = 'Uncorrectable' if (regs[1] >> 61) & 1 and (regs[1] >> 57) & 1 else 'Correctable'
hwname = f'{self.adev.hwid_names.get((regs[5] >> 32) & 0xFFF, "")} ({(regs[5] >> 32) & 0xFFF:#03x})'
print(f"am {self.adev.devfmt}: {acatyp} ACA: {hwname} mcatype={(regs[5] >> 48) & 0xFFFF:#06x} regs=[{', '.join(f'{r:#x}' for r in regs)}]")
self.adev.regBIF_BX0_BIF_DOORBELL_INT_CNTL.write(ras_cntlr_interrupt_clear=cntlr_err, ras_athub_err_event_interrupt_clear=athub_err)
self.adev.is_err_state = True
class AM_SDMA(AM_IP):
def init_sw(self): self.sdma_reginst, self.sdma_name = [], "F32" if self.adev.ip_ver[am.SDMA0_HWIP] < (7,0,0) else "MCU"
def init_hw(self):

View file

@ -81,7 +81,7 @@ class TLSFAllocator:
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
# Search for the smallest block that can fit the requested size. Start with its bucket and go up until any block is found.
for l1 in range(self.lv1(size), len(self.storage)):
if self.lv1_entries[l1] == 0: continue
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
@ -105,7 +105,7 @@ class TLSFAllocator:
def free(self, start:int):
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
# Memory Managment
# Memory Management
class AddrSpace(enum.Enum): PHYS = enum.auto(); SYS = enum.auto(); PEER = enum.auto() # noqa: E702
@ -221,7 +221,7 @@ class MemoryManager:
@classmethod
def alloc_vaddr(cls, size:int, align=0x1000) -> int:
assert cls.va_allocator is not None, "must be set it"
assert cls.va_allocator is not None, "must be set"
return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping:
@ -248,7 +248,7 @@ class MemoryManager:
return self.map_range(va, size, paddrs, aspace=AddrSpace.PHYS, uncached=uncached)
def vfree(self, vm:VirtMapping):
assert self.va_allocator is not None, "must be set it"
assert self.va_allocator is not None, "must be set"
self.unmap_range(vm.va_addr, vm.size)
self.va_allocator.free(vm.va_addr)
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)

View file

@ -77,7 +77,7 @@ class NVDev(PCIDevImplBase):
self._early_ip_init()
self._early_mmu_init()
# Turn the booting early, gsp client is loaded from the clean.
# No booting state, gsp client is reinited every run.
self.is_booting = False
for ip in [self.flcn, self.gsp]: ip.init_sw()

View file

@ -7,7 +7,8 @@ from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuff
from tinygrad.runtime.support.memory import MemoryManager, VirtMapping, AddrSpace
from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface
MAP_FIXED, MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0x10, 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
MAP_FIXED, MAP_FIXED_NOREPLACE = 0x10, 0x100000
MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
@dataclasses.dataclass(frozen=True)
class PCIBarInfo: addr:int; size:int # noqa: E702
@ -69,7 +70,7 @@ class _System:
all_devs.append((int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/vendor").read(), 16),
int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/device").read(), 16), pcibus))
return sorted([val for vendor, device, val in all_devs if vendor == vendor and any((device & mask) in devlist for mask, devlist in devices)])
return sorted([val for vndr, device, val in all_devs if vndr == vendor and any((device & mask) in devlist for mask, devlist in devices)])
def pci_setup_usb_bars(self, usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, PCIBarInfo]:
for bus in range(gpu_bus):
@ -219,7 +220,7 @@ class LNXPCIIfaceBase:
cls.gpus = hcq_filter_visible_devices(System.pci_scan_bus(vendor, devices, base_class))
# Acquire va range to avoid collisions.
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED, 0)
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED_NOREPLACE, 0)
self.pci_dev, self.dev, self.vram_bar = PCIDevice(dev.__class__.__name__[:2], cls.gpus[dev_id], bars=bars, resize_bars=[vram_bar]), dev, vram_bar
self.p2p_base_addr = self.pci_dev.bar_info[vram_bar].addr

View file

@ -3,7 +3,7 @@ import functools, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink, pm_gate_kernel_sink
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
@ -17,15 +17,19 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([
def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
if buf.base in x.backward_slice_with_self: ctx[x] = None
pm_generate_realize_map = PatternMatcher([
# always realize SINK src
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize),
# always realize REDUCE on outer ranges
(UPat(Ops.REDUCE, name="r"), lambda ctx,r: realize(ctx, r) if any(tr.arg[-1] == AxisType.OUTER for tr in r.src[1:]) else None),
# realize srcs of these
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ASSIGN, Ops.ENCDEC), name="rb"), realize_srcs),
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs),
# sometimes realize src of assign
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
])
@dataclass(frozen=True)
@ -68,7 +72,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
new_srcs.append(new_src)
# NOTE: do we need this?
@ -84,7 +88,7 @@ def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
# input ranges
new_ranges = [r for i,r in enumerate(ctx.range_map[x][0]) if i in x.arg[1]]
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0], tag=x.tag)
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0])
ctx.range_map[ret] = ctx.range_map[x]
return ret
@ -161,6 +165,11 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# get ops to realize
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
# don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
for u in tsink.toposort():
if u.op is Ops.ASSIGN and u.src[1].op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and u.src[1] in rctx.realize_map \
and not u.src[0].op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
del rctx.realize_map[u.src[1]]
# get the consumer map
with cpu_profile("consumer map in rangeify", "TINY"):

View file

@ -1,85 +1,56 @@
from typing import cast
import functools, itertools
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
from tinygrad.dtype import dtypes
from tinygrad.device import Device
# *** allreduce implementation ***
def handle_allreduce_multirank(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
# Group buffers
groups: dict[int|None, list[UOp]] = {}
for i,dev in enumerate(buf.device):
groups.setdefault(Device[dev].group_id, []).append(buf.mselect(i))
# Put reduce leader of each group first
reduce_leaders = set(getenv("REDUCE_LEADERS", "").split(","))
groups = {gid: sorted(bufs, key=lambda x: (x.device not in reduce_leaders, x.device)) for gid,bufs in groups.items()}
# Skip if only one group or if every group has only one buffer
if len(groups) <= 1 or not any(len(g) > 1 for g in groups.values()): return None
# Reduce inside each group
inner = [UOp(Ops.MSTACK, buf.dtype, tuple(bufs)).allreduce(red.arg, (cast(str, bufs[0].device),)).mselect(0) for bufs in groups.values()]
# Allreduce across groups
outer = UOp(Ops.MSTACK, buf.dtype, tuple(inner)).allreduce(red.arg, tuple(buf.device for buf in inner))
# Broadcast back to all devices in the group
gid2bid = {Device[device].group_id: i for i,device in enumerate(outer.device)}
return outer.mselect(gid2bid[Device[red.device].group_id]).copy_to_device(red.device) if not isinstance(red.device, tuple) else \
UOp(Ops.MSTACK, buf.dtype, tuple(outer.mselect(gid2bid[Device[device].group_id]).copy_to_device(device) for device in red.device))
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
use_all2all = (ALL2ALL >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
use_ring = not use_all2all and (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {n_lbs}x{numel} | {buf.dtype}")
use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}")
# contiguous before we copy it
buf = buf.contiguous()
# naive: copy to all devices. if you shrink later, that'll be handled
if not use_ring and not use_all2all:
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(n_lbs)])
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)])
# chunk data into n_lbs pieces
# chunk data into ndev pieces
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (n_lbs - left), initial=0)))
base, left = divmod(numel // factor, ndev)
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
# reduce-scatter
reduced_chunks = []
reduced_chunks:list[UOp] = []
for i,(s,e) in enumerate(chunks):
if use_all2all:
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(n_lbs)]
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
else:
chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
for step in range(n_lbs-1):
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
for step in range(ndev-1):
src, dest = (i+step)%ndev, (i+step+1)%ndev
cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
reduced_chunks.append(reduced)
# allgather
copied_chunks = []
copied_chunks:list[UOp] = []
for i,rc in enumerate(reduced_chunks):
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
else:
this_chunk: list[UOp|None] = [None] * n_lbs
this_chunk[(i+n_lbs-1)%n_lbs] = rc
for step in range(n_lbs-1):
this_chunk[(i+step)%n_lbs] = rc = rc.copy_to_device(buf.device[(i+step)%n_lbs])
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
chain:list[UOp] = [rc]
for step in range(ndev-1):
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
# reassemble
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
@ -90,7 +61,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
ret:list[UOp] = []
def apply_shrink(s:UOp, i:int) -> UOp:
new_arg = [tuple([x.substitute({dvar[0]:dvar[0].const_like(i)}) if isinstance(x, UOp) and
(dvar:=[v for v in x.vars() if v.op is Ops.DEFINE_VAR and v.arg[0]=='_device_num']) else x for x in ss]) for ss in shrink.marg]
(dvar:=[v for v in x.variables() if v.expr=='_device_num']) else x for x in ss]) for ss in shrink.marg]
return s.shrink(tuple(new_arg))
for i, x in enumerate(ms.src):
if x.op is Ops.COPY:
@ -100,7 +71,6 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
return ms.replace(src=tuple(ret))
replace_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank),
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
@ -125,20 +95,20 @@ def alu_multi(root:UOp):
axis = root.axis
assert axis is not None
srcs = []
srcs:list[UOp] = []
for mlb in msrcs:
if mlb.axis == axis:
# same axis, just copy through
assert mlb.op is Ops.MULTI
srcs.append(mlb.src[0])
elif mlb.axis is None:
if mlb.axis is None:
# no axis, shard it
assert mlb.op is not Ops.MULTI
srcs.append(mlb._shard(axis))
else:
# axis mismatch, unshard it, send it to all devices, and shard it correctly
assert mlb.op is Ops.MULTI
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
if mlb.axis == axis:
# same axis, just copy through
srcs.append(mlb.src[0])
else:
# axis mismatch, unshard it, send it to all devices, and shard it correctly
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
def reduce_multi(root:UOp, multi:UOp):
@ -149,21 +119,15 @@ def reduce_multi(root:UOp, multi:UOp):
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return multi.src[0].r(op, axis).multi(axis=multi.axis)
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
def reshape_multi(root:UOp, multi:UOp):
arg = root.marg
if (new_axis:=root.axis) is None: return multi.src[0].reshape(arg).multi(new_axis)
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
assert prod(multi.src[0].shape[multi.axis:])%prod(arg[new_axis+1:]) == 0, f"reshape cannot move items between shards {multi.shape} -> {arg=}"
new_shape_axis = prod(multi.src[0].shape[multi.axis:]) // prod(arg[new_axis+1:])
return multi.src[0].reshape(tuple(s if a!=new_axis else new_shape_axis for a,s in enumerate(arg))).multi(new_axis)
if prod(multi.shape) != prod(new_shape:=root.marg): raise RuntimeError("reshape must maintain prod(shape)")
if (new_axis:=root.axis) is not None: new_shape = tuple(s//len(multi.device) if a==new_axis else s for a,s in enumerate(new_shape))
return multi.src[0].reshape(new_shape).multi(new_axis)
def expand_multi(root:UOp, multi:UOp):
# NOTE: this assert isn't needed, sharded axis can have dim 1
assert multi.axis is None or root.marg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.marg=}"
return multi.src[0].expand(_shape_to_single_shard(multi.axis, root.marg, multi.src[0])).multi(multi.axis)
if multi.axis is None: new_shape = root.marg
else: new_shape = tuple(multi.src[0].shape[multi.axis] if a == multi.axis else s for a,s in enumerate(root.marg))
return multi.src[0].expand(new_shape).multi(multi.axis)
def pad_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.marg[multi.axis] == (0,0), f"padding not supported for {root.marg=}"
@ -177,11 +141,10 @@ def shrink_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.marg[multi.axis] == (0, multi.shape[multi.axis]) or root.marg[multi.axis] in multi.bounds, \
f"shrinking not supported for {root.marg=}"
if multi.axis is not None and root.marg[multi.axis] in multi.bounds and root.marg[multi.axis] != (0, multi.shape[multi.axis]):
assert all(root.marg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
"cannot shrink sharded and non-sharded axis at the same time"
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
# we just copy it to all the devices, no real. this will be optimized out later
return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.marg[multi.axis]))
non_shard_shrink = tuple((0, multi.src[0].shape[i]) if i == multi.axis else s for i, s in enumerate(root.marg))
return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.marg[multi.axis])).shrink(non_shard_shrink)
return multi.src[0].shrink(tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.marg))).multi(multi.axis)
def flip_multi(root:UOp, multi:UOp):
@ -224,9 +187,3 @@ multi_pm = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
])+replace_allreduce
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm")
if VIZ: graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
return ret

View file

@ -1,10 +1,10 @@
from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
@ -26,27 +26,13 @@ pm_mops = PatternMatcher([
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)),
# move movement ops after AFTER
(UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True),
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:], tag=None),)+r.src[1:], r.arg, tag=a.tag)),
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
])
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
def assign_to_contiguous(assign:UOp, target:UOp, src:UOp):
if (t := target.base).op is Ops.PARAM or (t.op is Ops.MSTACK and all(s.op is Ops.PARAM for s in t.src)): return None
# partial view of unrealized graph: insert CONTIGUOUS at base to realize it
if target is not t and target.op_in_backward_slice_with_self(Ops.SHRINK):
if t.op is Ops.CONTIGUOUS: return None
mops: list[UOp] = []
while target.op in GroupOp.Movement:
mops.append(target)
target = target.src[0]
new_target = t.f(Ops.CONTIGUOUS, tag=t.tag)
for m in reversed(mops): new_target = m.replace(src=(new_target,)+m.src[1:])
return assign.replace(src=(new_target, src))
return src.f(Ops.CONTIGUOUS, tag=assign.tag)
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
@ -83,19 +69,21 @@ def split_reduceop(reduce:UOp, x:UOp):
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
# reduce original axes, then split
return splitted.r(*reduce.arg).contiguous().r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape).replace(tag=reduce.tag)
return splitted.r(*reduce.arg).contiguous().r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
mop_cleanup = PatternMatcher([
# merge adjacent RESHAPES, safe because they are not tagged
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"),
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
# merge adjacent RESHAPES
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
def resolve_call(c:UOp) -> UOp|None:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
if c.src[0].op is Ops.PROGRAM: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
params: list[UOp] = []
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params)
params = sorted(params, key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
@ -103,49 +91,31 @@ def resolve_call(c:UOp) -> UOp|None:
for i, (p, a) in enumerate(zip(params, args)):
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
return c.src[0].substitute(dict(zip(params, args))).rtag(c.tag)
return c.src[0].substitute(dict(zip(params, args)))
earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
# remove CONTIGUOUS if the source is already contiguous
(UPat(Ops.RESHAPE, src=(UPat((Ops.PARAM, Ops.CONTIGUOUS)), UPat()), name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
# preserve tags?
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# handle size 0
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
# remove contiguous on movement ops before a copy on disk
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"),
lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
# push copy past movement ops to disk
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.COPY, allow_any_len=True, name="copy"),
lambda x,copy: x.replace(src=(copy.replace(src=(x.src[0],)+copy.src[1:], tag=None),)+x.src[1:], tag=copy.tag) \
lambda x,copy: x.replace(src=(copy.replace(src=(x.src[0],)+copy.src[1:]),)+x.src[1:]) \
if isinstance(x.device, str) and x.device.startswith("DISK") else None),
# ** copy rules **
# early fixup const copy
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
# COPY and source size need to match
# TODO: expand after copy creates issues with tagging
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None),
# copy only to different device
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP, tag=copy.tag) if x.device == copy.device else None),
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP) if x.device == copy.device else None),
# ** assign rules **
@ -153,23 +123,20 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat()), name="src"))), lambda target, src: src),
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
lambda assign, target, src: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))),
lambda target, src: target.assign(src.bitcast(target.dtype))),
# if assign target is itself an ASSIGN chain, canonicalize to the original buffer target
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
# assign only to buffer, otherwise make it a CONTIGUOUS
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.PARAM}, name="target"), UPat(name="src")), name="assign"), assign_to_contiguous),
# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
])
# *****************
# 3.5 cleanups
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC}
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP}
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
def cleanup_dead_axes(b:UOp):
@ -190,8 +157,7 @@ def cleanup_dead_axes(b:UOp):
reshape.append(s)
new_rng.append(rng)
if hit:
# move the tag to the expand. NOTE: this expand tag might not survive
return b.replace(src=b.src[0:1]+tuple(new_rng), tag=None).reshape(tuple(reshape)).expand(b.shape).replace(tag=b.tag)
return b.replace(src=b.src[0:1]+tuple(new_rng)).reshape(tuple(reshape)).expand(b.shape)
def gate_substitute(ctx, b:UOp) -> None:
if not any(r in b.ranges for r in ctx.keys()): raise BottomUpGate()
@ -264,8 +230,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
def remove_noop_bufferize(idx,b2):
if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.BUFFER_VIEW: return None
new_tag = (idx.src[0].tag or ()) + (b2.tag or ()) or None
return idx.src[0].rtag(new_tag).shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0].rtag(new_tag)
return idx.src[0].shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0]
pm_const_buffer_folding = pm_mops+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
@ -275,13 +240,13 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
# remove noop buffers. if we look at the next index we can remove even more of these
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), remove_noop_bufferize),
# no buffers for const (ranges don't matter for const - it's the same value everywhere)
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)),
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
# indexing a const is a const
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
# copy on CONST is CONST
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
# hack if a noop turned to a const
(UPat(Ops.NOOP, src=(UPat.cvar("c"),), name="noop"), lambda c,noop: c.rtag(noop.tag)),
(UPat(Ops.NOOP, src=(UPat.cvar("c"),), name="noop"), lambda c,noop: c),
# mstack on CONST is CONST
(UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True),
lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None),
@ -307,7 +272,7 @@ def late_buffer_view(t:UOp, b:UOp):
if len(shape) == 0: offset = x.src[1].arg
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag), b.src[1]))
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset)), b.src[1]))
to_bufferview = PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
@ -346,7 +311,6 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
# NOTE: this has been fixed up a bit
def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
#assert isinstance(x.tag, Flat), "bufferize must be flat"
size = prod(x.shape)
rngs = sorted(idx.ranges, key=lambda x: x.arg)
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {size}"
@ -359,26 +323,14 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
# skip self-assign from same-device copy, otherwise create the store
# in assign, this is the buffer size, not the bufferize size
if assign_src is assign_target: ret = assign_target.src[0]
else: ret = assign_target.src[0].after(assign_target.replace(dtype=sdtype).store(assign_src, tag=x.tag).end(*rngs))
else: ret = assign_target.src[0].after(assign_target.replace(dtype=sdtype).store(assign_src).end(*rngs))
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
return ret
# lower outerworld reduce here
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
assert sdtype.addrspace == AddrSpace.GLOBAL
outer_range = x.src[0].src[1]
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
# NOTE: this has the same number as the outer range, we need string ranges!
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
bufi = buf.index(idx, dtype=sdtype)
do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range)
return buf.after(do_store)
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
do_store = buf.index(idx, dtype=sdtype).store(x.src[0], tag=x.tag).end(*rngs)
do_store = buf.index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
return buf.after(do_store)
if allow_locals:
@ -387,16 +339,16 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
return buf.after(do_store.barrier())
# collapse any BUFFERIZE to single input BUFFERIZE. move the tag to a reshape
# collapse any BUFFERIZE to single input BUFFERIZE
def flatten_bufferize(x:UOp):
if x.tag is None and len(x.src) == 2: return None
ret = x.replace(tag=None, src=(x.src[0], get_single_element(apply_movement_op(Ops.RESHAPE, (prod(x.shape),), x.shape, x.src[1:]))))
if len(x.src) == 2: return None
ret = x.replace(src=(x.src[0], get_single_element(apply_movement_op(Ops.RESHAPE, (prod(x.shape),), x.shape, x.src[1:]))))
rngs = x.src[1:]
ret = ret.forced_reshape(x.shape)
ret = ret.reshape(x.shape)
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
return ret.rtag(x.tag)
return ret
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
@ -404,7 +356,7 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
# move RESHAPEs through MSELECT/MSTACK
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src])).reshape(m.shape)),
# remove any RESHAPEs on KERNEL
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
@ -423,7 +375,6 @@ class LocalAddBufferContext:
map:dict = field(default_factory=dict)
vars:dict = field(default_factory=dict)
range:int = 0
parent_tags:list = field(default_factory=list)
opts:tuple|None = None
def debuf(ctx:LocalAddBufferContext, buf:UOp):
@ -447,9 +398,6 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.tag != (): return None
if r.arg[-1] == AxisType.OUTER:
# for outer range, we replace with a bound variable
return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None))
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
ctx.range += 1
return ret
@ -488,12 +436,6 @@ rangeify_codegen = PatternMatcher([
# TODO: this can be moved into codegen?
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0]),
# add loads to non ptr indexes
# TODO: this can be moved into codegen?
#(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
# .f(Ops.INDEX, name="idx", allow_any_len=True),
# lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
# fix broadcast dtype
(UPat(Ops.AFTER, name="a").broadcast(name="b"), lambda a,b: a.broadcast(len(b.src))),
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
@ -505,32 +447,17 @@ rangeify_codegen = PatternMatcher([
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
])
def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp):
if x.tag is None or x.tag == (): return None
if isinstance(x.tag, tuple): ctx.parent_tags += list(x.tag)
return x.replace(tag=None)
pm_remove_tags = PatternMatcher([
(UPat(GroupOp.All, name="x"), remove_metadata_tags),
])
pm_add_range_tags = PatternMatcher([
(UPat(Ops.RANGE, name="x"), lambda x: x.rtag(())),
])
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
# if we have any non-outer ranges open here, we don't split
if any(r.arg[-1] != AxisType.OUTER for r in x.ranges): return None
# ends of outer range don't go in kernels
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
def split_store(x:UOp) -> UOp|None:
# if we have any open ranges here, we don't split
if x.ranges: return None
# local kernel rewrite
lctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
# gather the metadata
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True)
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops
if ret.op is Ops.STORE: stored = ret.src[1]
@ -539,8 +466,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
return kernel
@ -549,42 +475,9 @@ split_kernels = PatternMatcher([
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
])
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
if x.tag is not None or x in ctx[1]: return None
if x.tag is None and x.op is Ops.CALL:
# don't tag anything in a CALL
for u in x.src[0].toposort(): ctx[1].add(u)
if x.dtype.scalar() == dtypes.index: return None
ctx[0].append(x)
return x.replace(tag=(len(ctx[0])-1,))
add_tags = pm_gate_kernel_sink+PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.PARAM, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.PARAM for s in x.src) else tag_uop(ctx, x)),
])
# support for using a contiguous permuted view instead of the parent view if one exists
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
x = src
while x is not src.base:
if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg))
elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape)
else: return None
x = x.src[0]
ctx[src.base] = contig
replace_contiguous = PatternMatcher([
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
def get_rangeify(sink:UOp) -> UOp:
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
uop_list: list[UOp] = []
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites+replace_contiguous, ctx={}, bottom_up=True, name="earliest rewrites")
tsink = graph_rewrite(sink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
# convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))
@ -592,19 +485,12 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
# MSTACK stacks multiple BUFFERIZEs in one tagged tensor
# if it's not tagged by here, it's out
tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.PARAM, Ops.AFTER} and \
x.tag is not None and len(x.tag)])
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Rangeify")
# bufferize -> store
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True,
name="bufferize to store")
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, bottom_up=True, name="split kernels")
# WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish
afters = [u for u in tsink.toposort() if u.op is Ops.AFTER]
@ -619,17 +505,5 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
# TODO: we can probably get this earlier
sink_tags = [s.tag for s in tsink.src]
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
becomes_map: dict[UOp, UOp] = {}
for tag, s in zip(sink_tags, tsink.src):
assert tag is not None
for a in tag:
if a is None: continue
becomes_map[uop_list[int(a)]] = s
return becomes_map
return tsink

View file

@ -336,7 +336,7 @@ class Tensor(OpMixin):
raise JitError("cannot access tensor data during JIT capture, the value will be baked in")
x = self.cast(self.dtype.base).contiguous()
if isinstance(self.device, tuple): x = x.to("CPU")
return cast(Buffer, x.realize().uop.base.buffer).ensure_allocated()
return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
def _data(self) -> memoryview: return self._buffer().as_memoryview()
def data(self) -> memoryview:
@ -404,7 +404,7 @@ class Tensor(OpMixin):
"""
Creates a clone of this tensor allocating a separate buffer for the data.
"""
ret = Tensor.empty(self.shape, device=self.device, dtype=self.dtype)
ret = self.empty_like()
if self.grad is not None: ret.grad = self.grad.clone()
return ret.assign(self)
@ -537,12 +537,15 @@ class Tensor(OpMixin):
device = canonicalize_device(device)
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
def empty_like(self, **kwargs) -> Tensor:
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, **kwargs) -> Tensor:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return Tensor.empty(self.shape, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
dtype, device = self.dtype if dtype is None else dtype, self.device if device is None else device
if isinstance(device, tuple) and (axis := self.uop.axis) is not None:
return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis), device=device)
return Tensor.empty(self.shape, dtype=dtype, device=device, **kwargs)
@staticmethod
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
@ -628,7 +631,7 @@ class Tensor(OpMixin):
Tensor._device_seeds[device] = Tensor(
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
# increment rng counter for devices
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
@ -1075,6 +1078,34 @@ class Tensor(OpMixin):
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
def _pad_constant(self, pX:tuple[tuple[sint, sint], ...], value:float) -> Tensor:
# shrink first for negative pads, then pad with only non-negative values
has_neg = not all(resolve(p >= 0) for p in flatten(pX))
X = self.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape))) if has_neg else self
pads = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if has_neg else pX
if value == 0: return X._apply_uop(UOp.pad, arg=pads)
return X._apply_uop(UOp.pad, arg=pads) + Tensor.ones_like(X)._apply_uop(UOp.pad, arg=pads).where(0, value)
def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Tensor:
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.')
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
orig_shape, X = self.shape, self.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX))
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pX, orig_shape, X.shape)))
def _pad_reflect_replicate(self, pX:tuple[tuple[sint, sint], ...], mode:str) -> Tensor:
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
for d,(pB,pA) in enumerate(pads):
if mode == "reflect":
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
slcB, slcA = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
else:
shrB, shrA = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
# shrink after for negative pads (reflection/replication must see full data first)
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
"""
Returns a tensor with padding applied based on the input `padding`.
@ -1107,36 +1138,18 @@ class Tensor(OpMixin):
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
```
"""
if mode not in {"constant", "reflect", "replicate", "circular"}: raise NotImplementedError(f"{mode=} is not supported")
# flat padding
# normalize to grouped format
if all(isinstance(p, (int,UOp)) for p in padding):
if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
# group padding
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding))
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
if mode == "constant":
def _constant(x:Tensor,px,v) -> Tensor:
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
# dispatch
if mode == "constant": return self._pad_constant(pX, value)
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if mode == "circular":
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pads))
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pads, orig_shape, X.shape)))
for d,(pB,pA) in enumerate(pads):
if mode == "reflect":
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
if mode == "replicate":
shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
if mode == "circular": return self._pad_circular(pX)
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
raise NotImplementedError(f"{mode=} is not supported")
# convenience
def pad_to(self, shape, *args):

View file

@ -8,7 +8,7 @@ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDT
from tinygrad.dtype import storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
from tinygrad.helpers import strip_parens, colored, ansilen, printable
if TYPE_CHECKING:
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.renderer import Estimates
@ -16,16 +16,15 @@ if TYPE_CHECKING:
class AxisType(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702
THREAD = auto(); PLACEHOLDER = auto() # noqa: E702
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
AxisType.OUTER: "green"}
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2}
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1}
@ -304,7 +303,24 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return ret
@property
def size(self) -> int: return prod([int(x.vmax) if isinstance(x, UOp) else x for x in self.shape])
def max_shape(self) -> tuple[int, ...]:
return tuple([int(x.vmax) if isinstance(x, UOp) else x for x in self.shape])
@property
def shard_shape(self) -> tuple[sint, ...]:
if not isinstance(self.device, tuple) or self.axis is None: return self.shape
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.shape))
@property
def max_shard_shape(self) -> tuple[int, ...]:
if not isinstance(self.device, tuple) or self.axis is None: return self.max_shape
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.max_shape))
@property
def size(self) -> int: return prod(self.max_shape)
@property
def shard_size(self) -> int: return prod(self.max_shard_shape)
@functools.cached_property
def ended_ranges(self):
@ -352,7 +368,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return vmin
def __bool__(self): return self._eval((dtypes.bool,), bool)
def __int__(self): return self._eval(dtypes.ints, int)
def __float__(self): return self._eval(dtypes.floats, float)
def __float__(self): return float(self._eval(dtypes.floats, float))
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None):
dvars = {k:v for k,v in dvars.items() if k is not v}
if len(dvars) == 0: return self
@ -491,18 +507,23 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@functools.cached_property
def axis(self) -> int|None:
# COPY removes axis. TODO: add more tests for this, and consider MSELECT/MSTACK
if self.op is Ops.COPY: return None
if self.op is Ops.MULTI: return self.arg
# NOTE: they all have to share an axis, we always choose [-1]
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
if len(self.src) == 0: return None
src_axis = self.src[0].axis
if self.op is Ops.SHRINK and src_axis is not None and self.marg[src_axis] != (0, self.src[0].shape[src_axis]):
return None # SHRINK will remove the sharding if it's on axis
if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
if self.op is Ops.RESHAPE:
if src_axis is None: return None
arg_acc:list[sint] = list(itertools.accumulate(self.marg, operator.mul, initial=1))
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
if self.shape[new_axis] % len(self.device) != 0: raise RuntimeError(f"reshape {self.src[0].shape} -> {self.shape} moved items between shards")
return new_axis
if self.op is Ops.PERMUTE: return self.marg.index(src_axis) if src_axis is not None else None
return src_axis
@ -537,6 +558,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.DETACH: return self.src[0].base # DETACH can't change base
return self
@property
def multibase(self) -> UOp:
if self.op in GroupOp.Movement: return self.src[0].base
if self.op is Ops.DETACH: return self.src[0].base # DETACH can't change base
return self
# like gep, but might return an integer
def sgep(self, i:int) -> sint:
match self.op:
@ -554,6 +581,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
def _mop(self, op:Ops, arg, same_shape_noop:bool=False) -> UOp:
# early NOOP
if op in {Ops.SHRINK, Ops.PAD, Ops.EXPAND} and len(arg) == 0:
assert len(self.shape) == 0, "0 len arg only valid on zero length shape"
return self
match op:
case Ops.RESHAPE | Ops.EXPAND: src_args = [arg]
case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg))
@ -567,7 +598,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return ret
# in these four, if the shape doesn't change we can return self
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
#def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
@ -622,9 +652,26 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@property
def buffer(self) -> Buffer|MultiBuffer:
from tinygrad.device import Buffer, MultiBuffer
if self.op in {Ops.CONTIGUOUS, Ops.RESHAPE}: return self.src[0].buffer
# this buffer can process disk tensors and simple movement ops
if self is not self.base:
assert self.op is Ops.RESHAPE, f"can only be RESHAPE {self}"
return self.src[0].buffer
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.uop.symbolic import symbolic
out = graph_rewrite(self.flatten().index(UOp.range(self.size, 0)), pm_mops+symbolic)
buf = out.src[0].buffer
assert isinstance(buf, Buffer), "must be a Buffer for movement ops"
assert out.op is Ops.INDEX, "couldn't collapse to a single INDEX"
if out.src[1].op is Ops.CONST:
return buf.view(1, out.dtype, out.src[1].arg*out.dtype.itemsize)
if out.src[1].op is Ops.RANGE:
return buf.view(self.size, out.dtype, 0)
if out.src[1].op is Ops.ADD and out.src[1].src[0].op is Ops.RANGE and out.src[1].src[1].op is Ops.CONST:
return buf.view(self.size, out.dtype, out.src[1].src[1].arg*out.dtype.itemsize)
raise RuntimeError(f"cannot collapse INDEX {out.pyrender()} to a single size/offset")
if self.op is Ops.BITCAST:
buf = self.src[0].buffer
assert isinstance(buf, Buffer), "must be a Buffer for BITCAST"
return buf.view(self.size, self.dtype, 0)
if self.op is Ops.MSELECT:
ret = self.src[0].buffer
assert isinstance(ret, MultiBuffer)
@ -676,12 +723,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return graph_rewrite(self, pm_unbind, ctx=ret), ret
@property
def val(self) -> int: return self.unbind()[1]
def vars(self) -> set[UOp]:
topo = self.toposort()
bound = {x.src[0]: x for x in topo if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR}
return {bound.get(x, x) for x in topo if x.op is Ops.DEFINE_VAR}
def variables(self) -> list[Variable]:
return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR}, key=lambda v: v.arg)
# *** uop symbolic stuff ***
@ -775,7 +818,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@functools.cached_property
def _sym_fxn(self):
sself = self.simplify()
varnames = tuple(x.arg[0] for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
varnames = tuple(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
# TODO: sanitize varnames, or don't use naked eval while staying fast
return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used
@ -794,11 +837,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# *** uop high level syntactic sugar ***
@property
def shard_shape(self):
if self.axis is None: return self.shape
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.shape))
@staticmethod
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
@ -807,7 +845,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return ret
def placeholder_like(self, slot:int):
assert all_int(self.shape), "no placeholder-like on symbolic shape"
return UOp.placeholder(self.shard_shape, self.dtype, slot)
return UOp.placeholder(self.max_shard_shape, self.dtype, slot)
# set is store+end+after
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
@ -1192,12 +1230,13 @@ if TRACK_MATCH_STATS or PROFILE:
SENTINEL: Final[UOp] = cast(UOp, object())
class BottomUpGate(Exception): pass
class RewriteContext:
def __init__(self, pm, bpm, ctx=None):
def __init__(self, pm, bpm, ctx=None, rewrite_into_calls=False):
self.pm: PatternMatcher|None = pm
self.bpm: PatternMatcher|None = bpm
self.bpm_cache: dict[UOp, UOp|None] = {}
self.ctx = ctx
self.replace: dict[UOp, UOp] = {}
self.rewrite_into_calls = rewrite_into_calls
# no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite
def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx)
@ -1232,6 +1271,10 @@ class RewriteContext:
if n in waitlist: stack.extend(waitlist.pop(n))
continue
stack.append((n, 1, new_n))
# NOTE: CALL is handled as a special case.
# The function that is called is not included in the graph_rewrite.
# If you want to graph_rewrite a call, you can
if new_n.op is Ops.CALL and not self.rewrite_into_calls: self.replace[new_n.src[0]] = new_n.src[0]
for x in reversed(new_n.src):
if x in on_stack: continue
stack.append((x, 0, x))
@ -1270,22 +1313,10 @@ class RewriteContext:
return self.replace[root]
@profile_matches
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, rewrite_into_calls=False) -> UOp:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx, rewrite_into_calls=rewrite_into_calls)
return rewrite_ctx.unified_rewrite(sink)
@profile_matches
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None,
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
new_map: dict[UOp, UOp] = {}
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
if input_map is not None:
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
return new_map
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
@ -1318,7 +1349,6 @@ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo))
pm_gate_kernel_sink = PatternMatcher([(UPat(Ops.SINK, name="sink"), lambda sink: None if gate_kernel_sink(sink) else panic(BottomUpGate))])
def do_unbind(ctx:dict[Variable, int], x:UOp):
v,i = x.unbind()
@ -1347,7 +1377,7 @@ def bitcast(x, in_dtype:DType, out_dtype:DType):
return ret[0] if out_count == 1 else ret
renderer = PatternMatcher([
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.arg[0]),
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.expr),
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: str(x.arg)),
@ -1409,8 +1439,6 @@ pm_pyrender_extra = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+
(f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None),
# TODO: fix forced_reshape
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})" if x.src[0].shape == x.shape else None),
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
# NOTE: CMPNE doesn't work cause there's no __rne__
# NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering

View file

@ -123,16 +123,9 @@ _tensor_spec = PatternMatcher([
# REDUCE_AXIS is the reduce in the tensor graph
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
# REDUCE with an outerworld range
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
# AFTER if things were kernelized
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
# Tensor range bind / store
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True),
# allow CALL/PARAM
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
(UPat(Ops.PARAM), lambda: True),

View file

@ -50,6 +50,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c"),
lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
lambda x,a,b,c1,c2,c3: x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x),
@ -58,6 +60,10 @@ symbolic_simple = propagate_invalid + PatternMatcher([
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"),
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
((UPat.var("y")+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
((UPat.var("y")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"))+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"),
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),

View file

@ -76,7 +76,7 @@
pointer-events: none;
}
label {
display: inline-flex;
display: flex;
align-items: center;
gap: 4px;
line-height: 1;

View file

@ -192,7 +192,9 @@ const waveColor = (op) => {
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),
DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),
WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor}
GPC:new Map([["NONE","#1a7a2e"],["MEMORY_DEPENDENCY","#8b1a00"],["EXEC_DEPENDENCY","#006b6b"],["INST_FETCH","#7a7a00"],["SYNC","#6b006b"],
["PIPE_BUSY","#7a4a00"],["MEMORY_THROTTLE","#5c0000"],["CONSTANT_MEMORY","#1a3d7a"],["NOT_SELECTED","#2e2e3a"],["OTHER","#4a4a55"],
["SLEEPING","#1a1a2a"],["DEFAULT","#3a3a45"]]), WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor}
const cycleColors = (lst, i) => lst[i%lst.length];
const rescaleTrack = (source, tid, k) => {
@ -235,7 +237,11 @@ function selectShape(key) {
const Modes = {0:'read', 1:'write', 2:'write+read'};
function getMetadata(key) {
function setFocus(key) {
if (key !== focusedShape) {
saveToHistory({ shape:focusedShape });
focusedShape = key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel);
}
const { eventType, e } = selectShape(key);
const html = d3.create("div").classed("info", true);
if (eventType === EventTypes.EXEC) {
@ -247,14 +253,14 @@ function getMetadata(key) {
for (const b of e.arg.bufs.sort((a, b) => a.num - b.num)) {
group.append("p").text(`${Modes[b.mode]}@data${b.num} ${formatUnit(b.nbytes, 'B')}`).style("cursor", "pointer").on("click", () => {
const row = document.getElementById(b.k); if (!isExpanded(row)) { row.click(); }
focusShape(b.key);
setFocus(b.key);
});
}
if (e.arg.ctx != null) {
const i = e.arg.ctx; s = e.arg.step;
html.append("a").text(ctxs[i+1].steps[s].name).on("click", () => switchCtx(i, s));
const prgSrc = ctxs[i+1].steps.findIndex(s => s.name === "View Program");
if (prgSrc !== -1) html.append("a").text("View program").on("click", () => switchCtx(i, prgSrc));
const prgSrc = ctxs[i+1].steps.findIndex(s => s.name === "View Source");
if (prgSrc !== -1) html.append("a").text("View Source").on("click", () => switchCtx(i, prgSrc));
}
}
if (eventType === EventTypes.BUF) {
@ -268,16 +274,10 @@ function getMetadata(key) {
const p = kernels.append("p").append(() => colored(`[${u}] ${repr} ${Modes[mode]}@data${num}`));
const shapeInfo = selectShape(shape).e?.arg?.tooltipText?.split("\n");
if (shapeInfo?.length > 5) p.append("span").text(" "+shapeInfo[5]);
if (shape != null) p.style("cursor", "pointer").on("click", () => focusShape(shape));
if (shape != null) p.style("cursor", "pointer").on("click", () => setFocus(shape));
}
}
return html.node();
}
function focusShape(shape) {
saveToHistory({ shape:focusedShape });
focusedShape = shape; d3.select("#timeline").call(canvasZoom.transform, zoomLevel);
return metadata.replaceChildren(getMetadata(focusedShape));
return metadata.replaceChildren(html.node());
}
const EventTypes = { EXEC:0, BUF:1 };
@ -287,7 +287,7 @@ async function renderProfiler(path, unit, opts) {
// support non realtime x axis units
formatTime = unit === "realtime" ? formatMicroseconds : formatCycles;
if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null}; focusedDevice = null; focusedShape = null; }
metadata.replaceChildren(getMetadata(focusedShape));
setFocus(focusedShape);
// layout once!
if (data.tracks.size !== 0) return updateProgress(Status.COMPLETE);
const profiler = d3.select("#profiler").html("");
@ -375,7 +375,7 @@ async function renderProfiler(path, unit, opts) {
// tiny device events go straight to the rewrite rule
const key = k.startsWith("TINY") ? null : `${k}-${j}`;
const labelHTML = label.map(l=>`<span style="color:${l.color}">${l.st}</span>`).join("");
const arg = { tooltipText:labelHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), bufs:[], key,
const arg = { tooltipText:labelHTML+" N:"+shapes.length+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), bufs:[], key,
ctx:shapeRef?.ctx, step:shapeRef?.step };
if (e.key != null) shapeMap.set(e.key, key);
// offset y by depth
@ -606,7 +606,7 @@ async function renderProfiler(path, unit, opts) {
e.preventDefault();
const foundRect = findRectAtPosition(e.clientX, e.clientY);
if (foundRect?.step != null && (foundRect?.key == null || e.type == "dblclick")) { return switchCtx(foundRect.ctx, foundRect.step); }
if (foundRect?.key != focusedShape) { focusShape(foundRect?.key); }
if (foundRect?.key != focusedShape) { setFocus(foundRect?.key); }
}
canvas.addEventListener("click", clickShape);
canvas.addEventListener("dblclick", clickShape);
@ -739,12 +739,12 @@ function saveToHistory(ns) {
const switchCtx = (newCtx, step) => setState({ expandSteps:true, currentCtx:newCtx+1, currentStep:step ?? 0, currentRewrite:0 });
window.addEventListener("popstate", (e) => {
if (e.state?.shape != null) return focusShape(e.state?.shape);
if (e.state?.shape != null) return setFocus(e.state?.shape);
if (e.state != null) setState(e.state);
});
const createToggle = (id, text) => {
const label = d3.create("label").style("display", "block").text(text).node();
const label = d3.create("label").text(text).node();
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
label.prepend(toggle);
return { toggle, label };
@ -826,7 +826,7 @@ async function main() {
}
// timeline with cycles on the x axis
if (ret instanceof ArrayBuffer) {
opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1]), colorByName:step.name.includes("PKTS")};
opts = {heightScale:0.5, hideLabels:true, levelKey:step.name.includes("PKTS") ? (e) => parseInt(e.name.split(" ")[1].split(":")[1]) : null, colorByName:ckey.includes("pkts")};
return renderProfiler(ckey, "clk", opts);
}
metadata.innerHTML = "";
@ -872,7 +872,7 @@ async function main() {
}
if (ret.ref != null) {
const disasmIdx = ctxs[ret.ref+1].steps.findIndex(s => s.name === "View Disassembly")
metadata.appendChild(d3.create("a").text("View Program Graph").on("click", () => switchCtx(ret.ref, disasmIdx)).node());
metadata.appendChild(d3.create("a").text("View Disassembly").on("click", () => switchCtx(ret.ref, disasmIdx)).node());
}
if (ret.cols != null) renderTable(root, ret);
else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang));

View file

@ -67,7 +67,7 @@ const layoutUOp = (g, { graph, change }, opts) => {
const disconnected = new Set();
for (const n of g.nodes()) {
const node = g.node(n);
if (node?.label?.startsWith("CALL\n") || node?.label === "CALL") {
if (node.label.startsWith("CALL\n")) {
for (const pred of (g.predecessors(n) || [])) {
const edge = g.edge(pred, n);
if (edge?.label?.text === 0) {

View file

@ -8,6 +8,8 @@ from http.server import BaseHTTPRequestHandler
from typing import Any, TypedDict, TypeVar, Generator, Callable
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.helpers import printable, Context
from tinygrad.renderer.amd.dsl import Inst
from tinygrad.renderer.amd import detect_format
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
class TCPServerWithReuse(socketserver.TCPServer):
@ -93,6 +95,8 @@ def pystr(u:UOp) -> str:
try: return pyrender(u)
except Exception: return str(u)
# all the trace points, initialized after the trace loads
ctxs:list[dict] = []
def uop_to_json(x:UOp) -> dict[int, dict]:
assert isinstance(x, UOp)
graph: dict[int, dict] = {}
@ -133,7 +137,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None and ctxs: label += f"\ncodegen@{ctxs[ref]['name']}"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
# limit SOURCE labels line count
@ -184,8 +188,6 @@ def rel_ts(ts:int|Decimal, start_ts:int) -> int:
device_ts_diffs:dict[str, Decimal] = {}
def cpu_ts_diff(device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0))
amdgpu_targets:dict[str, str] = {}
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
for e in profile:
@ -205,7 +207,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.arg["name"]] = e
if dur == 0: continue
name, fmt, key = e.name, [], None
if (ref:=ref_map.get(name)) is not None:
if (ref:=ref_map.get(name)) is not None and ctxs:
name = ctxs[ref]["name"]
if (p:=get_prg_uop(ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
flops = sym_infer((estimates:=p.src[0].arg.estimates).ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
@ -299,16 +301,18 @@ def unpack_pmc(e) -> dict:
# ** on startup, list all the performance counter traces
def load_counters(profile:list[ProfileEvent]) -> None:
def load_amd_counters(profile:list[ProfileEvent]) -> None:
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
counter_events:dict[tuple[int, int], dict] = {}
durations:dict[str, list[float]] = {}
prg_events:dict[int, ProfileProgramEvent] = {}
arch = ""
for e in profile:
if isinstance(e, (ProfilePMCEvent, ProfileSQTTEvent)): counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e), []).append(e)
if isinstance(e, ProfileRangeEvent) and e.device.startswith("AMD") and e.en is not None:
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
if isinstance(e, ProfileProgramEvent) and e.tag is not None: prg_events[e.tag] = e
if isinstance(e, ProfileDeviceEvent) and e.device.startswith("AMD"): arch = f"gfx{unwrap(e.props)['gfx_target_version']//1000}"
if len(counter_events) == 0: return None
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
run_number = {n:0 for n,_ in counter_events}
@ -323,13 +327,13 @@ def load_counters(profile:list[ProfileEvent]) -> None:
# to decode a SQTT trace, we need the raw stream, program binary and device properties
if (sqtt:=v.get(ProfileSQTTEvent)):
for e in sqtt:
if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)),
data=(e.blob, prg_events[k].lib, amdgpu_targets[e.device])))
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k])))
if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib,arch)))
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
from tinygrad.renderer.amd.sqtt import map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC
from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4
ret:list[ProfileEvent] = []
rows:dict[str, None] = {}
trace:dict[str, set[int]] = {}
@ -340,12 +344,12 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width)))
for p, info in map_insts(data, lib, target):
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
if isinstance(p, INST):
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
if isinstance(p, (INST, INST_RDNA4)):
op_name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4)) else f"0x{p.op:02x}"
name, width = (op_name, 10 if "BARRIER" in op_name else 1)
add(name, p, width=width, idx=int("OTHER" in name), info=info)
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info)
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info.wave), info=info) # type: ignore[union-attr]
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info).wave, info=info)
if isinstance(p, (VMEMEXEC, ALUEXEC)):
name = str(p.src).split('.')[1]
if name == "VALU_SALU":
@ -359,12 +363,13 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
# ** SQTT OCC only unpacks wave start, end time and SIMD location
def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent,
target:str) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
# * init decoder
from extra.sqtt.roc import decode
base = unwrap(p.base)
addr_table = amd_decode(unwrap(p.lib), amdgpu_targets[p.device])
disasm:dict[int, tuple[str, int]] = {addr+base:(str(inst), inst.size()) for addr, inst in addr_table.items()}
addr_table = amd_decode(unwrap(p.lib), target)
disasm:dict[int, Inst] = {addr+base:inst for addr, inst in addr_table.items()}
rctx = decode(data, {p.tag:disasm})
cu_events:dict[str, list[ProfileEvent]] = {}
# * INST waves
@ -401,9 +406,8 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
for ev in profile:
if isinstance(ev, ProfileDeviceEvent):
device_ts_diffs[ev.device] = ev.tdiff
if (d:=ev.device.split(":")[0]) == "AMD":
device_decoders[d] = load_counters
amdgpu_targets[d] = f"gfx{unwrap(ev.props)['gfx_target_version']//1000}"
if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_amd_counters
if d == "NV": device_decoders[d] = load_nv_counters
# load device specific counters
for fxn in device_decoders.values(): fxn(profile)
# map events per device
@ -431,6 +435,35 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
index = json.dumps({"strings":list(scache), "dtypeSize":dtype_size, "markers":[{"ts":rel_ts(e.ts, start_ts), **e.arg} for e in markers]}).encode()
return struct.pack("<IQII", rel_ts(unwrap(end_ts), start_ts), max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)
# ** PMA counters
def load_nv_counters(profile:list) -> None:
steps:list[dict] = []
sm_version = {e.device:e.props.get("sm_version", 0x800) for e in profile if isinstance(e, ProfileDeviceEvent) and e.props is not None}
run_number:dict[str, int] = {}
for e in profile:
if type(e).__name__ == "ProfilePMAEvent":
run_number[e.kern] = run_num = run_number.get(e.kern, 0)+1
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(ctxs), len(steps)),
data=(e.blob, sm_version[e.device])))
if steps: ctxs.append({"name":"All Counters", "steps":steps})
def pma_timeline(blob:bytes, sm_version:int) -> list[ProfileEvent]:
from extra.nv_pma.decode import decode, decode_tpc_id
ret:list[ProfileEvent] = []
rows:dict[str, None] = {}
tpc_count:dict[int, int] = {}
# assume every sample is 32 cycles
cycles_per_sample = 32
for s, tpc_id in decode(blob, sm_version):
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
gpc, tpc, sm = decode_tpc_id(tpc_id)
tpc_count[tpc_id] = (n:=tpc_count.get(tpc_id,0)) + 1
rows.setdefault(row:=f"GPC:{gpc} TPC:{tpc} SM:{sm} WAVE:{s.wave_id}")
ret.append(ProfileRangeEvent(row, TracingKey(s.stall_reason.name, ret=f"pc=0x{s.pc_offset:06x} active={s.active}"),
Decimal(n*cycles_per_sample), Decimal((n+1)*cycles_per_sample)))
return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret
# ** Assembly static analyzers
def get_stdout(f: Callable) -> str:
@ -440,23 +473,12 @@ def get_stdout(f: Callable) -> str:
except Exception: traceback.print_exc(file=buf)
return buf.getvalue()
def amd_readelf(lib:bytes) -> list[dict]:
from tinygrad.runtime.autogen import amdgpu_kd
def get_elf_section(lib:bytes, name:str):
from tinygrad.runtime.support.elf import elf_loader
image, sections, __ = elf_loader(lib)
rodata = next((s for s in sections if s.name == ".rodata")).content
kd = amdgpu_kd.llvm_amdhsa_kernel_descriptor_t.from_buffer_copy(bytearray(rodata))
vgpr_gran = kd.compute_pgm_rsrc1 & amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WORKITEM_VGPR_COUNT
return [{"label":f"{resource} Alloc", "value":val} for resource,val in [("VGPR", (vgpr_gran+1)*8-7), ("LDS",kd.group_segment_fixed_size),
("Scratch", kd.private_segment_fixed_size)] if val > 0]
return next((sh for sh in elf_loader(lib)[1] if sh.name == name))
def amd_decode(lib:bytes, target:str) -> dict[int, Any]: # Any is the Inst class from tinygrad.renderer.amd.dsl
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.renderer.amd import detect_format
from tinygrad.renderer.amd.dsl import Inst
image, sections, _ = elf_loader(lib)
text = next((sh for sh in sections if sh.name == ".text"), None)
assert text is not None, "no .text section found in ELF"
def amd_decode(lib:bytes, target:str) -> dict[int, Inst]:
text = get_elf_section(lib, ".text")
off, buf = text.header.sh_addr, text.content
arch = "rdna3" if target.startswith("gfx11") else "rdna4" if target.startswith("gfx12") else "cdna"
addr_table:dict[int, Inst] = {}
@ -511,7 +533,12 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict:
if isinstance(val:=getattr(inst, name), Reg): tokens.append({"st":val.fmt(), "keys":[f"r{val.offset+i}" for i in range(val.sz)], "kind":1})
elif name in {"op","opx","opy"}: tokens.append({"st":(op_name:=val.name.lower()), "keys":[op_name], "kind":0})
elif name != "encoding" and val != field.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1})
return {"data":{"blocks":blocks, "paths":paths, "pc_tokens":pc_tokens}, "src":"\n".join(lines)}
from tinygrad.runtime.autogen import amdgpu_kd
kd = amdgpu_kd.llvm_amdhsa_kernel_descriptor_t.from_buffer_copy(bytearray(get_elf_section(lib, ".rodata").content))
vgpr_gran = kd.compute_pgm_rsrc1 & amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WORKITEM_VGPR_COUNT
return {"data":{"blocks":blocks, "paths":paths, "pc_tokens":pc_tokens}, "src":"\n".join(lines),
"metadata":[[{"label":f"{r} Alloc", "value":v} for r,v in [("VGPR", (vgpr_gran+1)*8-7), ("LDS", kd.group_segment_fixed_size),
("Scratch", kd.private_segment_fixed_size)] if v>0]]}
# ** Main render function to get the complete details about a trace event
@ -523,11 +550,10 @@ def get_render(query:str) -> dict:
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data)), "lang":"txt"}
if fmt == "code": return {"src":data, "lang":"cpp"}
if fmt == "asm":
ret:dict = {"metadata":[]}
ret:dict = {}
renderer, lib = data
if renderer.device.startswith("AMD"):
with soft_err(lambda err: ret.update(err)): ret.update(amdgpu_cfg(lib, renderer.arch))
with soft_err(lambda err: ret["metadata"].append(err)): ret["metadata"].append(amd_readelf(lib))
else: ret["src"] = get_stdout(lambda: renderer.compiler.disassemble(lib))
return ret
if fmt == "all-pmc":
@ -572,9 +598,9 @@ def get_render(query:str) -> dict:
pc_to_inst = data["disasm"]
start_pc = None
rows:dict[int, dict] = {}
for pc, (inst,_) in pc_to_inst.items():
for pc, inst in pc_to_inst.items():
if start_pc is None: start_pc = pc
rows[pc] = {"pc":pc-start_pc, "inst":inst, "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
rows[pc] = {"pc":pc-start_pc, "inst":str(inst), "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
for e in w.unpack_insts():
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
inst["hit_count"] += 1
@ -585,6 +611,12 @@ def get_render(query:str) -> dict:
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":ref_map.get(data["prg"].name)}
if fmt == "prg-pma-pkts":
ret = {}
with soft_err(lambda err:ret.update(err)):
if (events:=get_profile(pma_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
else: ret = {"src":"No PMA samples found."}
return ret
return data
# ** HTTP server
@ -647,7 +679,7 @@ if __name__ == "__main__":
st = time.perf_counter()
print("*** viz is starting")
ctxs:list[dict] = get_rewrites(trace:=load_pickle(args.kernels, default=RewriteTrace([], [], {})))
ctxs = get_rewrites(trace:=load_pickle(args.kernels, default=RewriteTrace([], [], {})))
profile_ret = get_profile(load_pickle(args.profile, default=[]))
server = TCPServerWithReuse(('', PORT), Handler)