Merge branch 'master' into clean_load

This commit is contained in:
George Hotz 2026-06-19 16:56:57 -07:00 committed by GitHub
commit 7a214c4499
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 293 additions and 350 deletions

View file

@ -80,7 +80,7 @@ runs:
- name: Cache Python packages (PR)
if: github.event_name == 'pull_request'
id: restore-venv-pr
uses: actions/cache/restore@v4
uses: actions/cache/restore@v5
with:
path: /tmp/.uv-cache
key: uv-${{ runner.os }}-${{ runner.arch }}-python-${{ inputs.python-version }}-${{ inputs.deps }}-${{ inputs.pydeps }}-${{ env.CACHE_VERSION }}
@ -96,7 +96,7 @@ runs:
- name: Cache downloads (PR)
if: inputs.key != '' && github.event_name == 'pull_request'
uses: actions/cache/restore@v4
uses: actions/cache/restore@v5
with:
path: ${{ runner.os == 'Linux' && '~/.cache/tinygrad/downloads/' || '~/Library/Caches/tinygrad/downloads/' }}
key: downloads-${{ github.job }}-${{ inputs.key }}-${{ env.CACHE_VERSION }}
@ -203,7 +203,7 @@ runs:
- name: Cache apt (PR)
if: runner.os == 'Linux' && (inputs.opencl == 'true' || inputs.amd == 'true' || inputs.webgpu == 'true' || inputs.llvm == 'true' || inputs.qemu == 'true') && github.event_name == 'pull_request'
uses: actions/cache/restore@v4
uses: actions/cache/restore@v5
with:
path: /var/cache/apt/archives/
key: ${{ runner.os }}-${{ runner.arch }}-apt-${{ steps.apt-pkgs.outputs.hash }}-${{ env.CACHE_VERSION }}

View file

@ -99,7 +99,6 @@ jobs:
ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
- name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging'
@ -134,32 +133,10 @@ jobs:
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
- name: Fuzz Padded Tensor Core GEMM
run: DEV=METAL M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
- name: Run LLaMA
run: |
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run quantized LLaMA
run: |
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
- name: Run quantized LLaMA3
run: |
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
#- name: Run LLaMA 7B on 4 (virtual) GPUs
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run OLMoE
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
- name: Run llama3.2
run: BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
- name: Run olmoe
run: BENCHMARK_LOG=olmoe JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m olmoe --benchmark --warmup
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
@ -235,9 +212,6 @@ jobs:
- name: Symlink models and datasets
run: |
mkdir -p weights
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
mkdir -p extra/datasets
ln -s /raid/datasets/imagenet extra/datasets/imagenet
@ -279,36 +253,16 @@ jobs:
# TODO: too slow
# - name: Run SDXL
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA
run: |
BENCHMARK_LOG=llama_nojit DEV=NV JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama DEV=NV JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run llama3.2
run: DEV=NV BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
- name: Run qwen3.5
run: DEV=NV BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run quantized LLaMA3
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
# - name: Run LLaMA-3 8B on 6 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-2 70B
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit DEV=NV JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 DEV=NV JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half DEV=NV HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam DEV=NV HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- uses: actions/upload-artifact@v7
with:
name: Speed (NVIDIA)
@ -402,10 +356,7 @@ jobs:
run: |
mkdir -p weights
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
mkdir -p extra/datasets
ln -s /raid/datasets/imagenet extra/datasets/imagenet
@ -458,18 +409,10 @@ jobs:
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 DEV=AMD python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
- name: Run SDXL
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 DEV=AMD python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA 7B
run: |
BENCHMARK_LOG=llama_nojit DEV=AMD JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama DEV=AMD JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA 7B with BEAM
run: BENCHMARK_LOG=llama_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run llama3.2
run: DEV=AMD BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
- name: Run qwen3.5
run: DEV=AMD BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-3 8B on 6 GPUs
@ -478,16 +421,6 @@ jobs:
# run: sudo modprobe amdgpu
# - name: Run LLaMA-2 70B
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral DEV=AMD python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit DEV=AMD JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 DEV=AMD JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half DEV=AMD HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam DEV=AMD HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run process replay tests
uses: ./.github/actions/process-replay

View file

@ -469,7 +469,7 @@ jobs:
# ****** Models Tests ******
testmodels:
name: Models (llvm+cpu+gpu)
name: Models
runs-on: *linux
timeout-minutes: 15
steps:
@ -480,34 +480,12 @@ jobs:
with:
key: models
deps: testing
opencl: 'true'
llvm: 'true'
- name: Test models (llvm)
run: DEV=CPU:LLVM python -m pytest -n=auto test/models --durations=20
- name: Test models (opencl)
run: DEV=CL python -m pytest -n=auto test/models --durations=20
- name: Test models (cpu)
run: DEV=CPU python -m pytest -n=auto test/models --durations=20
- name: Run process replay tests
uses: ./.github/actions/process-replay
testmetalmodels:
name: Models (metal)
runs-on: &macos macos-26
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: metal
deps: testing
- name: Test models (Metal)
run: DEV=METAL python -m pytest -n=auto test/models --durations=20
- name: Test LLaMA compile speed
run: DEV=METAL python test/external/external_test_speed_llama.py
# ****** Feature Tests ******
testdsp:
@ -716,7 +694,7 @@ jobs:
unittestmacos:
name: MacOS (unit)
runs-on: *macos
runs-on: &macos macos-26
timeout-minutes: 20
steps:
- name: Checkout Code

View file

@ -458,7 +458,8 @@ def test_matmul():
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))

View file

@ -126,7 +126,7 @@ def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
P_lds = QP_lds[:, :BLOCK_N]
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) — shaped store fails due to RESHAPE(DEFINE_LOCAL) surviving linearization
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) -- shaped store fails due to RESHAPE(local BUFFER) surviving linearization
rw1 = UOp.range(TM, 296, AxisType.LOOP)
rw2 = UOp.range(TN, 297, AxisType.LOOP)
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)

View file

@ -2619,7 +2619,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0")
insts = build_kernel(batch, M, N, K, A.dtype.base)
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
lds = UOp.placeholder((133_120,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),

View file

@ -219,7 +219,8 @@ def test_matmul():
def asm_kernel(A, B, C):
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(THREADS, "lidx0")]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds')
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2))
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))

View file

@ -134,11 +134,14 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
int sa_idx = block_row, sb_idx = block_col;
#pragma unroll 2
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
if (k + 1 < k_iters) {
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0);
@ -194,8 +197,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
{ // Epilogue k = k_iters - 2
int k = k_iters - 2;
if (k + 1 < k_iters) {
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");

Binary file not shown.

View file

@ -227,6 +227,7 @@ Ternary & $(P, A, B)$
\midrule
\op{Barrier} & (deps\ldots) & --- & Synchronize threads within a workgroup. \\
\op{Ins} & \ldots & \ldots & A single machine instruction (e.g.\ AMD ISA). \\
\op{GetAddr} & (buf, dev) & --- & Lower buf to its address on device dev. \\
\op{Special} & (bound,) & name & GPU thread/workgroup index (e.g.\ \texttt{gidx0}, \texttt{lidx1}). \\
\op{If} & (gate,) & --- & Begin conditional execution block. \\
\op{Endif} & (if,) & --- & End conditional execution block. \\

View file

@ -70,7 +70,7 @@ def custom_lds_sync(A:UOp, arch:str) -> UOp:
num_threads = A.shape[0]
threads = UOp.special(num_threads, "lidx0")
wg = UOp.special(1, "gidx0")
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
lds = UOp.placeholder((512,), dtypes.uint8, 0, AddrSpace.LOCAL) # 128 * 4 bytes
isa = r4 if arch == "rdna4" else r3
wait_kmcnt = [isa.s_wait_kmcnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
wait_dscnt = [isa.s_wait_dscnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
@ -103,7 +103,7 @@ def custom_handwritten(A:UOp) -> UOp:
A = A.flatten()
threads = UOp.special(128, "lidx0")
wg = UOp.special(1, "gidx0")
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
lds = UOp.placeholder((512,), dtypes.uint8, 0, AddrSpace.LOCAL) # 128 * 4 bytes
pipes = {getenv("PIPE", "")} if getenv("PIPE", "") else {"SALU", "VALU", "TRANSCENDENTAL", "WMMA"}
k = Kernel()
# wrap in loop to filter out icache misses

View file

@ -338,11 +338,15 @@ class TestGemmMXFP8(unittest.TestCase):
def test_llama_ffn(self): run_mxfp8_gemm(8192, 14336, 4096)
def test_llama_ffn2(self): run_mxfp8_gemm(8192, 4096, 14336)
def test_llama_qkv(self): run_mxfp8_gemm(8192, 4096, 4096)
def test_general_n_fw(self):
for N in (256, 1792, 2048, 8192): run_mxfp8_gemm(8192, N, 4096)
# backward needs all dims tile-aligned (dgrad reduces N, wgrad reduces M)
def test_bw_simple(self): run_mx_gemm_bw(256, 256, 256)
def test_bw_rect(self): run_mx_gemm_bw(512, 256, 512)
def test_bw_w_post(self): run_mx_gemm_bw(256, 256, 256, w_post=True)
def test_bw_llama_qkv(self): run_mx_gemm_bw(8192, 4096, 4096)
def test_general_n_bw(self):
for N in (2048, 8192, 14336): run_mx_gemm_bw(8192, N, 4096)
# MP sharding: col-parallel (w on out axis), row-parallel (x,w on in axis)
@needs_second_gpu
def test_multi_col_parallel(self): run_mx_gemm_multi(512, 512, 512, x_shard=None, w_shard=0, g_shard=1)

View file

@ -14,49 +14,51 @@ class TestEncodingsX86(unittest.TestCase):
# displacement of 0 isn't emitted
def test_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RDI)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RDI)
# mov edi, dword ptr [rdi]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 3F"))
# rsp/r12 require a sib byte when used as base memory address
def test_rsp_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RSP)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RSP)
# mov esp, dword ptr [rsp]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 24 24"))
# rbp/r13 require a displacement when used as base memory address
def test_rbp_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RBP)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RBP)
# mov ebp, dword ptr [rbp + 0]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 6D 00"))
# test [base + index*scale]
def test_base_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0)), RAX)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
# mov eax, dword ptr [rax + rdx*4]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 04 90"))
# rsp as index means no index
def test_rsp_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0)), RAX)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
# mov eax, dword ptr [rax]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 00"))
# however r12 is a valid index
def test_r12_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0)), RAX)
load = ins(X86Ops.MOV, dtypes.int32,
(def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
# mov eax, dword ptr [rax + r12*4]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("42 8B 04 A0"))
# test [base + index*scale + 8bit disp]
def test_complex_address_8bit_disp(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)), RDI)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4)), RDI)
# mov edi, dword ptr [rdi + rsi*4 + 0xa]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 7C B7 0A"))
# test [base + index*scale + 32bit disp]
def test_complex_address_32bit_disp(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000)), RDI)
load = ins(X86Ops.MOV, dtypes.int32,
(def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000), imm(dtypes.uint8, 4)), RDI)
# mov edi, dword ptr [rdi + rsi*4 + 0x2710]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B BC B7 10 27 00 00"))
@ -114,28 +116,28 @@ class TestEncodingsX86(unittest.TestCase):
# when writting to mem the uop takes the store form where dtype is void and there's no definition
def test_write_mem(self):
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4))
xmm0 = def_reg(dtypes.float32, XMM[0])
extr = ins(X86Ops.VPEXTRD, dtypes.void, (base, index, disp, xmm0, imm(dtypes.uint8, 0)))
extr = ins(X86Ops.VPEXTRD, dtypes.void, address + (xmm0, imm(dtypes.uint8, 0)))
# vpextrd dword ptr [rdi + rsi*4 + 0xa], xmm0, 0
self.assertEqual(bytes.fromhex(self.encode(extr)), bytes.fromhex("C4 E3 79 16 44 B7 0A 00"))
# test two address instruction with fused load works
def test_two_address_load(self):
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
cmove = ins(X86Ops.CMOVE, dtypes.int32, (base, index, disp), RAX)
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4))
cmove = ins(X86Ops.CMOVE, dtypes.int32, address, RAX)
# cmove eax, dword ptr [rdi + rsi*4 + 0xa]
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 44 B7 0A"))
# test instruction where displacement and imm have the same value
def test_disp_imm_same_value(self):
base, index, disp = def_reg(dtypes.int8.ptr(), RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10)
mov = ins(X86Ops.MOVi, dtypes.void, (base, index, disp, disp))
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 1))
mov = ins(X86Ops.MOVi, dtypes.void, address + (imm(dtypes.int8, 10),))
# mov byte ptr [rdi + rsi + 0xa], 0xa
self.assertEqual(bytes.fromhex(self.encode(mov)), bytes.fromhex("40 C6 44 37 0A 0A"))
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10)
imul = ins(X86Ops.IMULi, dtypes.int32, (base, index, disp) + (imm(dtypes.int32, 10),), RDI)
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10), imm(dtypes.uint8, 4))
imul = ins(X86Ops.IMULi, dtypes.int32, address + (imm(dtypes.int32, 10),), RDI)
# imul edi, dword ptr [rdi + rsi*4 + 0xa], 0xa
self.assertEqual(bytes.fromhex(self.encode(imul)), bytes.fromhex("69 BC B7 0A 00 00 00 0A 00 00 00"))

View file

@ -6,6 +6,9 @@ from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
from tinygrad.renderer.isa.x86 import X86Renderer, X86Ops
from tinygrad.renderer.isa import IselContext
# INDEX on a register value with a constant index extracts a single element (the old GEP)
def lane(y:UOp, i:int) -> UOp: return y.index(UOp.const(dtypes.int, i), dtype=y.dtype.scalar())
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only x86")
class TestIselX86(unittest.TestCase):
def isel_rewrite(self, x:UOp):
@ -57,9 +60,9 @@ class TestIselX86(unittest.TestCase):
# need to move src from gpr to xmm before broadcasting
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and n.src[0].arg is X86Ops.VMOVD)
# if we can fuse a load we can skip the move and access memory directly
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
n = self.isel_rewrite(load.broadcast(4))
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 4)
def test_vbroadcastss(self):
a = UOp.variable("a", 0, 0, dtypes.float32)
@ -73,17 +76,17 @@ class TestIselX86(unittest.TestCase):
d = UOp.variable("d", 0, 0, dtypes.float32)
valid = [UOp.vectorize(c, c, d, d),
UOp.vectorize(a.gep(0), a.gep(1), c, c),
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(1), a.gep(2), a.gep(3), a.gep(0)),
UOp.vectorize(a.gep(3), a.gep(2), a.gep(1), a.gep(0), a.gep(7), a.gep(6), a.gep(5), a.gep(4)),
UOp.vectorize(a.gep(0), a.gep(0), b.gep(1), b.gep(1), a.gep(4), a.gep(4), b.gep(5), b.gep(5))]
UOp.vectorize(lane(a, 0), lane(a, 1), c, c),
UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 2), lane(b, 3)),
UOp.vectorize(lane(a, 1), lane(a, 2), lane(a, 3), lane(a, 0)),
UOp.vectorize(lane(a, 3), lane(a, 2), lane(a, 1), lane(a, 0), lane(a, 7), lane(a, 6), lane(a, 5), lane(a, 4)),
UOp.vectorize(lane(a, 0), lane(a, 0), lane(b, 1), lane(b, 1), lane(a, 4), lane(a, 4), lane(b, 5), lane(b, 5))]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
invalid = [UOp.vectorize(a.gep(0), a.gep(1), b.gep(4), b.gep(5)),
UOp.vectorize(a.gep(0), a.gep(5), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(0), a.gep(0), a.gep(0), a.gep(0), a.gep(4), a.gep(4), a.gep(4), a.gep(5)),
UOp.vectorize(a.gep(0), a.gep(0), b.gep(0), b.gep(0), a.gep(4), a.gep(4), b.gep(4), a.gep(4))]
invalid = [UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 4), lane(b, 5)),
UOp.vectorize(lane(a, 0), lane(a, 5), lane(b, 2), lane(b, 3)),
UOp.vectorize(lane(a, 0), lane(a, 0), lane(a, 0), lane(a, 0), lane(a, 4), lane(a, 4), lane(a, 4), lane(a, 5)),
UOp.vectorize(lane(a, 0), lane(a, 0), lane(b, 0), lane(b, 0), lane(a, 4), lane(a, 4), lane(b, 4), lane(a, 4))]
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
def test_vshufpd(self):
@ -93,16 +96,16 @@ class TestIselX86(unittest.TestCase):
d = UOp.variable("d", 0, 0, dtypes.float64)
valid = [UOp.vectorize(c, d),
UOp.vectorize(a.gep(0), c),
UOp.vectorize(a.gep(1), b.gep(1)),
UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(1), a.gep(1), a.gep(3), a.gep(3))]
UOp.vectorize(lane(a, 0), c),
UOp.vectorize(lane(a, 1), lane(b, 1)),
UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 2), lane(b, 3)),
UOp.vectorize(lane(a, 1), lane(a, 1), lane(a, 3), lane(a, 3))]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
invalid = [UOp.vectorize(c, c, c, c),
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(2), b.gep(3), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(0), b.gep(1), a.gep(0), b.gep(1))]
UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 2), lane(b, 3)),
UOp.vectorize(lane(a, 2), lane(b, 3), lane(a, 2), lane(b, 3)),
UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 0), lane(b, 1))]
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
def test_vinsertps(self):
@ -111,31 +114,31 @@ class TestIselX86(unittest.TestCase):
c = UOp.variable("c", 0, 0, dtypes.float32.vec(4))
d = UOp.variable("e", 0, 0, dtypes.float32)
# moving 0th element to position 0 does nothing so only 1 vinsertps is generated
n = self.isel_rewrite(UOp.vectorize(a.gep(0), d))
n = self.isel_rewrite(UOp.vectorize(lane(a, 0), d))
self.assertIs(n.arg, X86Ops.VINSERTPS)
self.assertIsNot(n.src[0].arg, X86Ops.VINSERTPS)
valid = [UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(3), b.gep(2), c.gep(1), d)]
valid = [UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 2), lane(b, 3)),
UOp.vectorize(lane(a, 3), lane(b, 2), lane(c, 1), d)]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VINSERTPS)
# complex address is [base + index*scale + displacement]
def test_complex_address(self):
a = UOp.variable("a", 0, 0, dtypes.int32)
load = UOp.param(0, dtypes.int32.ptr()).index(a + 1, ptr=True).load()
load = UOp.param(0, dtypes.int32, (16,)).index(a + 1).load()
n = self.isel_rewrite(load)
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
self.assertTrue(n.src[2].op is Ops.CONST and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4)
def test_fold_load(self):
load1 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load2 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 1), ptr=True).load()
load1 = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
load2 = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 1)).load()
n = self.isel_rewrite(load1 + load2)
self.assertTrue(len(n.src) == 4)
self.assertTrue(len(n.src) == 5)
# don't fold when used multiple times
def test_dont_fold_load(self):
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
# used by multiple users
n = self.isel_rewrite(load + 1 + load)
self.assertTrue(len(n.src) == 2)

View file

@ -76,7 +76,7 @@ class TestLinearizer(unittest.TestCase):
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG])
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)]
for i,u in enumerate(ranges):
if skip and i in skip: continue
@ -161,7 +161,7 @@ class TestLinearizer(unittest.TestCase):
uops = tuple(to_program(replace_opts(r.schedule_linear().src[-1].src[0],
[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]), renderer=Device[Device.DEFAULT].renderer).src[2].src)
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
accs = [u for u in uops if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG]
stores = [u for u in uops if u.op is Ops.STORE]
assert len(accs) == 0 # it's removed now
assert len(stores) == 1
@ -210,14 +210,14 @@ class TestLinearizer(unittest.TestCase):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
realized_ast = a.schedule_linear().src[-1].src[0]
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
local = [uop for uop in tuple(program.src[2].src) if uop.op is Ops.BUFFER and uop.addrspace in (AddrSpace.LOCAL, AddrSpace.REG)]
assert local[0].dtype.base == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
realized_ast = c.schedule_linear().src[-1].src[0]
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
local = [uop for uop in tuple(program.src[2].src) if uop.op is Ops.BUFFER and uop.addrspace in (AddrSpace.LOCAL, AddrSpace.REG)]
self.assertEqual(local[0].dtype.base, expected_dtype)
tests = (
@ -243,7 +243,7 @@ class TestLinearizer(unittest.TestCase):
r = (x@y).relu()
opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]
ast = helper_linearizer_opt(r, [opt])
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
# the uops graph is reg BUFFER -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
@ -361,7 +361,8 @@ class TestLinearizer(unittest.TestCase):
ast = helper_linearizer_opt(out, opts=[opt])
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
local_stores = [u for u in uops if u.op is Ops.STORE and any(
x.op is Ops.BUFFER and x.addrspace is AddrSpace.LOCAL for x in get_recursive(u.src[0]))]
global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.PARAM for x in get_recursive(u.src[0]))]
barrier = [u for u in uops if u.op is Ops.BARRIER]
assert len(barrier) == 1

View file

@ -3,6 +3,7 @@ import numpy as np
import tempfile, unittest
from tinygrad import Tensor, Context, Device, dtypes, UOp
from tinygrad.uop.ops import Ops
from tinygrad.dtype import AddrSpace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import run_linear
from tinygrad.codegen import to_program
@ -80,7 +81,7 @@ class TestQuantizeOnnxCPU(unittest.TestCase):
with Context(QUANTIZE=1):
linear = run_onnx({"input":inp})["output"].schedule_linear()
prg = to_program(linear.src[-2].src[0], renderer=Device[Device.DEFAULT].renderer)
daccs = [u for u in tuple(prg.src[2].src) if u.op is Ops.DEFINE_REG]
daccs = [u for u in tuple(prg.src[2].src) if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG]
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")

View file

@ -1402,7 +1402,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
acc_dt = dtypes.int32 if is_int_out else dtypes.float32
# Use uint32 temp array to prevent optimizer from eliminating f16→f32 bitcast chains.
# The optimizer folds bitcast(uint32→float32) stores to float32 arrays, losing the conversion.
tmp = UOp(Ops.DEFINE_LOCAL, dtypes.uint32.ptr(n_a_elems + n_b_elems, addrspace=AddrSpace.LOCAL), arg=(n_a_elems + n_b_elems,))
tmp = UOp.placeholder((n_a_elems + n_b_elems,), dtypes.uint32, slot=0, addrspace=AddrSpace.LOCAL)
def cvt_elem(raw: UOp, sub_idx: int) -> UOp:
if is_i8:
@ -1425,7 +1425,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
mant = h & UOp.const(dtypes.uint32, 0x3FF)
# Use bf16 path: shift left by 16 to create bf16 bits, then shift mantissa and adjust exponent in float domain
# bf16 bits = (sign << 15) | (exp_bf16 << 7) | mant_bf16 -- but f16 and bf16 have different formats
# Instead: construct f32 bits properly, use a DEFINE_LOCAL uint32 array to force materialization
# Instead: construct f32 bits properly, use a local uint32 array to force materialization
f32_bits = (sign << UOp.const(dtypes.uint32, 31)) | \
((exp + UOp.const(dtypes.uint32, 112)) << UOp.const(dtypes.uint32, 23)) | \
(mant << UOp.const(dtypes.uint32, 13))

View file

@ -51,7 +51,6 @@ def wer_helper(result: str, reference: str)->float:
wer, _, _ = metrics.word_error_rate([result], [reference])
return wer
@unittest.skipIf(Device.DEFAULT in ["CPU"], "slow")
# TODO: WEBGPU GPU dispatch dimensions limit
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU GPU dispatch dimensions limit")
class TestWhisper(unittest.TestCase):

View file

@ -529,6 +529,26 @@ class TestFunctionTuple(unittest.TestCase):
np.testing.assert_allclose(g(a).numpy(), 14.0)
def test_custom_kernel_inplace_output_is_implicit(self):
# a custom_kernel output the kernel also READS (in-place add) is not write-only, so it must be captured as an input
def inplace_add(C:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0)
return C[i].store(C[i].load() + A[i]).end(i).sink(arg=KernelInfo(name="inplace_add"))
@function(precompile=True, allow_implicit=False)
def f(a:Tensor): return Tensor.custom_kernel(Tensor.empty(*a.shape, dtype=a.dtype, device=a.device), a, fxn=inplace_add)[0]
with self.assertRaisesRegex(RuntimeError, "implicit buffer"): f(Tensor([1., 2., 3., 4.]).contiguous().realize())
def test_custom_kernel_write_only_persistent_output_is_implicit(self):
# a write-only custom_kernel output that is a realized buffer must be captured
def write(C:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0)
return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="write"))
state = Tensor([100., 200., 300., 400.], device="CPU").contiguous().realize()
@function(precompile=True, allow_implicit=True)
def f(a:Tensor): return Tensor.custom_kernel(state, a, fxn=write)[0]
f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize()
np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.])
def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2):
devs = ("CPU:0", "CPU:1")
def my_kernel(C:UOp, A:UOp) -> UOp:

View file

@ -4,12 +4,11 @@ import itertools
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
from tinygrad.uop.ops import ParamArg
from tinygrad.uop.render import pyrender
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
from tinygrad.renderer import Renderer, Estimates
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.dtype import dtypes, PtrDType, ImageDType
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
@ -36,15 +35,12 @@ pm_index_is_shrink = PatternMatcher([
pm_remove_vec_dtypes = PatternMatcher([
# rewrite PARAM to non pointer
(UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
(UPat((Ops.PARAM, Ops.BUFFER), name="buf"), lambda buf:
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
# remove all vec dtypes
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER}, name="x"),
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
# replace DEFINE_LOCAL/DEFINE_REG with BUFFER
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
])+pm_clean_up_group_sink
def do_number_param(ctx:list[int], x:UOp):
@ -187,6 +183,8 @@ def do_linearize(ctx:Renderer, prg:UOp, sink:UOp) -> UOp:
# isa renderers need to allocate registers
if isinstance(ctx, ISARenderer):
if ctx.pre_regalloc_matcher is not None: lst = line_rewrite(lst, ctx.pre_regalloc_matcher, PreRegAllocContext())
# register definitions (INS without srcs) move to the top so regalloc sees their live ranges span the whole program (callee saved regs)
lst = sorted(lst, key=lambda u: u.op is not Ops.INS or bool(u.src))
regalloc_ctx = LinearScanRegallocContext(lst, ctx)
lst = line_rewrite(lst, pm_regalloc_rewrite, regalloc_ctx)
lst = line_rewrite(lst, ctx.post_regalloc_matcher, regalloc_ctx)

View file

@ -245,11 +245,15 @@ def no_vectorized_alu(alu:UOp):
return UOp(Ops.STACK, alu.dtype, alus)
def no_vectorized_buf(buf:UOp):
if not isinstance(buf.dtype, PtrDType): return None
if buf.addrspace not in (AddrSpace.LOCAL, AddrSpace.REG): return None
# TODO: this fails on regs
#assert buf.max_numel() == buf.ptrdtype.size
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.addrspace)).cast(buf.dtype)
sz = buf.ptrdtype.size*buf.ptrdtype.count
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(sz, buf.addrspace), src=(UOp.const(dtypes.int, sz),)).cast(buf.dtype)
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
if buf.addrspace not in (AddrSpace.LOCAL, AddrSpace.REG): return None
cnt = cast.dtype.count
if bcast is not None and bcast.op is Ops.GEP:
# GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes
@ -264,12 +268,10 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.weakint.vec(len(pairs)), offsets), ptr=True)
devectorize_buf_and_index = PatternMatcher([
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
no_vectorized_index),
(UPat(Ops.BUFFER, name="buf"), no_vectorized_buf),
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")), no_vectorized_index),
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")), no_vectorized_index),
])
devectorize_alu = PatternMatcher([

View file

@ -2,6 +2,7 @@ import heapq
from typing import Any
from collections import defaultdict
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
def linearize(sink:UOp) -> list[UOp]:
@ -23,9 +24,7 @@ def linearize(sink:UOp) -> list[UOp]:
match u.op:
# the order and placement of these defines is important
case Ops.PARAM: priority, extra = -20, u.arg.slot
case Ops.BUFFER: priority = -18
case Ops.DEFINE_REG: priority = -18
case Ops.DEFINE_LOCAL: priority = -17
case Ops.BUFFER: priority = -17 if u.addrspace == AddrSpace.LOCAL else -18
case Ops.LOAD: priority = -1 # place loads early
case Ops.STORE: priority = 1 # place stores late
case Ops.RANGE: priority = 5 # placing RANGE is good

View file

@ -2,7 +2,7 @@ import itertools
from tinygrad.helpers import dedup
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
from tinygrad.renderer.isa import ISARenderer, Register
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.dtype import dtypes
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP, Ops.STACK}
@ -49,8 +49,9 @@ class LinearScanRegallocContext:
# assign register to spilled virtual and record load to be emitted before current uop, also assign it a stack slot
def fill(v:Register, i:int, cons:tuple[Register, ...]|None=None) -> Register:
if v not in self.spills:
# the value of a BUFFER is its 64bit address
dt = self.vdef(v).dtype
sz = dt.scalar().itemsize * dt.count if not isinstance(dt, PtrDType) else 8
sz = 8 if self.vdef(v).op is Ops.BUFFER else dt.scalar().itemsize * dt.count
offset = self.stack_size + (sz - self.stack_size % sz) % sz
self.spills[v] = UOp.const(dtypes.int32, offset)
self.stack_size = offset + sz
@ -83,9 +84,9 @@ class LinearScanRegallocContext:
self.reals.setdefault(i, {})[v] = live[v]
# allocate stack array
if u.op is Ops.DEFINE_LOCAL:
if u.op is Ops.BUFFER:
self.locals[u] = UOp.const(dtypes.int32, self.stack_size)
self.stack_size += u.dtype.nbytes()
self.stack_size += u.max_numel() * u.dtype.itemsize
# loop prologue, avoid loading inside the loop
if u.op is Ops.RANGE:
@ -116,7 +117,7 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
if i in ctx.reals and (v:=ctx.uops[i].src[j].reg) in ctx.spills: nsrc.append(ctx.ren.fill(ctx.spills[v], ctx.vdef(v), ctx.reals[i][v]))
else: nsrc.append(s)
ndefs = tuple(ctx.reals[i][v] for v in x.tag) if isinstance(x.tag, tuple) else x.tag
if x.op is Ops.DEFINE_LOCAL: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], dtype=x.dtype, tag=ndefs))
if x.op is Ops.BUFFER: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], tag=ndefs))
else: nx = x.replace(src=tuple(nsrc), tag=ndefs)
before = [ctx.ren.fill(ctx.spills[v], ctx.vdef(v), r) for v,r in ctx.insert_before.get(i, [])]
@ -132,6 +133,5 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
return nx, before + [nx] + after
pm_regalloc_rewrite = PatternMatcher([
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
regalloc_rewrite),
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.BUFFER, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"), regalloc_rewrite),
])

View file

@ -134,7 +134,7 @@ def reduce_collapse(red:UOp, u:UOp, pm:PatternMatcher=pm_reduce_collapse) -> UOp
replaces: dict[UOp, UOp] = {}
for u in included:
for s in u.src:
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL}: continue
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.BUFFER}: continue
replaces[s] = UOp.variable(f'in{len(replaces)}', s.vmin, s.vmax, s.dtype)
collapse_fxn = u.substitute(replaces).reduce(r, arg=Ops.ADD)
sink = graph_rewrite(collapse_fxn, pm, name="reduce_collapse")

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import sys, argparse, codecs, typing, re, unicodedata, json, uuid, time, pathlib
from tinygrad import nn
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context, fetch, profile_marker
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context, fetch, profile_marker, getenv
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
from tinygrad.llm.model import Transformer
@ -214,9 +214,13 @@ def main():
for i in range(args.benchmark):
profile_marker(f"decode @ {i}")
GlobalCounters.reset()
if (log:=getenv("BENCHMARK_LOG", "")): from extra.bench_log import WallTimeEvent, BenchEvent
with Timing(on_exit=lambda x: f", {1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s,"
f" {GlobalCounters.global_mem//1000000}/{GlobalCounters.mem_used//1000000} MB -- "+\
tok.decode(toks).replace("\n", "\\n")): next(gen)
tok.decode(toks).replace("\n", "\\n")):
if log:
with WallTimeEvent(BenchEvent.STEP): next(gen)
else: next(gen)
exit(0)
# interactive chat

View file

@ -39,7 +39,7 @@ def assemble_linear(prg:UOp, lin:UOp, arch:str) -> bytes:
for u in sink.toposort():
if u.op is Ops.PARAM and u.addrspace is AddrSpace.ALU: n_vars += 1
elif u.op is Ops.PARAM: n_bufs += 1
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
elif u.op is Ops.BUFFER and u.addrspace is AddrSpace.LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))
code_bytes = b"".join(inst.to_bytes() for inst in insts)
arch = next(v for k, v in _arch_map.items() if arch.startswith(k))

View file

@ -158,10 +158,9 @@ class CStyleLanguage(Renderer):
return f"({self[buf]}+{strip_parens(self[idx]) if idx.arg == Ops.ADD else self[idx]})"
def render_buffer(self, x:UOp):
shp = x.src[0].as_shape
lanes = 1
prefix = f"{self.smem_align}{self.smem_prefix}" if x.addrspace == AddrSpace.LOCAL else ""
suffix = f"[{shp[0]}]" if len(shp) else ""
suffix = f"[{x.max_numel()}]"
return f"{prefix}{self._render_dtype(x.dtype, sz=lanes)} {self[x]}{suffix};"
def _render_dtype(self, dtype:DType, sz:int=1, addrspace=AddrSpace.ALU, mutable=True, override_ptr=False):

View file

@ -2,7 +2,7 @@
# allow semicolons to put multiple ops on one line
import sys, struct, functools
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, DType, truncate, AddrSpace
from tinygrad.dtype import dtypes, DType, truncate, AddrSpace
from tinygrad.uop import FastEnum, auto, Ops, GroupOp
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
from tinygrad.renderer.isa import ISARenderer, IselContext, Register, PreRegAllocContext
@ -12,8 +12,8 @@ from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
class X86Ops(FastEnum):
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
# these aren't real instructions
FRAME_INDEX = auto(); LABEL = auto()
# these aren't real instructions, DEFINE is a register placeholder that defines a register without emitting an instruction
FRAME_INDEX = auto(); LABEL = auto(); DEFINE = auto()
# index
LEA = auto()
# register / memory / immediate moves
@ -171,50 +171,32 @@ extra_matcher = PatternMatcher([
(UPat(Ops.CMOD, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x - y * x.alu(Ops.CDIV, y)),
])
# ***** X86 new style -> x86 internal style (pointers, vec dtypes, GEP) *****
pm_x86_style = PatternMatcher([
# buffers are pointers, scalar PARAMs (variables) keep their shape src
(UPat(Ops.PARAM, name="x"), lambda x: x.replace(dtype=x.dtype.ptr(x.src[0].arg), src=()) \
if x.arg.addrspace is AddrSpace.GLOBAL and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.BUFFER, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG if x.arg.addrspace == AddrSpace.REG else Ops.DEFINE_LOCAL,
dtype=x.dtype.ptr(x.src[0].arg, x.arg.addrspace), src=(), arg=x.arg.slot)),
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(dtype=x.src[0].dtype) if x.dtype != x.src[0].dtype else None),
# SHRINK is a vectorized INDEX
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx"), UPat.cvar("c"))), lambda buf,idx,c: buf.index(idx, ptr=True) \
.cast(buf.ptrdtype.base.vec(c.arg).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) if isinstance(buf.dtype, PtrDType) else None),
# cast of a pointer is a noop in new style (any reinterpreting cast was absorbed into SHRINK)
(UPat(Ops.CAST, src=(UPat.var("y"),), name="x"), lambda x,y:
y if isinstance(y.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
# INDEX on a pointer has pointer dtype, INDEX on a register value is a GEP
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat()), name="x"), lambda buf,x:
x.replace(dtype=buf.dtype) if isinstance(buf.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.INDEX, src=(UPat.var("y"), UPat.cvar("c")), name="x"), lambda y,c,x:
y.gep(c.arg) if not isinstance(y.dtype, PtrDType) and y.op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else None),
# restore vec dtypes from structure
(UPat(Ops.LOAD, src=(UPat(Ops.CAST, name="c"),), allow_any_len=True, name="x"), lambda x,c:
x.replace(dtype=x.dtype.scalar().vec(c.ptrdtype.base.count)) if isinstance(c.dtype, PtrDType) and c.ptrdtype.base.count > x.dtype.count else None),
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
if not isinstance(x.dtype, PtrDType) and not any(isinstance(s.dtype, PtrDType) for s in x.src) \
and (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
])
# ***** X86 pre instruction selection *****
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count, AddrSpace.LOCAL), arg=next(ctx))
local_idx = local.index(UOp.const(dtypes.int32, 0), ptr=True)
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
return ptr.cast(cast.dtype).load(dtype=x.dtype)
def scratch_buffer(elem_dt:DType, count:int, slot:int) -> UOp:
return UOp.placeholder((count,), elem_dt, slot, AddrSpace.LOCAL)
def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count, AddrSpace.LOCAL), arg=-1)
ptr = gate.where(base.index(idx, ptr=True), local.index(UOp.const(dtypes.int32, 0), ptr=True))
return ptr.cast(cast.dtype).store(val)
def gated_load(ctx, addr:UOp, alt:UOp, gate:UOp, x:UOp):
local = scratch_buffer(addr.src[0].dtype.scalar(), x.dtype.count, next(ctx))
local_idx = local.index(UOp.const(dtypes.int32, 0), dtype=dtypes.uint64)
# the selected address is a 64bit value, the AFTER orders the load after the scratch store and carries the element dtype for the encoder
sel = gate.where(addr.replace(dtype=dtypes.uint64), local_idx)
ptr = UOp(Ops.AFTER, addr.dtype, (sel, (local_idx if x.dtype.count == 1 else local).store(alt)))
return ptr.load(dtype=x.dtype)
# these must be done in a separate matcher because they violate the spec
pre_isel_matcher = pm_x86_style + PatternMatcher([
def gated_store(addr:UOp, gate:UOp, val:UOp):
local = scratch_buffer(addr.src[0].dtype.scalar(), val.dtype.count, -1)
sel = gate.where(addr.replace(dtype=dtypes.uint64), local.index(UOp.const(dtypes.int32, 0), dtype=dtypes.uint64))
return UOp(Ops.AFTER, addr.dtype, (sel,)).store(val)
# legalize the new style graph for isel. NOTE: this runs after the spec is verified, some of these rewrites violate it
pre_isel_matcher = PatternMatcher([
# x86 registers are typed by their width, materialize the structural width of the graph into vec dtypes (this is still valid new style)
(UPat(Ops.SHRINK, src=(UPat(), UPat(), UPat.cvar("c"))).load(allow_any_len=True, name="x"), lambda x,c:
x.replace(dtype=x.dtype.scalar().vec(c.arg)) if c.arg > x.dtype.count else None),
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
if (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
# zero extending scalar 32bit int is a noop
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
# cast between signed and unsigned int is a noop
@ -229,11 +211,12 @@ pre_isel_matcher = pm_x86_style + PatternMatcher([
# noop of a noop is removed
(UPat(Ops.NOOP, src=(UPat(Ops.NOOP),), name="x"), lambda x: x.replace(src=x.src[0].src)),
# moving elements of a single register to another without shuffling is a noop
(UPat(Ops.STACK, src=(UPat.var("y"),), allow_any_len=True, name="x"),
lambda y,x: UOp(Ops.NOOP, x.dtype, y.src) if all(s.op is Ops.GEP and s.src == y.src and s.arg[0] == i for i,s in enumerate(x.src)) else None),
# gated load/store become a conditional move on the index, the load/store are unconditional
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").store(UPat.var("val"), UPat.var("gate")), gated_store),
(UPat(Ops.STACK, src=(UPat.var("y").index(UPat()),), allow_any_len=True, name="x"),
lambda y,x: UOp(Ops.NOOP, x.dtype, (y,)) if all(s.op is Ops.INDEX and len(s.src) == 2 and s.src[0] is y \
and s.src[1].op is Ops.CONST and s.src[1].arg == i for i,s in enumerate(x.src)) else None),
# gated load/store become a conditional move on the address, the load/store are unconditional
(UPat((Ops.INDEX, Ops.SHRINK), name="addr").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
(UPat((Ops.INDEX, Ops.SHRINK), name="addr").store(UPat.var("val"), UPat.var("gate")), gated_store),
# TODO: remove this once we allow all flag producing ops in cmove
# if gate in scalar int cmove is not a comparison need to add one to set the flag
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
@ -264,10 +247,10 @@ reg_strs = {"rax": {4:"eax", 2:"ax", 1:"al"}, "rcx": {4:"ecx", 2:"cx", 1:"cl"},
# ***** X86 instruction selection *****
# if s is used multiple times we don't fold
def is_foldable(ctx:IselContext, x:UOp, s:UOp) -> bool: return len(ctx.uses[s]) == x.src.count(s) == 1
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
def lane(x:UOp, i:int) -> int: return s.arg[0] if (s:=x.src[i]).op is Ops.GEP else 0
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.INDEX else s
def lane(x:UOp, i:int) -> int: return s.src[1].arg if (s:=x.src[i]).op is Ops.INDEX else 0
def to_int(dt:DType): return {dtypes.float16: dtypes.int16, dtypes.float32: dtypes.int32, dtypes.float64: dtypes.int64}[dt]
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.DEFINE_REG, dt, tag=None if reg is None else (reg,))
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, dt, arg=X86Ops.DEFINE, tag=None if reg is None else (reg,))
def imm(dt:DType, v:int) -> UOp: return UOp.const(dt, truncate[dt](v)).rtag()
def to_imm(c:UOp) -> UOp|None:
if c.op is not Ops.CONST: return None
@ -348,41 +331,52 @@ def idiv(ctx:IselContext, x:UOp) -> UOp:
# this move "cleanses" the register constraints (rax/rdx) of idiv as that only applies on definition and not on the uses of idiv
return x.ins(X86Ops.MOV, src=(idiv,))
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp]:
# a memory address operand is (base, index, displacement, size). size is the element size, it scales the index and is the memory operand width.
# it is materialized as an immediate so the address stays correct if the base register is ever spilled and refilled
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp, UOp]:
def _disp(v:int) -> UOp: return imm(dtypes.int32 if abs(v) > dtypes.int8.max else dtypes.int8, v)
def _cast(v:UOp) -> UOp: return v.cast(dtypes.int64) if v.vmin < 0 else v
if x.op is not Ops.INDEX: return (x, UOp(Ops.NOOP), _disp(0))
base, idx = x.src
disp_scale = base.dtype.itemsize if isinstance(base.dtype, PtrDType) else 1
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * disp_scale))
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * disp_scale))
return (base, _cast(idx), _disp(0))
if x.op not in {Ops.INDEX, Ops.SHRINK}: return (x, UOp(Ops.NOOP), _disp(0), imm(dtypes.uint8, x.dtype.itemsize))
base, idx = x.src[0], x.src[1]
# buffers are indexed by element, everything else (the stack pointer) by byte
scale = base.dtype.itemsize if base.op in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else 1
sz = imm(dtypes.uint8, base.dtype.itemsize)
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * scale), sz)
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * scale), sz)
return (base, _cast(idx), _disp(0), sz)
def abi(ctx:IselContext, x:UOp) -> UOp|None:
if isinstance(x.tag, tuple): return None
i = ctx.func_args.index(x)
def _stack_arg(disp:int): return (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=disp))
if sys.platform == "win32": src = (x.replace(tag=((RCX, RDX, GPR[8], GPR[9])[i],)),) if i < 4 else _stack_arg((i-3)*8+32)
else: src = (x.replace(tag=((RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i],)),) if i < 6 else _stack_arg((i-5)*8)
# buffer params hold addresses, their value moves as a 64bit int
dt = dtypes.uint64 if x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.GLOBAL else x.dtype
# the shape srcs of a PARAM are not values, tag them so they aren't materialized into registers
def _reg_arg(r:Register) -> tuple[UOp, ...]: return (x.replace(dtype=dt, src=tuple(s.rtag() for s in x.src), tag=(r,)),)
def _stack_arg(disp:int):
return (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=disp), imm(dtypes.uint8, 8))
if sys.platform == "win32": src = _reg_arg((RCX, RDX, GPR[8], GPR[9])[i]) if i < 4 else _stack_arg((i-3)*8+32)
else: src = _reg_arg((RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i]) if i < 6 else _stack_arg((i-5)*8)
# this move "cleanses" the abi register constraint
return x.ins(X86Ops.MOV, src=src)
return x.ins(X86Ops.MOV, dtype=dt, src=src)
def alloc_vregs(ctx:IselContext, x:UOp) -> UOp|None:
# real registers
if x.op is Ops.DEFINE_REG and x.tag is not None: return None
# register placeholders with real registers
if x.arg is X86Ops.DEFINE and x.tag is not None: return None
# this is an immediate
if x.arg is X86Ops.FRAME_INDEX: return None
# no register definition
if x.dtype is dtypes.void: return None
# already allocated vregs
if isinstance(x.tag, tuple) and x.tag[0]._cons: return None
# allocate vreg definitions
# allocate vreg definitions, the value of a BUFFER is its address so it lives in a gpr
defs = []
if isinstance(x.tag, tuple): defs = [ctx.vreg(x.tag)]
elif x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType): defs = [ctx.vreg(WGPR)]
elif x.op is Ops.BUFFER or x.dtype in dtypes.ints+(dtypes.bool,): defs = [ctx.vreg(WGPR)]
elif x.dtype in dtypes.floats or x.dtype.count > 1: defs = [ctx.vreg(XMM)]
# TODO: add this once the scheduler can track register pressure
# if x.arg in X86GroupOp.WriteFlags: defs.append(ctx.vreg(RFLAGS))
# the size src of a BUFFER is not a value, tag it so it isn't materialized into a register
if x.op is Ops.BUFFER: return x.replace(src=tuple(s.rtag() for s in x.src), tag=tuple(defs))
return x.replace(tag=tuple(defs))
dts = dtypes.ints + (dtypes.bool, dtypes.float16, dtypes.float32, dtypes.float64)
@ -393,11 +387,12 @@ dt_128bit = tuple(dt.vec(l) for dt in dts for l in [16,8,4,2,1] if l*dt.itemsize
isel_matcher = PatternMatcher([
# **** Op -> Op ****
# cast to pointer is a noop
(UPat.var("y").cast(name="x"), lambda y,x: y if isinstance(x.dtype, PtrDType) or y.dtype == dtypes.void else None),
# float gep(0) is a noop as it just moves the 0th element from one xmm register to another
# cast of void is a noop
(UPat.var("y").cast(name="x"), lambda y,x: y if y.dtype == dtypes.void else None),
# extracting the 0th float element is a noop as it just moves the 0th element from one xmm register to another
# this is done here to not interfere with shuffles
(UPat(dtype=dtypes.floats).gep(0, name="x"), lambda x: x.replace(op=Ops.NOOP, arg=None)),
(UPat(dtype=dtypes.floats).index(UPat(Ops.CONST, arg=0), name="x"),
lambda x: x.replace(op=Ops.NOOP, src=x.src[:1]) if x.src[0].dtype.count > 1 else None),
# range is lowered to acc, cmp, jmp after regalloc
(UPat(Ops.RANGE, src=(UPat.cvar("c"),), allow_any_len=True, name="x"), lambda c,x: x.replace(src=(imm(c.dtype, c.arg),) + x.src[1:])),
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(tag=(ctx.vreg(WGPR),)) if not isinstance(x.tag, tuple) else None),
@ -409,9 +404,6 @@ isel_matcher = PatternMatcher([
if not x.src or x.src[0].arg is not X86Ops.RET else None),
# function abi constraints
(UPat((Ops.PARAM, Ops.SPECIAL), name="x"), abi),
# these are treated the same for now
(UPat(Ops.DEFINE_REG, name="x"), lambda x:
x.replace(op=Ops.DEFINE_LOCAL, dtype=x.dtype.base.ptr(x.dtype.size, AddrSpace.LOCAL)) if isinstance(x.arg, int) else None),
# constants that can't be immediates, move them to registers
(UPat.cvar("x", dtypes.int64s), lambda x: x.ins(X86Ops.MOVABS, src=(imm(x.dtype, x.arg),)) if not x.tag else None),
(UPat.cvar("x", dtypes.ints+(dtypes.bool,)), lambda x: x.ins(X86Ops.MOVi, src=(imm(x.dtype, x.arg),)) if not x.tag else None),
@ -479,12 +471,17 @@ isel_matcher = PatternMatcher([
(UPat(Ops.STACK, dtypes.float32, name="x"), vinsertps),
(UPat.var("y", dtypes.ints+(dtypes.bool,)).broadcast(name="x"), vpbroadcast),
(UPat(Ops.STACK, dtypes.ints+(dtypes.bool,), name="x"), vpins),
# gep
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.floats).gep(name="x"), lambda y,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, x.arg[0] * x.dtype.itemsize)))),
# INDEX on a vector register value extracts a single element
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
(UPat.var("y", dtypes.int16s).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
(UPat.var("y", dtypes.int32s).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
(UPat.var("y", dtypes.int64s).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
(UPat.var("y", dtypes.floats).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, c.arg * x.dtype.itemsize))) if y.dtype.count > 1 else None),
# fused multiply add
((UPat(Ops.MUL, dtypes.float32, name="a") + UPat.var("b")).named("c"), lambda ctx,a,b,c:
a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
@ -578,8 +575,9 @@ isel_matcher = PatternMatcher([
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.ins(X86Ops.VMOVQ)),
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
# index
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
# index on a buffer (or the stack pointer) computes an address, addresses are 64bit values
(UPat((Ops.INDEX, Ops.SHRINK), name="x"),
lambda x: x.ins(X86Ops.LEA, dtype=dtypes.uint64, src=fold_address(x)) if x.src[0].dtype.count == 1 else None),
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
# copy, load, store
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
@ -608,7 +606,7 @@ isel_matcher = PatternMatcher([
(UPat(Ops.INS, src=(UPat(), UPat(), UPat(Ops.LOAD, src=(UPat(name="a"),), name="y")), allow_any_len=True, name="x"), lambda ctx,y,a,x:
x.replace(src=x.src[:2] + fold_address(a) + x.src[3:]) if x.arg in X86GroupOp.ReadMem3rd and is_foldable(ctx, x, y) else None),
# allocate virtual registers
(UPat((Ops.INS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), alloc_vregs),
(UPat((Ops.INS, Ops.BUFFER), name="x"), alloc_vregs),
])
# ***** pre register allocation *****
@ -656,14 +654,16 @@ post_regalloc_matcher = PatternMatcher([
# ***** X86 instruction encoding *****
def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) -> bytes|None:
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, sz_uop:UOp|None=None,
vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
nonlocal reg, opc
# get the encoding values of the different fields
reg = cast(int, cast(Register, reg_uop.reg).index if reg_uop is not None else reg)
rm = cast(Register, rm_uop.reg).index
idx = cast(Register, idx_uop.reg).index if idx_uop is not None and idx_uop.reg is not None else 4
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else rm_uop.dtype.itemsize
reg_sz = (reg_uop.dtype.itemsize if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
# for a memory operand the rm size is the element size from the address, otherwise it's the size of the value in the register
rm_sz = sz_uop.arg if sz_uop is not None else rm_uop.dtype.itemsize
reg_sz = reg_uop.dtype.itemsize if reg_uop is not None else 0
sz = reg_sz or rm_sz
# encode instruction
@ -723,19 +723,19 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) ->
# when a uop writes to memory it takes the form of a store, dtype is void, no definition
address:tuple[UOp|None, ...]
if x.arg in X86GroupOp.WriteMem:
if len(x.src) > 3: address, rest = x.src[:3], x.src[3:]
else: address, rest = (x, None, None), x.src
if len(x.src) > 4: address, rest = x.src[:4], x.src[4:]
else: address, rest = (x, None, None, None), x.src
return _encode(rest[0], *address, *(None, *rest[1:])) if reg is None else _encode(None, *address, *(None, *rest[:1]))
if x.arg in X86GroupOp.Rm1st:
if len(x.src) > 2: address, rest = x.src[:3], x.src[3:]
else: address, rest = (x.src[0], None, None), x.src[1:]
if len(x.src) > 3: address, rest = x.src[:4], x.src[4:]
else: address, rest = (x.src[0], None, None, None), x.src[1:]
imm_uop = rest[:1] if rest and rest[0].op is Ops.CONST else (None,)
return _encode(x, *address, *(None, *imm_uop)) if reg is None else _encode(None, *address, *(x if sel else None, *imm_uop))
if x.arg in X86GroupOp.Rm2nd:
if len(x.src) > 3: address, rest = x.src[1:4], x.src[:1] + x.src[4:]
else: address, rest = (x.src[1], None, None), x.src[:1] + x.src[2:]
if len(x.src) > 4: address, rest = x.src[1:5], x.src[:1] + x.src[5:]
else: address, rest = (x.src[1], None, None, None), x.src[:1] + x.src[2:]
# cmp/vucomiss reg, rm don't define a new register
return _encode(x, *address, *rest) if x.dtype is not dtypes.void else _encode(rest[0], *address)
@ -770,8 +770,9 @@ encodings = {
X86Ops.VCVTDQ2PS: lambda x: encode(x, 0x5B, pp=0, sel=1), X86Ops.VCVTDQ2PD: lambda x: encode(x, 0xE6, pp=2, sel=1),
X86Ops.VCVTPS2PD: lambda x: encode(x, 0x5A, pp=0, sel=1), X86Ops.VCVTPD2PS: lambda x: encode(x, 0x5A, pp=1, sel=1),
X86Ops.VCVTTPS2DQ: lambda x: encode(x, 0x5B, pp=2, sel=1), X86Ops.VCVTTPD2DQ: lambda x: encode(x, 0xE6, pp=1, sel=1),
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.itemsize == 8),
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.itemsize == 8),
# the int src is the 2nd src (the rm field), if it was folded into a memory operand its width is the element size of the address
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=(x.src[4].arg if len(x.src) > 4 else x.src[1].dtype.itemsize) == 8),
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=(x.src[4].arg if len(x.src) > 4 else x.src[1].dtype.itemsize) == 8),
X86Ops.VCVTTSS2SI: lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype.itemsize == 8),
X86Ops.VCVTTSD2SI: lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype.itemsize == 8),
# int division
@ -871,43 +872,41 @@ class X86Renderer(ISARenderer):
self.compiler = X86Compiler()
def is_two_address(self, x:UOp) -> bool: return x.arg in X86GroupOp.TwoAddress
def stack_pointer(self) -> UOp: return def_reg(dtypes.uint64, RSP)
# nasty hacks to deal with pointers TODO: rm pointers
# the value of a BUFFER is its address, it moves through registers and the stack as a 64bit int
def copy(self, x:UOp, reg:Register):
dt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
ret = isel_matcher.rewrite(UOp(Ops.COPY, dt, (x,), tag=reg))
ret = isel_matcher.rewrite(UOp(Ops.COPY, dtypes.uint64 if x.op is Ops.BUFFER else x.dtype, (x,), tag=reg))
assert ret is not None
return ret.replace(dtype=x.dtype)
return ret
def spill(self, disp:UOp, x:UOp) -> UOp:
nx = x.replace(dtype=dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype)
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(nx))
if x.op is Ops.BUFFER: x = x.replace(dtype=dtypes.uint64)
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(x))
assert ret is not None
return ret.replace(src=(s if s is not nx else x for s in ret.src))
return ret
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp:
ndt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=ndt, tag=reg))
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=dtypes.uint64 if x.op is Ops.BUFFER else x.dtype, tag=reg))
assert ret is not None
return ret.replace(dtype=x.dtype)
return ret
def asm_str(self, uops:list[UOp], function_name:str) -> str:
def _format_op(x:UOp) -> str: return f" {(o[7:-1] if (o:=str(x.arg))[-1] in ('i', 'm') else o[7:]).lower():7s}"
def _format_operands(x:UOp) -> str:
def _format(src:tuple[UOp, ...]) -> list[str]:
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize if not isinstance(s.dtype, PtrDType) else 8, o) if \
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize, o) if \
(o:=str(s.reg)) in reg_strs else o for s in src if s.reg is not None]
def _mem_adress(base:UOp, idx:UOp, disp:UOp) -> list[str]:
return [f"[{base.reg}" + (f" + {idx.reg}*{base.dtype.itemsize}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"]
def _mem_adress(base:UOp, idx:UOp, disp:UOp, sz:UOp) -> list[str]:
return [f"[{base.reg}" + (f" + {idx.reg}*{sz.arg}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"]
if len(x.src) > 3 and x.arg in X86GroupOp.WriteMem: ret = _mem_adress(*x.src[:3]) + _format(x.src[3:])
elif len(x.src) > 2 and x.arg in X86GroupOp.Rm1st: ret = _format((x,)) + _mem_adress(*x.src[:3]) + _format(x.src[3:])
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm2nd: ret = _format((x, x.src[0])) + _mem_adress(*x.src[1:4]) + _format(x.src[4:])
if len(x.src) > 4 and x.arg in X86GroupOp.WriteMem: ret = _mem_adress(*x.src[:4]) + _format(x.src[4:])
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm1st: ret = _format((x,)) + _mem_adress(*x.src[:4]) + _format(x.src[4:])
elif len(x.src) > 4 and x.arg in X86GroupOp.Rm2nd: ret = _format((x, x.src[0])) + _mem_adress(*x.src[1:5]) + _format(x.src[5:])
else: ret = _format((x,) + x.src)
return ", ".join(ret)
asm = [f".{function_name}:"]
for u in uops:
if u.op is not Ops.INS: continue
if u.op is not Ops.INS or u.arg is X86Ops.DEFINE: continue
if u.arg is X86Ops.LABEL: asm.append(f"{str(u.tag)}:")
elif u.arg is X86Ops.RET: asm.append(_format_op(u))
else: asm.append(_format_op(u) + " " + _format_operands(u))
@ -918,7 +917,7 @@ class X86Renderer(ISARenderer):
jumps: dict[UOp, int] = {}
binary = bytearray()
for u in uops:
if u.op is not Ops.INS: continue
if u.op is not Ops.INS or u.arg is X86Ops.DEFINE: continue
if u.arg is X86Ops.LABEL:
targets[u.tag] = len(binary)
continue

View file

@ -23,8 +23,8 @@ dsp_pm_late = PatternMatcher([
(UPat.var("x")+UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat.var("x")*UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat.var("x")//UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat(Ops.DEFINE_REG, src=(UPat(Ops.STACK, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
(UPat(Ops.BUFFER, src=(UPat(Ops.STACK, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:]) if d.addrspace is AddrSpace.REG else None),
])
# NOTE: this just increases readability of the generated code

View file

@ -8,8 +8,8 @@ from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_claus
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.SLICE,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.FUNCTION}
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.LOAD, Ops.CALL, Ops.FUNCTION}
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None

View file

@ -427,7 +427,7 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
ended_stores.append(store_target.replace(dtype=sdtype).store(store.src[1]).end(*end_rngs))
return buf.after(*ended_stores)
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
# NOTE: the local BUFFER needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
if x.src[0].op is Ops.SLICE:
@ -561,11 +561,11 @@ rangeify_codegen = PatternMatcher([
# TODO: this can be moved into codegen?
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0] if len(x.src) else None),
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if isinstance(idx.dtype, PtrDType) else
(UPat(Ops.BUFFER).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if dg.addrspace is not AddrSpace.LOCAL or isinstance(idx.dtype, PtrDType) else
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).gep(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if isinstance(idx.dtype, PtrDType) else
(UPat(Ops.BUFFER).f(Ops.AFTER, allow_any_len=True).gep(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if dg.addrspace is not AddrSpace.LOCAL or isinstance(idx.dtype, PtrDType) else
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
])

View file

@ -19,10 +19,7 @@ class Ops(FastEnum):
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
SPECIAL = auto()
# define LOCAL/REG allocate things
DEFINE_LOCAL = auto(); DEFINE_REG = auto()
# BUFFER is the new LOCAL/REG
# BUFFER allocates global/local/register storage depending on its addrspace
BUFFER = auto()
# ** 2 -- non op uops **
@ -125,7 +122,7 @@ class GroupOp:
# TODO: is BITCAST always Elementwise if it's shape changing?
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
Defines = {Ops.PARAM, Ops.BUFFER}
Irreducible = {Ops.CONST, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}

View file

@ -508,7 +508,7 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa
return PatternMatcher(pat)
pm_long_decomp = PatternMatcher([
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
(UPat((*GroupOp.Defines, Ops.BUFFER, Ops.INDEX), name="x"), lambda x:
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None),
(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])),
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
@ -531,7 +531,7 @@ pm_long_decomp = PatternMatcher([
# float decomposition patterns - ctx is (fr, to) tuple
pm_float_decomp = PatternMatcher([
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x:
(UPat((*GroupOp.Defines, Ops.BUFFER, Ops.INDEX), name="x"), lambda ctx,x:
x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None),
(UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None),
# bitcasted load should just replace load

View file

@ -279,7 +279,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
case Ops.GETADDR: return ()
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
case Ops.BINARY: return (len(self.arg),)
case Ops.BUFFER: return self.src[0].as_shape if isinstance(self.arg, ParamArg) else (self.arg,)
case Ops.BUFFER:
if isinstance(self.arg, ParamArg):
if len(self.src): return self.src[0].as_shape
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size, self.dtype.count) if self.dtype.count > 1 else (self.ptrdtype.size,)
return (self.dtype.count,) if self.dtype.count > 1 else ()
return (self.arg,)
case Ops.SLICE:
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
if self.src[0].op is Ops.INDEX: return ()
@ -288,13 +293,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
case Ops.STAGE:
# STAGE adds the existing shape to the front, opposite of INDEX
return tuple([int(r.vmax+1) for r in self.src[1:]])+self.src[0].shape
case Ops.DEFINE_LOCAL | Ops.DEFINE_REG:
if len(self.src) >= 1:
# NOTE: this is the same as PARAM
return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
if isinstance(self.dtype, PtrDType):
return (self.ptrdtype.size, self.dtype.count) if self.dtype.count > 1 else (self.ptrdtype.size,)
return (self.dtype.count,) if self.dtype.count > 1 else ()
case Ops.PARAM:
if isinstance(self.dtype, ImageDType): return self.dtype.shape
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
@ -794,8 +792,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def addrspace(self) -> AddrSpace|None:
if self.op is Ops.PARAM: return self.arg.addrspace
if self.op is Ops.BUFFER: return self.arg.addrspace if isinstance(self.arg, ParamArg) else AddrSpace.GLOBAL
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
if self.op in {Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
if self.op is Ops.LOAD: return AddrSpace.ALU # LOAD brings things into the ALU
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
@ -1060,10 +1056,13 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
@staticmethod
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
arg = ParamArg(slot, addrspace=addrspace) if addrspace is AddrSpace.GLOBAL else slot
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=arg)
if len(shape) > 1: ret = ret.reshape(shape)
if addrspace is AddrSpace.GLOBAL:
ret = UOp(Ops.PARAM, dtype.ptr(prod(shape), addrspace), arg=ParamArg(slot, addrspace=addrspace))
else:
assert addrspace in (AddrSpace.LOCAL, AddrSpace.REG)
buf_shape = (prod(shape),) + ((dtype.count,) if dtype.count > 1 else ())
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
return ret
def placeholder_like(self, slot:int):
assert all_int(self.shape), "no placeholder-like on symbolic shape"
@ -1152,7 +1151,7 @@ class ProgramInfo:
if u.op is Ops.PARAM and u.addrspace != AddrSpace.ALU: _globals.append(u.arg.slot)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
if (buf:=idx.src[0].buf_uop).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
if u.op is Ops.SPECIAL:
if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size

View file

@ -79,17 +79,15 @@ spec_shared = PatternMatcher([
# PARAM
(UPat(Ops.PARAM, name="x"), lambda x: isinstance(x.arg, ParamArg)),
(UPat(Ops.BUFFER, src=(UPat(),), name="x"), lambda x:
isinstance(x.arg, ParamArg) and x.addrspace in (AddrSpace.REG, AddrSpace.LOCAL)),
# GROUP of stores (or groups, or NOOPs)
# TODO: remove UNROLL here, it's for SPEC=2
(UPat(Ops.GROUP, dtypes.void, src=UPat((Ops.GROUP, Ops.STORE, Ops.NOOP, Ops.UNROLL, Ops.INS))), lambda: True),
# TOOD: these should be buffer with different addrspace everywhere.
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), lambda: True),
# AFTER on Movement Op, PARAM, BUFFER, CONTIGUOUS, or another AFTER
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.AFTER, Ops.MULTI,
Ops.BITCAST, Ops.INS})),),
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.AFTER, Ops.MULTI, Ops.BITCAST, Ops.INS})),),
allow_any_len=True), lambda: True),
# CUSTOM (inline and non inline)
@ -176,7 +174,7 @@ spec_tensor = PatternMatcher([
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None),
lambda root,x: root.dtype == x.dtype),
# TODO: this should not be here. STAGE is transformed to DEFINE_LOCAL later
# TODO: this should not be here. STAGE is transformed to BUFFER later
(UPat(Ops.STAGE, src=(UPat(),), allow_any_len=True), lambda: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
@ -196,7 +194,7 @@ spec_tensor = PatternMatcher([
# these ops can exist in programs but not the tensor spec. example: LOAD
spec_program = PatternMatcher([
# no more of these in programs
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.GEP)), lambda: False),
(UPat(Ops.GEP), lambda: False),
# weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False),

View file

@ -46,7 +46,7 @@ from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphE
from tinygrad.dtype import dtypes, AddrSpace
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.SHAPED_WMMA: "#FF5B5B",
Ops.SHAPED_WMMA: "#FF5B5B",
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#D8F9E4", Ops.STACK: "#D8F9E4",
Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",