mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into clean_load
This commit is contained in:
commit
7a214c4499
36 changed files with 293 additions and 350 deletions
6
.github/actions/setup-tinygrad/action.yml
vendored
6
.github/actions/setup-tinygrad/action.yml
vendored
|
|
@ -80,7 +80,7 @@ runs:
|
|||
- name: Cache Python packages (PR)
|
||||
if: github.event_name == 'pull_request'
|
||||
id: restore-venv-pr
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: /tmp/.uv-cache
|
||||
key: uv-${{ runner.os }}-${{ runner.arch }}-python-${{ inputs.python-version }}-${{ inputs.deps }}-${{ inputs.pydeps }}-${{ env.CACHE_VERSION }}
|
||||
|
|
@ -96,7 +96,7 @@ runs:
|
|||
|
||||
- name: Cache downloads (PR)
|
||||
if: inputs.key != '' && github.event_name == 'pull_request'
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: ${{ runner.os == 'Linux' && '~/.cache/tinygrad/downloads/' || '~/Library/Caches/tinygrad/downloads/' }}
|
||||
key: downloads-${{ github.job }}-${{ inputs.key }}-${{ env.CACHE_VERSION }}
|
||||
|
|
@ -203,7 +203,7 @@ runs:
|
|||
|
||||
- name: Cache apt (PR)
|
||||
if: runner.os == 'Linux' && (inputs.opencl == 'true' || inputs.amd == 'true' || inputs.webgpu == 'true' || inputs.llvm == 'true' || inputs.qemu == 'true') && github.event_name == 'pull_request'
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: /var/cache/apt/archives/
|
||||
key: ${{ runner.os }}-${{ runner.arch }}-apt-${{ steps.apt-pkgs.outputs.hash }}-${{ env.CACHE_VERSION }}
|
||||
|
|
|
|||
91
.github/workflows/benchmark.yml
vendored
91
.github/workflows/benchmark.yml
vendored
|
|
@ -99,7 +99,6 @@ jobs:
|
|||
ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu
|
||||
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
|
||||
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
|
|
@ -134,32 +133,10 @@ jobs:
|
|||
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
|
||||
- name: Fuzz Padded Tensor Core GEMM
|
||||
run: DEV=METAL M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run quantized LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
|
||||
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
|
||||
- name: Run quantized LLaMA3
|
||||
run: |
|
||||
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
|
||||
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
|
||||
#- name: Run LLaMA 7B on 4 (virtual) GPUs
|
||||
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run OLMoE
|
||||
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
|
||||
- name: Run llama3.2
|
||||
run: BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run olmoe
|
||||
run: BENCHMARK_LOG=olmoe JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m olmoe --benchmark --warmup
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
|
||||
|
||||
|
|
@ -235,9 +212,6 @@ jobs:
|
|||
- name: Symlink models and datasets
|
||||
run: |
|
||||
mkdir -p weights
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
|
||||
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
|
||||
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
|
||||
mkdir -p extra/datasets
|
||||
ln -s /raid/datasets/imagenet extra/datasets/imagenet
|
||||
|
|
@ -279,36 +253,16 @@ jobs:
|
|||
# TODO: too slow
|
||||
# - name: Run SDXL
|
||||
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit DEV=NV JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama DEV=NV JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run llama3.2
|
||||
run: DEV=NV BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run qwen3.5
|
||||
run: DEV=NV BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run quantized LLaMA3
|
||||
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit DEV=NV JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 DEV=NV JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half DEV=NV HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam DEV=NV HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: Speed (NVIDIA)
|
||||
|
|
@ -402,10 +356,7 @@ jobs:
|
|||
run: |
|
||||
mkdir -p weights
|
||||
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
|
||||
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
|
||||
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
|
||||
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
|
||||
mkdir -p extra/datasets
|
||||
ln -s /raid/datasets/imagenet extra/datasets/imagenet
|
||||
|
|
@ -458,18 +409,10 @@ jobs:
|
|||
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 DEV=AMD python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
|
||||
- name: Run SDXL
|
||||
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 DEV=AMD python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA 7B
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit DEV=AMD JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama DEV=AMD JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA 7B with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run llama3.2
|
||||
run: DEV=AMD BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run qwen3.5
|
||||
run: DEV=AMD BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
|
|
@ -478,16 +421,6 @@ jobs:
|
|||
# run: sudo modprobe amdgpu
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral DEV=AMD python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit DEV=AMD JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 DEV=AMD JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half DEV=AMD HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam DEV=AMD HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
|
|
|||
26
.github/workflows/test.yml
vendored
26
.github/workflows/test.yml
vendored
|
|
@ -469,7 +469,7 @@ jobs:
|
|||
# ****** Models Tests ******
|
||||
|
||||
testmodels:
|
||||
name: Models (llvm+cpu+gpu)
|
||||
name: Models
|
||||
runs-on: *linux
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
|
|
@ -480,34 +480,12 @@ jobs:
|
|||
with:
|
||||
key: models
|
||||
deps: testing
|
||||
opencl: 'true'
|
||||
llvm: 'true'
|
||||
- name: Test models (llvm)
|
||||
run: DEV=CPU:LLVM python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (opencl)
|
||||
run: DEV=CL python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (cpu)
|
||||
run: DEV=CPU python -m pytest -n=auto test/models --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testmetalmodels:
|
||||
name: Models (metal)
|
||||
runs-on: &macos macos-26
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: metal
|
||||
deps: testing
|
||||
- name: Test models (Metal)
|
||||
run: DEV=METAL python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test LLaMA compile speed
|
||||
run: DEV=METAL python test/external/external_test_speed_llama.py
|
||||
|
||||
# ****** Feature Tests ******
|
||||
|
||||
testdsp:
|
||||
|
|
@ -716,7 +694,7 @@ jobs:
|
|||
|
||||
unittestmacos:
|
||||
name: MacOS (unit)
|
||||
runs-on: *macos
|
||||
runs-on: &macos macos-26
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
|
|
|
|||
|
|
@ -458,7 +458,8 @@ def test_matmul():
|
|||
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
|
||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
|
||||
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
|
||||
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
|
|||
P_lds = QP_lds[:, :BLOCK_N]
|
||||
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
|
||||
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
|
||||
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) — shaped store fails due to RESHAPE(DEFINE_LOCAL) surviving linearization
|
||||
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) -- shaped store fails due to RESHAPE(local BUFFER) surviving linearization
|
||||
rw1 = UOp.range(TM, 296, AxisType.LOOP)
|
||||
rw2 = UOp.range(TN, 297, AxisType.LOOP)
|
||||
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)
|
||||
|
|
|
|||
|
|
@ -2619,7 +2619,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
|||
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
|
||||
gidx = UOp.special(NUM_WG, "gidx0")
|
||||
insts = build_kernel(batch, M, N, K, A.dtype.base)
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds = UOp.placeholder((133_120,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
|
||||
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),
|
||||
|
|
|
|||
|
|
@ -219,7 +219,8 @@ def test_matmul():
|
|||
def asm_kernel(A, B, C):
|
||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||
lidxs = [UOp.special(THREADS, "lidx0")]
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2))
|
||||
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
|
||||
arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
|
|
|||
|
|
@ -134,11 +134,14 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
int sa_idx = block_row, sb_idx = block_col;
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
|
||||
if (k + 1 < k_iters) {
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
|
|
@ -194,8 +197,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
{ // Epilogue k = k_iters - 2
|
||||
int k = k_iters - 2;
|
||||
if (k + 1 < k_iters) {
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -227,6 +227,7 @@ Ternary & $(P, A, B)$
|
|||
\midrule
|
||||
\op{Barrier} & (deps\ldots) & --- & Synchronize threads within a workgroup. \\
|
||||
\op{Ins} & \ldots & \ldots & A single machine instruction (e.g.\ AMD ISA). \\
|
||||
\op{GetAddr} & (buf, dev) & --- & Lower buf to its address on device dev. \\
|
||||
\op{Special} & (bound,) & name & GPU thread/workgroup index (e.g.\ \texttt{gidx0}, \texttt{lidx1}). \\
|
||||
\op{If} & (gate,) & --- & Begin conditional execution block. \\
|
||||
\op{Endif} & (if,) & --- & End conditional execution block. \\
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def custom_lds_sync(A:UOp, arch:str) -> UOp:
|
|||
num_threads = A.shape[0]
|
||||
threads = UOp.special(num_threads, "lidx0")
|
||||
wg = UOp.special(1, "gidx0")
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
|
||||
lds = UOp.placeholder((512,), dtypes.uint8, 0, AddrSpace.LOCAL) # 128 * 4 bytes
|
||||
isa = r4 if arch == "rdna4" else r3
|
||||
wait_kmcnt = [isa.s_wait_kmcnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
|
||||
wait_dscnt = [isa.s_wait_dscnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
|
||||
|
|
@ -103,7 +103,7 @@ def custom_handwritten(A:UOp) -> UOp:
|
|||
A = A.flatten()
|
||||
threads = UOp.special(128, "lidx0")
|
||||
wg = UOp.special(1, "gidx0")
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
|
||||
lds = UOp.placeholder((512,), dtypes.uint8, 0, AddrSpace.LOCAL) # 128 * 4 bytes
|
||||
pipes = {getenv("PIPE", "")} if getenv("PIPE", "") else {"SALU", "VALU", "TRANSCENDENTAL", "WMMA"}
|
||||
k = Kernel()
|
||||
# wrap in loop to filter out icache misses
|
||||
|
|
|
|||
|
|
@ -338,11 +338,15 @@ class TestGemmMXFP8(unittest.TestCase):
|
|||
def test_llama_ffn(self): run_mxfp8_gemm(8192, 14336, 4096)
|
||||
def test_llama_ffn2(self): run_mxfp8_gemm(8192, 4096, 14336)
|
||||
def test_llama_qkv(self): run_mxfp8_gemm(8192, 4096, 4096)
|
||||
def test_general_n_fw(self):
|
||||
for N in (256, 1792, 2048, 8192): run_mxfp8_gemm(8192, N, 4096)
|
||||
# backward needs all dims tile-aligned (dgrad reduces N, wgrad reduces M)
|
||||
def test_bw_simple(self): run_mx_gemm_bw(256, 256, 256)
|
||||
def test_bw_rect(self): run_mx_gemm_bw(512, 256, 512)
|
||||
def test_bw_w_post(self): run_mx_gemm_bw(256, 256, 256, w_post=True)
|
||||
def test_bw_llama_qkv(self): run_mx_gemm_bw(8192, 4096, 4096)
|
||||
def test_general_n_bw(self):
|
||||
for N in (2048, 8192, 14336): run_mx_gemm_bw(8192, N, 4096)
|
||||
# MP sharding: col-parallel (w on out axis), row-parallel (x,w on in axis)
|
||||
@needs_second_gpu
|
||||
def test_multi_col_parallel(self): run_mx_gemm_multi(512, 512, 512, x_shard=None, w_shard=0, g_shard=1)
|
||||
|
|
|
|||
|
|
@ -14,49 +14,51 @@ class TestEncodingsX86(unittest.TestCase):
|
|||
|
||||
# displacement of 0 isn't emitted
|
||||
def test_base_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RDI)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RDI)
|
||||
# mov edi, dword ptr [rdi]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 3F"))
|
||||
|
||||
# rsp/r12 require a sib byte when used as base memory address
|
||||
def test_rsp_base_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RSP)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RSP)
|
||||
# mov esp, dword ptr [rsp]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 24 24"))
|
||||
|
||||
# rbp/r13 require a displacement when used as base memory address
|
||||
def test_rbp_base_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RBP)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RBP)
|
||||
# mov ebp, dword ptr [rbp + 0]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 6D 00"))
|
||||
|
||||
# test [base + index*scale]
|
||||
def test_base_index_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0)), RAX)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
|
||||
# mov eax, dword ptr [rax + rdx*4]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 04 90"))
|
||||
|
||||
# rsp as index means no index
|
||||
def test_rsp_index_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0)), RAX)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
|
||||
# mov eax, dword ptr [rax]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 00"))
|
||||
|
||||
# however r12 is a valid index
|
||||
def test_r12_index_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0)), RAX)
|
||||
load = ins(X86Ops.MOV, dtypes.int32,
|
||||
(def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
|
||||
# mov eax, dword ptr [rax + r12*4]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("42 8B 04 A0"))
|
||||
|
||||
# test [base + index*scale + 8bit disp]
|
||||
def test_complex_address_8bit_disp(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)), RDI)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4)), RDI)
|
||||
# mov edi, dword ptr [rdi + rsi*4 + 0xa]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 7C B7 0A"))
|
||||
|
||||
# test [base + index*scale + 32bit disp]
|
||||
def test_complex_address_32bit_disp(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000)), RDI)
|
||||
load = ins(X86Ops.MOV, dtypes.int32,
|
||||
(def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000), imm(dtypes.uint8, 4)), RDI)
|
||||
# mov edi, dword ptr [rdi + rsi*4 + 0x2710]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B BC B7 10 27 00 00"))
|
||||
|
||||
|
|
@ -114,28 +116,28 @@ class TestEncodingsX86(unittest.TestCase):
|
|||
|
||||
# when writting to mem the uop takes the store form where dtype is void and there's no definition
|
||||
def test_write_mem(self):
|
||||
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
|
||||
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4))
|
||||
xmm0 = def_reg(dtypes.float32, XMM[0])
|
||||
extr = ins(X86Ops.VPEXTRD, dtypes.void, (base, index, disp, xmm0, imm(dtypes.uint8, 0)))
|
||||
extr = ins(X86Ops.VPEXTRD, dtypes.void, address + (xmm0, imm(dtypes.uint8, 0)))
|
||||
# vpextrd dword ptr [rdi + rsi*4 + 0xa], xmm0, 0
|
||||
self.assertEqual(bytes.fromhex(self.encode(extr)), bytes.fromhex("C4 E3 79 16 44 B7 0A 00"))
|
||||
|
||||
# test two address instruction with fused load works
|
||||
def test_two_address_load(self):
|
||||
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
|
||||
cmove = ins(X86Ops.CMOVE, dtypes.int32, (base, index, disp), RAX)
|
||||
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4))
|
||||
cmove = ins(X86Ops.CMOVE, dtypes.int32, address, RAX)
|
||||
# cmove eax, dword ptr [rdi + rsi*4 + 0xa]
|
||||
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 44 B7 0A"))
|
||||
|
||||
# test instruction where displacement and imm have the same value
|
||||
def test_disp_imm_same_value(self):
|
||||
base, index, disp = def_reg(dtypes.int8.ptr(), RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10)
|
||||
mov = ins(X86Ops.MOVi, dtypes.void, (base, index, disp, disp))
|
||||
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 1))
|
||||
mov = ins(X86Ops.MOVi, dtypes.void, address + (imm(dtypes.int8, 10),))
|
||||
# mov byte ptr [rdi + rsi + 0xa], 0xa
|
||||
self.assertEqual(bytes.fromhex(self.encode(mov)), bytes.fromhex("40 C6 44 37 0A 0A"))
|
||||
|
||||
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10)
|
||||
imul = ins(X86Ops.IMULi, dtypes.int32, (base, index, disp) + (imm(dtypes.int32, 10),), RDI)
|
||||
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10), imm(dtypes.uint8, 4))
|
||||
imul = ins(X86Ops.IMULi, dtypes.int32, address + (imm(dtypes.int32, 10),), RDI)
|
||||
# imul edi, dword ptr [rdi + rsi*4 + 0xa], 0xa
|
||||
self.assertEqual(bytes.fromhex(self.encode(imul)), bytes.fromhex("69 BC B7 0A 00 00 00 0A 00 00 00"))
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
|
|||
from tinygrad.renderer.isa.x86 import X86Renderer, X86Ops
|
||||
from tinygrad.renderer.isa import IselContext
|
||||
|
||||
# INDEX on a register value with a constant index extracts a single element (the old GEP)
|
||||
def lane(y:UOp, i:int) -> UOp: return y.index(UOp.const(dtypes.int, i), dtype=y.dtype.scalar())
|
||||
|
||||
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only x86")
|
||||
class TestIselX86(unittest.TestCase):
|
||||
def isel_rewrite(self, x:UOp):
|
||||
|
|
@ -57,9 +60,9 @@ class TestIselX86(unittest.TestCase):
|
|||
# need to move src from gpr to xmm before broadcasting
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and n.src[0].arg is X86Ops.VMOVD)
|
||||
# if we can fuse a load we can skip the move and access memory directly
|
||||
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
|
||||
n = self.isel_rewrite(load.broadcast(4))
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 4)
|
||||
|
||||
def test_vbroadcastss(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32)
|
||||
|
|
@ -73,17 +76,17 @@ class TestIselX86(unittest.TestCase):
|
|||
d = UOp.variable("d", 0, 0, dtypes.float32)
|
||||
|
||||
valid = [UOp.vectorize(c, c, d, d),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), c, c),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(1), a.gep(2), a.gep(3), a.gep(0)),
|
||||
UOp.vectorize(a.gep(3), a.gep(2), a.gep(1), a.gep(0), a.gep(7), a.gep(6), a.gep(5), a.gep(4)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), b.gep(1), b.gep(1), a.gep(4), a.gep(4), b.gep(5), b.gep(5))]
|
||||
UOp.vectorize(lane(a, 0), lane(a, 1), c, c),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 1), lane(a, 2), lane(a, 3), lane(a, 0)),
|
||||
UOp.vectorize(lane(a, 3), lane(a, 2), lane(a, 1), lane(a, 0), lane(a, 7), lane(a, 6), lane(a, 5), lane(a, 4)),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 0), lane(b, 1), lane(b, 1), lane(a, 4), lane(a, 4), lane(b, 5), lane(b, 5))]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
|
||||
|
||||
invalid = [UOp.vectorize(a.gep(0), a.gep(1), b.gep(4), b.gep(5)),
|
||||
UOp.vectorize(a.gep(0), a.gep(5), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), a.gep(0), a.gep(0), a.gep(4), a.gep(4), a.gep(4), a.gep(5)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), b.gep(0), b.gep(0), a.gep(4), a.gep(4), b.gep(4), a.gep(4))]
|
||||
invalid = [UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 4), lane(b, 5)),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 5), lane(b, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 0), lane(a, 0), lane(a, 0), lane(a, 4), lane(a, 4), lane(a, 4), lane(a, 5)),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 0), lane(b, 0), lane(b, 0), lane(a, 4), lane(a, 4), lane(b, 4), lane(a, 4))]
|
||||
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
|
||||
|
||||
def test_vshufpd(self):
|
||||
|
|
@ -93,16 +96,16 @@ class TestIselX86(unittest.TestCase):
|
|||
d = UOp.variable("d", 0, 0, dtypes.float64)
|
||||
|
||||
valid = [UOp.vectorize(c, d),
|
||||
UOp.vectorize(a.gep(0), c),
|
||||
UOp.vectorize(a.gep(1), b.gep(1)),
|
||||
UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(1), a.gep(1), a.gep(3), a.gep(3))]
|
||||
UOp.vectorize(lane(a, 0), c),
|
||||
UOp.vectorize(lane(a, 1), lane(b, 1)),
|
||||
UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 1), lane(a, 1), lane(a, 3), lane(a, 3))]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
|
||||
|
||||
invalid = [UOp.vectorize(c, c, c, c),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(2), b.gep(3), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(0), b.gep(1), a.gep(0), b.gep(1))]
|
||||
UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 2), lane(b, 3), lane(a, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 0), lane(b, 1))]
|
||||
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
|
||||
|
||||
def test_vinsertps(self):
|
||||
|
|
@ -111,31 +114,31 @@ class TestIselX86(unittest.TestCase):
|
|||
c = UOp.variable("c", 0, 0, dtypes.float32.vec(4))
|
||||
d = UOp.variable("e", 0, 0, dtypes.float32)
|
||||
# moving 0th element to position 0 does nothing so only 1 vinsertps is generated
|
||||
n = self.isel_rewrite(UOp.vectorize(a.gep(0), d))
|
||||
n = self.isel_rewrite(UOp.vectorize(lane(a, 0), d))
|
||||
self.assertIs(n.arg, X86Ops.VINSERTPS)
|
||||
self.assertIsNot(n.src[0].arg, X86Ops.VINSERTPS)
|
||||
|
||||
valid = [UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(3), b.gep(2), c.gep(1), d)]
|
||||
valid = [UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 3), lane(b, 2), lane(c, 1), d)]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VINSERTPS)
|
||||
|
||||
# complex address is [base + index*scale + displacement]
|
||||
def test_complex_address(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
load = UOp.param(0, dtypes.int32.ptr()).index(a + 1, ptr=True).load()
|
||||
load = UOp.param(0, dtypes.int32, (16,)).index(a + 1).load()
|
||||
n = self.isel_rewrite(load)
|
||||
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
|
||||
self.assertTrue(n.src[2].op is Ops.CONST and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4)
|
||||
|
||||
def test_fold_load(self):
|
||||
load1 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load2 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 1), ptr=True).load()
|
||||
load1 = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
|
||||
load2 = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 1)).load()
|
||||
n = self.isel_rewrite(load1 + load2)
|
||||
self.assertTrue(len(n.src) == 4)
|
||||
self.assertTrue(len(n.src) == 5)
|
||||
|
||||
# don't fold when used multiple times
|
||||
def test_dont_fold_load(self):
|
||||
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
|
||||
# used by multiple users
|
||||
n = self.isel_rewrite(load + 1 + load)
|
||||
self.assertTrue(len(n.src) == 2)
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
|
||||
def _test_no_nested_ranges(self, lins, skip=None):
|
||||
for l in lins:
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG])
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)]
|
||||
for i,u in enumerate(ranges):
|
||||
if skip and i in skip: continue
|
||||
|
|
@ -161,7 +161,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
|
||||
uops = tuple(to_program(replace_opts(r.schedule_linear().src[-1].src[0],
|
||||
[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]), renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
||||
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
|
||||
accs = [u for u in uops if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG]
|
||||
stores = [u for u in uops if u.op is Ops.STORE]
|
||||
assert len(accs) == 0 # it's removed now
|
||||
assert len(stores) == 1
|
||||
|
|
@ -210,14 +210,14 @@ class TestLinearizer(unittest.TestCase):
|
|||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
realized_ast = a.schedule_linear().src[-1].src[0]
|
||||
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
|
||||
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
|
||||
local = [uop for uop in tuple(program.src[2].src) if uop.op is Ops.BUFFER and uop.addrspace in (AddrSpace.LOCAL, AddrSpace.REG)]
|
||||
assert local[0].dtype.base == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
realized_ast = c.schedule_linear().src[-1].src[0]
|
||||
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
|
||||
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
|
||||
local = [uop for uop in tuple(program.src[2].src) if uop.op is Ops.BUFFER and uop.addrspace in (AddrSpace.LOCAL, AddrSpace.REG)]
|
||||
self.assertEqual(local[0].dtype.base, expected_dtype)
|
||||
|
||||
tests = (
|
||||
|
|
@ -243,7 +243,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
r = (x@y).relu()
|
||||
opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]
|
||||
ast = helper_linearizer_opt(r, [opt])
|
||||
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
|
||||
# the uops graph is reg BUFFER -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
|
||||
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
||||
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
|
||||
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
|
||||
|
|
@ -361,7 +361,8 @@ class TestLinearizer(unittest.TestCase):
|
|||
ast = helper_linearizer_opt(out, opts=[opt])
|
||||
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
|
||||
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
||||
local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
|
||||
local_stores = [u for u in uops if u.op is Ops.STORE and any(
|
||||
x.op is Ops.BUFFER and x.addrspace is AddrSpace.LOCAL for x in get_recursive(u.src[0]))]
|
||||
global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.PARAM for x in get_recursive(u.src[0]))]
|
||||
barrier = [u for u in uops if u.op is Ops.BARRIER]
|
||||
assert len(barrier) == 1
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
import tempfile, unittest
|
||||
from tinygrad import Tensor, Context, Device, dtypes, UOp
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import run_linear
|
||||
from tinygrad.codegen import to_program
|
||||
|
|
@ -80,7 +81,7 @@ class TestQuantizeOnnxCPU(unittest.TestCase):
|
|||
with Context(QUANTIZE=1):
|
||||
linear = run_onnx({"input":inp})["output"].schedule_linear()
|
||||
prg = to_program(linear.src[-2].src[0], renderer=Device[Device.DEFAULT].renderer)
|
||||
daccs = [u for u in tuple(prg.src[2].src) if u.op is Ops.DEFINE_REG]
|
||||
daccs = [u for u in tuple(prg.src[2].src) if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG]
|
||||
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
|
||||
|
|
|
|||
|
|
@ -1402,7 +1402,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
|
|||
acc_dt = dtypes.int32 if is_int_out else dtypes.float32
|
||||
# Use uint32 temp array to prevent optimizer from eliminating f16→f32 bitcast chains.
|
||||
# The optimizer folds bitcast(uint32→float32) stores to float32 arrays, losing the conversion.
|
||||
tmp = UOp(Ops.DEFINE_LOCAL, dtypes.uint32.ptr(n_a_elems + n_b_elems, addrspace=AddrSpace.LOCAL), arg=(n_a_elems + n_b_elems,))
|
||||
tmp = UOp.placeholder((n_a_elems + n_b_elems,), dtypes.uint32, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
def cvt_elem(raw: UOp, sub_idx: int) -> UOp:
|
||||
if is_i8:
|
||||
|
|
@ -1425,7 +1425,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
|
|||
mant = h & UOp.const(dtypes.uint32, 0x3FF)
|
||||
# Use bf16 path: shift left by 16 to create bf16 bits, then shift mantissa and adjust exponent in float domain
|
||||
# bf16 bits = (sign << 15) | (exp_bf16 << 7) | mant_bf16 -- but f16 and bf16 have different formats
|
||||
# Instead: construct f32 bits properly, use a DEFINE_LOCAL uint32 array to force materialization
|
||||
# Instead: construct f32 bits properly, use a local uint32 array to force materialization
|
||||
f32_bits = (sign << UOp.const(dtypes.uint32, 31)) | \
|
||||
((exp + UOp.const(dtypes.uint32, 112)) << UOp.const(dtypes.uint32, 23)) | \
|
||||
(mant << UOp.const(dtypes.uint32, 13))
|
||||
|
|
|
|||
|
|
@ -51,7 +51,6 @@ def wer_helper(result: str, reference: str)->float:
|
|||
wer, _, _ = metrics.word_error_rate([result], [reference])
|
||||
return wer
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["CPU"], "slow")
|
||||
# TODO: WEBGPU GPU dispatch dimensions limit
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU GPU dispatch dimensions limit")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -529,6 +529,26 @@ class TestFunctionTuple(unittest.TestCase):
|
|||
|
||||
np.testing.assert_allclose(g(a).numpy(), 14.0)
|
||||
|
||||
def test_custom_kernel_inplace_output_is_implicit(self):
|
||||
# a custom_kernel output the kernel also READS (in-place add) is not write-only, so it must be captured as an input
|
||||
def inplace_add(C:UOp, A:UOp) -> UOp:
|
||||
i = UOp.range(A.shape[0], 0)
|
||||
return C[i].store(C[i].load() + A[i]).end(i).sink(arg=KernelInfo(name="inplace_add"))
|
||||
@function(precompile=True, allow_implicit=False)
|
||||
def f(a:Tensor): return Tensor.custom_kernel(Tensor.empty(*a.shape, dtype=a.dtype, device=a.device), a, fxn=inplace_add)[0]
|
||||
with self.assertRaisesRegex(RuntimeError, "implicit buffer"): f(Tensor([1., 2., 3., 4.]).contiguous().realize())
|
||||
|
||||
def test_custom_kernel_write_only_persistent_output_is_implicit(self):
|
||||
# a write-only custom_kernel output that is a realized buffer must be captured
|
||||
def write(C:UOp, A:UOp) -> UOp:
|
||||
i = UOp.range(A.shape[0], 0)
|
||||
return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="write"))
|
||||
state = Tensor([100., 200., 300., 400.], device="CPU").contiguous().realize()
|
||||
@function(precompile=True, allow_implicit=True)
|
||||
def f(a:Tensor): return Tensor.custom_kernel(state, a, fxn=write)[0]
|
||||
f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize()
|
||||
np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.])
|
||||
|
||||
def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
def my_kernel(C:UOp, A:UOp) -> UOp:
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@ import itertools
|
|||
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
|
||||
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
|
||||
from tinygrad.uop.ops import ParamArg
|
||||
from tinygrad.uop.render import pyrender
|
||||
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
|
|
@ -36,15 +35,12 @@ pm_index_is_shrink = PatternMatcher([
|
|||
|
||||
pm_remove_vec_dtypes = PatternMatcher([
|
||||
# rewrite PARAM to non pointer
|
||||
(UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
|
||||
(UPat((Ops.PARAM, Ops.BUFFER), name="buf"), lambda buf:
|
||||
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
|
||||
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
|
||||
# remove all vec dtypes
|
||||
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
|
||||
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER}, name="x"),
|
||||
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
|
||||
# replace DEFINE_LOCAL/DEFINE_REG with BUFFER
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
|
||||
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
|
||||
])+pm_clean_up_group_sink
|
||||
|
||||
def do_number_param(ctx:list[int], x:UOp):
|
||||
|
|
@ -187,6 +183,8 @@ def do_linearize(ctx:Renderer, prg:UOp, sink:UOp) -> UOp:
|
|||
# isa renderers need to allocate registers
|
||||
if isinstance(ctx, ISARenderer):
|
||||
if ctx.pre_regalloc_matcher is not None: lst = line_rewrite(lst, ctx.pre_regalloc_matcher, PreRegAllocContext())
|
||||
# register definitions (INS without srcs) move to the top so regalloc sees their live ranges span the whole program (callee saved regs)
|
||||
lst = sorted(lst, key=lambda u: u.op is not Ops.INS or bool(u.src))
|
||||
regalloc_ctx = LinearScanRegallocContext(lst, ctx)
|
||||
lst = line_rewrite(lst, pm_regalloc_rewrite, regalloc_ctx)
|
||||
lst = line_rewrite(lst, ctx.post_regalloc_matcher, regalloc_ctx)
|
||||
|
|
|
|||
|
|
@ -245,11 +245,15 @@ def no_vectorized_alu(alu:UOp):
|
|||
return UOp(Ops.STACK, alu.dtype, alus)
|
||||
|
||||
def no_vectorized_buf(buf:UOp):
|
||||
if not isinstance(buf.dtype, PtrDType): return None
|
||||
if buf.addrspace not in (AddrSpace.LOCAL, AddrSpace.REG): return None
|
||||
# TODO: this fails on regs
|
||||
#assert buf.max_numel() == buf.ptrdtype.size
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.addrspace)).cast(buf.dtype)
|
||||
sz = buf.ptrdtype.size*buf.ptrdtype.count
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(sz, buf.addrspace), src=(UOp.const(dtypes.int, sz),)).cast(buf.dtype)
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
|
||||
if buf.addrspace not in (AddrSpace.LOCAL, AddrSpace.REG): return None
|
||||
cnt = cast.dtype.count
|
||||
if bcast is not None and bcast.op is Ops.GEP:
|
||||
# GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes
|
||||
|
|
@ -264,12 +268,10 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
|
|||
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.weakint.vec(len(pairs)), offsets), ptr=True)
|
||||
|
||||
devectorize_buf_and_index = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index),
|
||||
(UPat(Ops.BUFFER, name="buf"), no_vectorized_buf),
|
||||
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")), no_vectorized_index),
|
||||
])
|
||||
|
||||
devectorize_alu = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import heapq
|
|||
from typing import Any
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
|
||||
|
||||
def linearize(sink:UOp) -> list[UOp]:
|
||||
|
|
@ -23,9 +24,7 @@ def linearize(sink:UOp) -> list[UOp]:
|
|||
match u.op:
|
||||
# the order and placement of these defines is important
|
||||
case Ops.PARAM: priority, extra = -20, u.arg.slot
|
||||
case Ops.BUFFER: priority = -18
|
||||
case Ops.DEFINE_REG: priority = -18
|
||||
case Ops.DEFINE_LOCAL: priority = -17
|
||||
case Ops.BUFFER: priority = -17 if u.addrspace == AddrSpace.LOCAL else -18
|
||||
case Ops.LOAD: priority = -1 # place loads early
|
||||
case Ops.STORE: priority = 1 # place stores late
|
||||
case Ops.RANGE: priority = 5 # placing RANGE is good
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import itertools
|
|||
from tinygrad.helpers import dedup
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
||||
from tinygrad.renderer.isa import ISARenderer, Register
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP, Ops.STACK}
|
||||
|
||||
|
|
@ -49,8 +49,9 @@ class LinearScanRegallocContext:
|
|||
# assign register to spilled virtual and record load to be emitted before current uop, also assign it a stack slot
|
||||
def fill(v:Register, i:int, cons:tuple[Register, ...]|None=None) -> Register:
|
||||
if v not in self.spills:
|
||||
# the value of a BUFFER is its 64bit address
|
||||
dt = self.vdef(v).dtype
|
||||
sz = dt.scalar().itemsize * dt.count if not isinstance(dt, PtrDType) else 8
|
||||
sz = 8 if self.vdef(v).op is Ops.BUFFER else dt.scalar().itemsize * dt.count
|
||||
offset = self.stack_size + (sz - self.stack_size % sz) % sz
|
||||
self.spills[v] = UOp.const(dtypes.int32, offset)
|
||||
self.stack_size = offset + sz
|
||||
|
|
@ -83,9 +84,9 @@ class LinearScanRegallocContext:
|
|||
self.reals.setdefault(i, {})[v] = live[v]
|
||||
|
||||
# allocate stack array
|
||||
if u.op is Ops.DEFINE_LOCAL:
|
||||
if u.op is Ops.BUFFER:
|
||||
self.locals[u] = UOp.const(dtypes.int32, self.stack_size)
|
||||
self.stack_size += u.dtype.nbytes()
|
||||
self.stack_size += u.max_numel() * u.dtype.itemsize
|
||||
|
||||
# loop prologue, avoid loading inside the loop
|
||||
if u.op is Ops.RANGE:
|
||||
|
|
@ -116,7 +117,7 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
|
|||
if i in ctx.reals and (v:=ctx.uops[i].src[j].reg) in ctx.spills: nsrc.append(ctx.ren.fill(ctx.spills[v], ctx.vdef(v), ctx.reals[i][v]))
|
||||
else: nsrc.append(s)
|
||||
ndefs = tuple(ctx.reals[i][v] for v in x.tag) if isinstance(x.tag, tuple) else x.tag
|
||||
if x.op is Ops.DEFINE_LOCAL: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], dtype=x.dtype, tag=ndefs))
|
||||
if x.op is Ops.BUFFER: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], tag=ndefs))
|
||||
else: nx = x.replace(src=tuple(nsrc), tag=ndefs)
|
||||
|
||||
before = [ctx.ren.fill(ctx.spills[v], ctx.vdef(v), r) for v,r in ctx.insert_before.get(i, [])]
|
||||
|
|
@ -132,6 +133,5 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
|
|||
return nx, before + [nx] + after
|
||||
|
||||
pm_regalloc_rewrite = PatternMatcher([
|
||||
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
|
||||
regalloc_rewrite),
|
||||
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.BUFFER, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"), regalloc_rewrite),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ def reduce_collapse(red:UOp, u:UOp, pm:PatternMatcher=pm_reduce_collapse) -> UOp
|
|||
replaces: dict[UOp, UOp] = {}
|
||||
for u in included:
|
||||
for s in u.src:
|
||||
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL}: continue
|
||||
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.BUFFER}: continue
|
||||
replaces[s] = UOp.variable(f'in{len(replaces)}', s.vmin, s.vmax, s.dtype)
|
||||
collapse_fxn = u.substitute(replaces).reduce(r, arg=Ops.ADD)
|
||||
sink = graph_rewrite(collapse_fxn, pm, name="reduce_collapse")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
import sys, argparse, codecs, typing, re, unicodedata, json, uuid, time, pathlib
|
||||
from tinygrad import nn
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context, fetch, profile_marker
|
||||
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context, fetch, profile_marker, getenv
|
||||
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
||||
from tinygrad.llm.model import Transformer
|
||||
|
||||
|
|
@ -214,9 +214,13 @@ def main():
|
|||
for i in range(args.benchmark):
|
||||
profile_marker(f"decode @ {i}")
|
||||
GlobalCounters.reset()
|
||||
if (log:=getenv("BENCHMARK_LOG", "")): from extra.bench_log import WallTimeEvent, BenchEvent
|
||||
with Timing(on_exit=lambda x: f", {1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s,"
|
||||
f" {GlobalCounters.global_mem//1000000}/{GlobalCounters.mem_used//1000000} MB -- "+\
|
||||
tok.decode(toks).replace("\n", "\\n")): next(gen)
|
||||
tok.decode(toks).replace("\n", "\\n")):
|
||||
if log:
|
||||
with WallTimeEvent(BenchEvent.STEP): next(gen)
|
||||
else: next(gen)
|
||||
exit(0)
|
||||
|
||||
# interactive chat
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def assemble_linear(prg:UOp, lin:UOp, arch:str) -> bytes:
|
|||
for u in sink.toposort():
|
||||
if u.op is Ops.PARAM and u.addrspace is AddrSpace.ALU: n_vars += 1
|
||||
elif u.op is Ops.PARAM: n_bufs += 1
|
||||
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
|
||||
elif u.op is Ops.BUFFER and u.addrspace is AddrSpace.LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
|
||||
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))
|
||||
code_bytes = b"".join(inst.to_bytes() for inst in insts)
|
||||
arch = next(v for k, v in _arch_map.items() if arch.startswith(k))
|
||||
|
|
|
|||
|
|
@ -158,10 +158,9 @@ class CStyleLanguage(Renderer):
|
|||
return f"({self[buf]}+{strip_parens(self[idx]) if idx.arg == Ops.ADD else self[idx]})"
|
||||
|
||||
def render_buffer(self, x:UOp):
|
||||
shp = x.src[0].as_shape
|
||||
lanes = 1
|
||||
prefix = f"{self.smem_align}{self.smem_prefix}" if x.addrspace == AddrSpace.LOCAL else ""
|
||||
suffix = f"[{shp[0]}]" if len(shp) else ""
|
||||
suffix = f"[{x.max_numel()}]"
|
||||
return f"{prefix}{self._render_dtype(x.dtype, sz=lanes)} {self[x]}{suffix};"
|
||||
|
||||
def _render_dtype(self, dtype:DType, sz:int=1, addrspace=AddrSpace.ALU, mutable=True, override_ptr=False):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# allow semicolons to put multiple ops on one line
|
||||
import sys, struct, functools
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, DType, truncate, AddrSpace
|
||||
from tinygrad.dtype import dtypes, DType, truncate, AddrSpace
|
||||
from tinygrad.uop import FastEnum, auto, Ops, GroupOp
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext, Register, PreRegAllocContext
|
||||
|
|
@ -12,8 +12,8 @@ from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
|
|||
|
||||
class X86Ops(FastEnum):
|
||||
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
|
||||
# these aren't real instructions
|
||||
FRAME_INDEX = auto(); LABEL = auto()
|
||||
# these aren't real instructions, DEFINE is a register placeholder that defines a register without emitting an instruction
|
||||
FRAME_INDEX = auto(); LABEL = auto(); DEFINE = auto()
|
||||
# index
|
||||
LEA = auto()
|
||||
# register / memory / immediate moves
|
||||
|
|
@ -171,50 +171,32 @@ extra_matcher = PatternMatcher([
|
|||
(UPat(Ops.CMOD, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x - y * x.alu(Ops.CDIV, y)),
|
||||
])
|
||||
|
||||
# ***** X86 new style -> x86 internal style (pointers, vec dtypes, GEP) *****
|
||||
|
||||
pm_x86_style = PatternMatcher([
|
||||
# buffers are pointers, scalar PARAMs (variables) keep their shape src
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: x.replace(dtype=x.dtype.ptr(x.src[0].arg), src=()) \
|
||||
if x.arg.addrspace is AddrSpace.GLOBAL and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.BUFFER, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG if x.arg.addrspace == AddrSpace.REG else Ops.DEFINE_LOCAL,
|
||||
dtype=x.dtype.ptr(x.src[0].arg, x.arg.addrspace), src=(), arg=x.arg.slot)),
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(dtype=x.src[0].dtype) if x.dtype != x.src[0].dtype else None),
|
||||
# SHRINK is a vectorized INDEX
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx"), UPat.cvar("c"))), lambda buf,idx,c: buf.index(idx, ptr=True) \
|
||||
.cast(buf.ptrdtype.base.vec(c.arg).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) if isinstance(buf.dtype, PtrDType) else None),
|
||||
# cast of a pointer is a noop in new style (any reinterpreting cast was absorbed into SHRINK)
|
||||
(UPat(Ops.CAST, src=(UPat.var("y"),), name="x"), lambda x,y:
|
||||
y if isinstance(y.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
|
||||
# INDEX on a pointer has pointer dtype, INDEX on a register value is a GEP
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat()), name="x"), lambda buf,x:
|
||||
x.replace(dtype=buf.dtype) if isinstance(buf.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("y"), UPat.cvar("c")), name="x"), lambda y,c,x:
|
||||
y.gep(c.arg) if not isinstance(y.dtype, PtrDType) and y.op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else None),
|
||||
# restore vec dtypes from structure
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.CAST, name="c"),), allow_any_len=True, name="x"), lambda x,c:
|
||||
x.replace(dtype=x.dtype.scalar().vec(c.ptrdtype.base.count)) if isinstance(c.dtype, PtrDType) and c.ptrdtype.base.count > x.dtype.count else None),
|
||||
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
|
||||
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
|
||||
if not isinstance(x.dtype, PtrDType) and not any(isinstance(s.dtype, PtrDType) for s in x.src) \
|
||||
and (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
|
||||
])
|
||||
|
||||
# ***** X86 pre instruction selection *****
|
||||
|
||||
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count, AddrSpace.LOCAL), arg=next(ctx))
|
||||
local_idx = local.index(UOp.const(dtypes.int32, 0), ptr=True)
|
||||
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
|
||||
return ptr.cast(cast.dtype).load(dtype=x.dtype)
|
||||
def scratch_buffer(elem_dt:DType, count:int, slot:int) -> UOp:
|
||||
return UOp.placeholder((count,), elem_dt, slot, AddrSpace.LOCAL)
|
||||
|
||||
def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count, AddrSpace.LOCAL), arg=-1)
|
||||
ptr = gate.where(base.index(idx, ptr=True), local.index(UOp.const(dtypes.int32, 0), ptr=True))
|
||||
return ptr.cast(cast.dtype).store(val)
|
||||
def gated_load(ctx, addr:UOp, alt:UOp, gate:UOp, x:UOp):
|
||||
local = scratch_buffer(addr.src[0].dtype.scalar(), x.dtype.count, next(ctx))
|
||||
local_idx = local.index(UOp.const(dtypes.int32, 0), dtype=dtypes.uint64)
|
||||
# the selected address is a 64bit value, the AFTER orders the load after the scratch store and carries the element dtype for the encoder
|
||||
sel = gate.where(addr.replace(dtype=dtypes.uint64), local_idx)
|
||||
ptr = UOp(Ops.AFTER, addr.dtype, (sel, (local_idx if x.dtype.count == 1 else local).store(alt)))
|
||||
return ptr.load(dtype=x.dtype)
|
||||
|
||||
# these must be done in a separate matcher because they violate the spec
|
||||
pre_isel_matcher = pm_x86_style + PatternMatcher([
|
||||
def gated_store(addr:UOp, gate:UOp, val:UOp):
|
||||
local = scratch_buffer(addr.src[0].dtype.scalar(), val.dtype.count, -1)
|
||||
sel = gate.where(addr.replace(dtype=dtypes.uint64), local.index(UOp.const(dtypes.int32, 0), dtype=dtypes.uint64))
|
||||
return UOp(Ops.AFTER, addr.dtype, (sel,)).store(val)
|
||||
|
||||
# legalize the new style graph for isel. NOTE: this runs after the spec is verified, some of these rewrites violate it
|
||||
pre_isel_matcher = PatternMatcher([
|
||||
# x86 registers are typed by their width, materialize the structural width of the graph into vec dtypes (this is still valid new style)
|
||||
(UPat(Ops.SHRINK, src=(UPat(), UPat(), UPat.cvar("c"))).load(allow_any_len=True, name="x"), lambda x,c:
|
||||
x.replace(dtype=x.dtype.scalar().vec(c.arg)) if c.arg > x.dtype.count else None),
|
||||
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
|
||||
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
|
||||
if (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
|
||||
# zero extending scalar 32bit int is a noop
|
||||
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
|
||||
# cast between signed and unsigned int is a noop
|
||||
|
|
@ -229,11 +211,12 @@ pre_isel_matcher = pm_x86_style + PatternMatcher([
|
|||
# noop of a noop is removed
|
||||
(UPat(Ops.NOOP, src=(UPat(Ops.NOOP),), name="x"), lambda x: x.replace(src=x.src[0].src)),
|
||||
# moving elements of a single register to another without shuffling is a noop
|
||||
(UPat(Ops.STACK, src=(UPat.var("y"),), allow_any_len=True, name="x"),
|
||||
lambda y,x: UOp(Ops.NOOP, x.dtype, y.src) if all(s.op is Ops.GEP and s.src == y.src and s.arg[0] == i for i,s in enumerate(x.src)) else None),
|
||||
# gated load/store become a conditional move on the index, the load/store are unconditional
|
||||
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
|
||||
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").store(UPat.var("val"), UPat.var("gate")), gated_store),
|
||||
(UPat(Ops.STACK, src=(UPat.var("y").index(UPat()),), allow_any_len=True, name="x"),
|
||||
lambda y,x: UOp(Ops.NOOP, x.dtype, (y,)) if all(s.op is Ops.INDEX and len(s.src) == 2 and s.src[0] is y \
|
||||
and s.src[1].op is Ops.CONST and s.src[1].arg == i for i,s in enumerate(x.src)) else None),
|
||||
# gated load/store become a conditional move on the address, the load/store are unconditional
|
||||
(UPat((Ops.INDEX, Ops.SHRINK), name="addr").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
|
||||
(UPat((Ops.INDEX, Ops.SHRINK), name="addr").store(UPat.var("val"), UPat.var("gate")), gated_store),
|
||||
# TODO: remove this once we allow all flag producing ops in cmove
|
||||
# if gate in scalar int cmove is not a comparison need to add one to set the flag
|
||||
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
|
||||
|
|
@ -264,10 +247,10 @@ reg_strs = {"rax": {4:"eax", 2:"ax", 1:"al"}, "rcx": {4:"ecx", 2:"cx", 1:"cl"},
|
|||
# ***** X86 instruction selection *****
|
||||
# if s is used multiple times we don't fold
|
||||
def is_foldable(ctx:IselContext, x:UOp, s:UOp) -> bool: return len(ctx.uses[s]) == x.src.count(s) == 1
|
||||
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
|
||||
def lane(x:UOp, i:int) -> int: return s.arg[0] if (s:=x.src[i]).op is Ops.GEP else 0
|
||||
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.INDEX else s
|
||||
def lane(x:UOp, i:int) -> int: return s.src[1].arg if (s:=x.src[i]).op is Ops.INDEX else 0
|
||||
def to_int(dt:DType): return {dtypes.float16: dtypes.int16, dtypes.float32: dtypes.int32, dtypes.float64: dtypes.int64}[dt]
|
||||
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.DEFINE_REG, dt, tag=None if reg is None else (reg,))
|
||||
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, dt, arg=X86Ops.DEFINE, tag=None if reg is None else (reg,))
|
||||
def imm(dt:DType, v:int) -> UOp: return UOp.const(dt, truncate[dt](v)).rtag()
|
||||
def to_imm(c:UOp) -> UOp|None:
|
||||
if c.op is not Ops.CONST: return None
|
||||
|
|
@ -348,41 +331,52 @@ def idiv(ctx:IselContext, x:UOp) -> UOp:
|
|||
# this move "cleanses" the register constraints (rax/rdx) of idiv as that only applies on definition and not on the uses of idiv
|
||||
return x.ins(X86Ops.MOV, src=(idiv,))
|
||||
|
||||
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp]:
|
||||
# a memory address operand is (base, index, displacement, size). size is the element size, it scales the index and is the memory operand width.
|
||||
# it is materialized as an immediate so the address stays correct if the base register is ever spilled and refilled
|
||||
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp, UOp]:
|
||||
def _disp(v:int) -> UOp: return imm(dtypes.int32 if abs(v) > dtypes.int8.max else dtypes.int8, v)
|
||||
def _cast(v:UOp) -> UOp: return v.cast(dtypes.int64) if v.vmin < 0 else v
|
||||
if x.op is not Ops.INDEX: return (x, UOp(Ops.NOOP), _disp(0))
|
||||
base, idx = x.src
|
||||
disp_scale = base.dtype.itemsize if isinstance(base.dtype, PtrDType) else 1
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * disp_scale))
|
||||
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * disp_scale))
|
||||
return (base, _cast(idx), _disp(0))
|
||||
if x.op not in {Ops.INDEX, Ops.SHRINK}: return (x, UOp(Ops.NOOP), _disp(0), imm(dtypes.uint8, x.dtype.itemsize))
|
||||
base, idx = x.src[0], x.src[1]
|
||||
# buffers are indexed by element, everything else (the stack pointer) by byte
|
||||
scale = base.dtype.itemsize if base.op in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else 1
|
||||
sz = imm(dtypes.uint8, base.dtype.itemsize)
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * scale), sz)
|
||||
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * scale), sz)
|
||||
return (base, _cast(idx), _disp(0), sz)
|
||||
|
||||
def abi(ctx:IselContext, x:UOp) -> UOp|None:
|
||||
if isinstance(x.tag, tuple): return None
|
||||
i = ctx.func_args.index(x)
|
||||
def _stack_arg(disp:int): return (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=disp))
|
||||
if sys.platform == "win32": src = (x.replace(tag=((RCX, RDX, GPR[8], GPR[9])[i],)),) if i < 4 else _stack_arg((i-3)*8+32)
|
||||
else: src = (x.replace(tag=((RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i],)),) if i < 6 else _stack_arg((i-5)*8)
|
||||
# buffer params hold addresses, their value moves as a 64bit int
|
||||
dt = dtypes.uint64 if x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.GLOBAL else x.dtype
|
||||
# the shape srcs of a PARAM are not values, tag them so they aren't materialized into registers
|
||||
def _reg_arg(r:Register) -> tuple[UOp, ...]: return (x.replace(dtype=dt, src=tuple(s.rtag() for s in x.src), tag=(r,)),)
|
||||
def _stack_arg(disp:int):
|
||||
return (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=disp), imm(dtypes.uint8, 8))
|
||||
if sys.platform == "win32": src = _reg_arg((RCX, RDX, GPR[8], GPR[9])[i]) if i < 4 else _stack_arg((i-3)*8+32)
|
||||
else: src = _reg_arg((RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i]) if i < 6 else _stack_arg((i-5)*8)
|
||||
# this move "cleanses" the abi register constraint
|
||||
return x.ins(X86Ops.MOV, src=src)
|
||||
return x.ins(X86Ops.MOV, dtype=dt, src=src)
|
||||
|
||||
def alloc_vregs(ctx:IselContext, x:UOp) -> UOp|None:
|
||||
# real registers
|
||||
if x.op is Ops.DEFINE_REG and x.tag is not None: return None
|
||||
# register placeholders with real registers
|
||||
if x.arg is X86Ops.DEFINE and x.tag is not None: return None
|
||||
# this is an immediate
|
||||
if x.arg is X86Ops.FRAME_INDEX: return None
|
||||
# no register definition
|
||||
if x.dtype is dtypes.void: return None
|
||||
# already allocated vregs
|
||||
if isinstance(x.tag, tuple) and x.tag[0]._cons: return None
|
||||
# allocate vreg definitions
|
||||
# allocate vreg definitions, the value of a BUFFER is its address so it lives in a gpr
|
||||
defs = []
|
||||
if isinstance(x.tag, tuple): defs = [ctx.vreg(x.tag)]
|
||||
elif x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType): defs = [ctx.vreg(WGPR)]
|
||||
elif x.op is Ops.BUFFER or x.dtype in dtypes.ints+(dtypes.bool,): defs = [ctx.vreg(WGPR)]
|
||||
elif x.dtype in dtypes.floats or x.dtype.count > 1: defs = [ctx.vreg(XMM)]
|
||||
# TODO: add this once the scheduler can track register pressure
|
||||
# if x.arg in X86GroupOp.WriteFlags: defs.append(ctx.vreg(RFLAGS))
|
||||
# the size src of a BUFFER is not a value, tag it so it isn't materialized into a register
|
||||
if x.op is Ops.BUFFER: return x.replace(src=tuple(s.rtag() for s in x.src), tag=tuple(defs))
|
||||
return x.replace(tag=tuple(defs))
|
||||
|
||||
dts = dtypes.ints + (dtypes.bool, dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
|
|
@ -393,11 +387,12 @@ dt_128bit = tuple(dt.vec(l) for dt in dts for l in [16,8,4,2,1] if l*dt.itemsize
|
|||
|
||||
isel_matcher = PatternMatcher([
|
||||
# **** Op -> Op ****
|
||||
# cast to pointer is a noop
|
||||
(UPat.var("y").cast(name="x"), lambda y,x: y if isinstance(x.dtype, PtrDType) or y.dtype == dtypes.void else None),
|
||||
# float gep(0) is a noop as it just moves the 0th element from one xmm register to another
|
||||
# cast of void is a noop
|
||||
(UPat.var("y").cast(name="x"), lambda y,x: y if y.dtype == dtypes.void else None),
|
||||
# extracting the 0th float element is a noop as it just moves the 0th element from one xmm register to another
|
||||
# this is done here to not interfere with shuffles
|
||||
(UPat(dtype=dtypes.floats).gep(0, name="x"), lambda x: x.replace(op=Ops.NOOP, arg=None)),
|
||||
(UPat(dtype=dtypes.floats).index(UPat(Ops.CONST, arg=0), name="x"),
|
||||
lambda x: x.replace(op=Ops.NOOP, src=x.src[:1]) if x.src[0].dtype.count > 1 else None),
|
||||
# range is lowered to acc, cmp, jmp after regalloc
|
||||
(UPat(Ops.RANGE, src=(UPat.cvar("c"),), allow_any_len=True, name="x"), lambda c,x: x.replace(src=(imm(c.dtype, c.arg),) + x.src[1:])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(tag=(ctx.vreg(WGPR),)) if not isinstance(x.tag, tuple) else None),
|
||||
|
|
@ -409,9 +404,6 @@ isel_matcher = PatternMatcher([
|
|||
if not x.src or x.src[0].arg is not X86Ops.RET else None),
|
||||
# function abi constraints
|
||||
(UPat((Ops.PARAM, Ops.SPECIAL), name="x"), abi),
|
||||
# these are treated the same for now
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x:
|
||||
x.replace(op=Ops.DEFINE_LOCAL, dtype=x.dtype.base.ptr(x.dtype.size, AddrSpace.LOCAL)) if isinstance(x.arg, int) else None),
|
||||
# constants that can't be immediates, move them to registers
|
||||
(UPat.cvar("x", dtypes.int64s), lambda x: x.ins(X86Ops.MOVABS, src=(imm(x.dtype, x.arg),)) if not x.tag else None),
|
||||
(UPat.cvar("x", dtypes.ints+(dtypes.bool,)), lambda x: x.ins(X86Ops.MOVi, src=(imm(x.dtype, x.arg),)) if not x.tag else None),
|
||||
|
|
@ -479,12 +471,17 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(Ops.STACK, dtypes.float32, name="x"), vinsertps),
|
||||
(UPat.var("y", dtypes.ints+(dtypes.bool,)).broadcast(name="x"), vpbroadcast),
|
||||
(UPat(Ops.STACK, dtypes.ints+(dtypes.bool,), name="x"), vpins),
|
||||
# gep
|
||||
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.floats).gep(name="x"), lambda y,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, x.arg[0] * x.dtype.itemsize)))),
|
||||
# INDEX on a vector register value extracts a single element
|
||||
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
|
||||
(UPat.var("y", dtypes.int16s).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
|
||||
(UPat.var("y", dtypes.int32s).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
|
||||
(UPat.var("y", dtypes.int64s).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
|
||||
(UPat.var("y", dtypes.floats).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, c.arg * x.dtype.itemsize))) if y.dtype.count > 1 else None),
|
||||
# fused multiply add
|
||||
((UPat(Ops.MUL, dtypes.float32, name="a") + UPat.var("b")).named("c"), lambda ctx,a,b,c:
|
||||
a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
|
||||
|
|
@ -578,8 +575,9 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.ins(X86Ops.VMOVQ)),
|
||||
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
|
||||
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
|
||||
# index
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
|
||||
# index on a buffer (or the stack pointer) computes an address, addresses are 64bit values
|
||||
(UPat((Ops.INDEX, Ops.SHRINK), name="x"),
|
||||
lambda x: x.ins(X86Ops.LEA, dtype=dtypes.uint64, src=fold_address(x)) if x.src[0].dtype.count == 1 else None),
|
||||
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
|
||||
# copy, load, store
|
||||
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
|
||||
|
|
@ -608,7 +606,7 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(Ops.INS, src=(UPat(), UPat(), UPat(Ops.LOAD, src=(UPat(name="a"),), name="y")), allow_any_len=True, name="x"), lambda ctx,y,a,x:
|
||||
x.replace(src=x.src[:2] + fold_address(a) + x.src[3:]) if x.arg in X86GroupOp.ReadMem3rd and is_foldable(ctx, x, y) else None),
|
||||
# allocate virtual registers
|
||||
(UPat((Ops.INS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), alloc_vregs),
|
||||
(UPat((Ops.INS, Ops.BUFFER), name="x"), alloc_vregs),
|
||||
])
|
||||
|
||||
# ***** pre register allocation *****
|
||||
|
|
@ -656,14 +654,16 @@ post_regalloc_matcher = PatternMatcher([
|
|||
# ***** X86 instruction encoding *****
|
||||
|
||||
def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) -> bytes|None:
|
||||
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
|
||||
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, sz_uop:UOp|None=None,
|
||||
vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
|
||||
nonlocal reg, opc
|
||||
# get the encoding values of the different fields
|
||||
reg = cast(int, cast(Register, reg_uop.reg).index if reg_uop is not None else reg)
|
||||
rm = cast(Register, rm_uop.reg).index
|
||||
idx = cast(Register, idx_uop.reg).index if idx_uop is not None and idx_uop.reg is not None else 4
|
||||
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else rm_uop.dtype.itemsize
|
||||
reg_sz = (reg_uop.dtype.itemsize if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
|
||||
# for a memory operand the rm size is the element size from the address, otherwise it's the size of the value in the register
|
||||
rm_sz = sz_uop.arg if sz_uop is not None else rm_uop.dtype.itemsize
|
||||
reg_sz = reg_uop.dtype.itemsize if reg_uop is not None else 0
|
||||
sz = reg_sz or rm_sz
|
||||
|
||||
# encode instruction
|
||||
|
|
@ -723,19 +723,19 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) ->
|
|||
# when a uop writes to memory it takes the form of a store, dtype is void, no definition
|
||||
address:tuple[UOp|None, ...]
|
||||
if x.arg in X86GroupOp.WriteMem:
|
||||
if len(x.src) > 3: address, rest = x.src[:3], x.src[3:]
|
||||
else: address, rest = (x, None, None), x.src
|
||||
if len(x.src) > 4: address, rest = x.src[:4], x.src[4:]
|
||||
else: address, rest = (x, None, None, None), x.src
|
||||
return _encode(rest[0], *address, *(None, *rest[1:])) if reg is None else _encode(None, *address, *(None, *rest[:1]))
|
||||
|
||||
if x.arg in X86GroupOp.Rm1st:
|
||||
if len(x.src) > 2: address, rest = x.src[:3], x.src[3:]
|
||||
else: address, rest = (x.src[0], None, None), x.src[1:]
|
||||
if len(x.src) > 3: address, rest = x.src[:4], x.src[4:]
|
||||
else: address, rest = (x.src[0], None, None, None), x.src[1:]
|
||||
imm_uop = rest[:1] if rest and rest[0].op is Ops.CONST else (None,)
|
||||
return _encode(x, *address, *(None, *imm_uop)) if reg is None else _encode(None, *address, *(x if sel else None, *imm_uop))
|
||||
|
||||
if x.arg in X86GroupOp.Rm2nd:
|
||||
if len(x.src) > 3: address, rest = x.src[1:4], x.src[:1] + x.src[4:]
|
||||
else: address, rest = (x.src[1], None, None), x.src[:1] + x.src[2:]
|
||||
if len(x.src) > 4: address, rest = x.src[1:5], x.src[:1] + x.src[5:]
|
||||
else: address, rest = (x.src[1], None, None, None), x.src[:1] + x.src[2:]
|
||||
# cmp/vucomiss reg, rm don't define a new register
|
||||
return _encode(x, *address, *rest) if x.dtype is not dtypes.void else _encode(rest[0], *address)
|
||||
|
||||
|
|
@ -770,8 +770,9 @@ encodings = {
|
|||
X86Ops.VCVTDQ2PS: lambda x: encode(x, 0x5B, pp=0, sel=1), X86Ops.VCVTDQ2PD: lambda x: encode(x, 0xE6, pp=2, sel=1),
|
||||
X86Ops.VCVTPS2PD: lambda x: encode(x, 0x5A, pp=0, sel=1), X86Ops.VCVTPD2PS: lambda x: encode(x, 0x5A, pp=1, sel=1),
|
||||
X86Ops.VCVTTPS2DQ: lambda x: encode(x, 0x5B, pp=2, sel=1), X86Ops.VCVTTPD2DQ: lambda x: encode(x, 0xE6, pp=1, sel=1),
|
||||
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.itemsize == 8),
|
||||
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.itemsize == 8),
|
||||
# the int src is the 2nd src (the rm field), if it was folded into a memory operand its width is the element size of the address
|
||||
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=(x.src[4].arg if len(x.src) > 4 else x.src[1].dtype.itemsize) == 8),
|
||||
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=(x.src[4].arg if len(x.src) > 4 else x.src[1].dtype.itemsize) == 8),
|
||||
X86Ops.VCVTTSS2SI: lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype.itemsize == 8),
|
||||
X86Ops.VCVTTSD2SI: lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype.itemsize == 8),
|
||||
# int division
|
||||
|
|
@ -871,43 +872,41 @@ class X86Renderer(ISARenderer):
|
|||
self.compiler = X86Compiler()
|
||||
def is_two_address(self, x:UOp) -> bool: return x.arg in X86GroupOp.TwoAddress
|
||||
def stack_pointer(self) -> UOp: return def_reg(dtypes.uint64, RSP)
|
||||
# nasty hacks to deal with pointers TODO: rm pointers
|
||||
# the value of a BUFFER is its address, it moves through registers and the stack as a 64bit int
|
||||
def copy(self, x:UOp, reg:Register):
|
||||
dt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
|
||||
ret = isel_matcher.rewrite(UOp(Ops.COPY, dt, (x,), tag=reg))
|
||||
ret = isel_matcher.rewrite(UOp(Ops.COPY, dtypes.uint64 if x.op is Ops.BUFFER else x.dtype, (x,), tag=reg))
|
||||
assert ret is not None
|
||||
return ret.replace(dtype=x.dtype)
|
||||
return ret
|
||||
|
||||
def spill(self, disp:UOp, x:UOp) -> UOp:
|
||||
nx = x.replace(dtype=dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype)
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(nx))
|
||||
if x.op is Ops.BUFFER: x = x.replace(dtype=dtypes.uint64)
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(x))
|
||||
assert ret is not None
|
||||
return ret.replace(src=(s if s is not nx else x for s in ret.src))
|
||||
return ret
|
||||
|
||||
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp:
|
||||
ndt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=ndt, tag=reg))
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=dtypes.uint64 if x.op is Ops.BUFFER else x.dtype, tag=reg))
|
||||
assert ret is not None
|
||||
return ret.replace(dtype=x.dtype)
|
||||
return ret
|
||||
|
||||
def asm_str(self, uops:list[UOp], function_name:str) -> str:
|
||||
def _format_op(x:UOp) -> str: return f" {(o[7:-1] if (o:=str(x.arg))[-1] in ('i', 'm') else o[7:]).lower():7s}"
|
||||
def _format_operands(x:UOp) -> str:
|
||||
def _format(src:tuple[UOp, ...]) -> list[str]:
|
||||
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize if not isinstance(s.dtype, PtrDType) else 8, o) if \
|
||||
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize, o) if \
|
||||
(o:=str(s.reg)) in reg_strs else o for s in src if s.reg is not None]
|
||||
def _mem_adress(base:UOp, idx:UOp, disp:UOp) -> list[str]:
|
||||
return [f"[{base.reg}" + (f" + {idx.reg}*{base.dtype.itemsize}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"]
|
||||
def _mem_adress(base:UOp, idx:UOp, disp:UOp, sz:UOp) -> list[str]:
|
||||
return [f"[{base.reg}" + (f" + {idx.reg}*{sz.arg}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"]
|
||||
|
||||
if len(x.src) > 3 and x.arg in X86GroupOp.WriteMem: ret = _mem_adress(*x.src[:3]) + _format(x.src[3:])
|
||||
elif len(x.src) > 2 and x.arg in X86GroupOp.Rm1st: ret = _format((x,)) + _mem_adress(*x.src[:3]) + _format(x.src[3:])
|
||||
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm2nd: ret = _format((x, x.src[0])) + _mem_adress(*x.src[1:4]) + _format(x.src[4:])
|
||||
if len(x.src) > 4 and x.arg in X86GroupOp.WriteMem: ret = _mem_adress(*x.src[:4]) + _format(x.src[4:])
|
||||
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm1st: ret = _format((x,)) + _mem_adress(*x.src[:4]) + _format(x.src[4:])
|
||||
elif len(x.src) > 4 and x.arg in X86GroupOp.Rm2nd: ret = _format((x, x.src[0])) + _mem_adress(*x.src[1:5]) + _format(x.src[5:])
|
||||
else: ret = _format((x,) + x.src)
|
||||
return ", ".join(ret)
|
||||
|
||||
asm = [f".{function_name}:"]
|
||||
for u in uops:
|
||||
if u.op is not Ops.INS: continue
|
||||
if u.op is not Ops.INS or u.arg is X86Ops.DEFINE: continue
|
||||
if u.arg is X86Ops.LABEL: asm.append(f"{str(u.tag)}:")
|
||||
elif u.arg is X86Ops.RET: asm.append(_format_op(u))
|
||||
else: asm.append(_format_op(u) + " " + _format_operands(u))
|
||||
|
|
@ -918,7 +917,7 @@ class X86Renderer(ISARenderer):
|
|||
jumps: dict[UOp, int] = {}
|
||||
binary = bytearray()
|
||||
for u in uops:
|
||||
if u.op is not Ops.INS: continue
|
||||
if u.op is not Ops.INS or u.arg is X86Ops.DEFINE: continue
|
||||
if u.arg is X86Ops.LABEL:
|
||||
targets[u.tag] = len(binary)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ dsp_pm_late = PatternMatcher([
|
|||
(UPat.var("x")+UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")*UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")//UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat(Ops.DEFINE_REG, src=(UPat(Ops.STACK, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.STACK, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:]) if d.addrspace is AddrSpace.REG else None),
|
||||
])
|
||||
|
||||
# NOTE: this just increases readability of the generated code
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_claus
|
|||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.SLICE,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.FUNCTION}
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.LOAD, Ops.CALL, Ops.FUNCTION}
|
||||
|
||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -427,7 +427,7 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
|||
ended_stores.append(store_target.replace(dtype=sdtype).store(store.src[1]).end(*end_rngs))
|
||||
return buf.after(*ended_stores)
|
||||
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
# NOTE: the local BUFFER needs to be disambiguated here
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
|
||||
if x.src[0].op is Ops.SLICE:
|
||||
|
|
@ -561,11 +561,11 @@ rangeify_codegen = PatternMatcher([
|
|||
# TODO: this can be moved into codegen?
|
||||
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0] if len(x.src) else None),
|
||||
|
||||
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if isinstance(idx.dtype, PtrDType) else
|
||||
(UPat(Ops.BUFFER).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if dg.addrspace is not AddrSpace.LOCAL or isinstance(idx.dtype, PtrDType) else
|
||||
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
|
||||
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).gep(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if isinstance(idx.dtype, PtrDType) else
|
||||
(UPat(Ops.BUFFER).f(Ops.AFTER, allow_any_len=True).gep(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if dg.addrspace is not AddrSpace.LOCAL or isinstance(idx.dtype, PtrDType) else
|
||||
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -19,10 +19,7 @@ class Ops(FastEnum):
|
|||
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
|
||||
SPECIAL = auto()
|
||||
|
||||
# define LOCAL/REG allocate things
|
||||
DEFINE_LOCAL = auto(); DEFINE_REG = auto()
|
||||
|
||||
# BUFFER is the new LOCAL/REG
|
||||
# BUFFER allocates global/local/register storage depending on its addrspace
|
||||
BUFFER = auto()
|
||||
|
||||
# ** 2 -- non op uops **
|
||||
|
|
@ -125,7 +122,7 @@ class GroupOp:
|
|||
# TODO: is BITCAST always Elementwise if it's shape changing?
|
||||
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
|
||||
|
||||
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
||||
Defines = {Ops.PARAM, Ops.BUFFER}
|
||||
|
||||
Irreducible = {Ops.CONST, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
|
|
|||
|
|
@ -508,7 +508,7 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa
|
|||
return PatternMatcher(pat)
|
||||
|
||||
pm_long_decomp = PatternMatcher([
|
||||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
(UPat((*GroupOp.Defines, Ops.BUFFER, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None),
|
||||
(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])),
|
||||
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
|
||||
|
|
@ -531,7 +531,7 @@ pm_long_decomp = PatternMatcher([
|
|||
|
||||
# float decomposition patterns - ctx is (fr, to) tuple
|
||||
pm_float_decomp = PatternMatcher([
|
||||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x:
|
||||
(UPat((*GroupOp.Defines, Ops.BUFFER, Ops.INDEX), name="x"), lambda ctx,x:
|
||||
x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None),
|
||||
(UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None),
|
||||
# bitcasted load should just replace load
|
||||
|
|
|
|||
|
|
@ -279,7 +279,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
case Ops.GETADDR: return ()
|
||||
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
case Ops.BINARY: return (len(self.arg),)
|
||||
case Ops.BUFFER: return self.src[0].as_shape if isinstance(self.arg, ParamArg) else (self.arg,)
|
||||
case Ops.BUFFER:
|
||||
if isinstance(self.arg, ParamArg):
|
||||
if len(self.src): return self.src[0].as_shape
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size, self.dtype.count) if self.dtype.count > 1 else (self.ptrdtype.size,)
|
||||
return (self.dtype.count,) if self.dtype.count > 1 else ()
|
||||
return (self.arg,)
|
||||
case Ops.SLICE:
|
||||
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||
if self.src[0].op is Ops.INDEX: return ()
|
||||
|
|
@ -288,13 +293,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
case Ops.STAGE:
|
||||
# STAGE adds the existing shape to the front, opposite of INDEX
|
||||
return tuple([int(r.vmax+1) for r in self.src[1:]])+self.src[0].shape
|
||||
case Ops.DEFINE_LOCAL | Ops.DEFINE_REG:
|
||||
if len(self.src) >= 1:
|
||||
# NOTE: this is the same as PARAM
|
||||
return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
|
||||
if isinstance(self.dtype, PtrDType):
|
||||
return (self.ptrdtype.size, self.dtype.count) if self.dtype.count > 1 else (self.ptrdtype.size,)
|
||||
return (self.dtype.count,) if self.dtype.count > 1 else ()
|
||||
case Ops.PARAM:
|
||||
if isinstance(self.dtype, ImageDType): return self.dtype.shape
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
|
||||
|
|
@ -794,8 +792,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
def addrspace(self) -> AddrSpace|None:
|
||||
if self.op is Ops.PARAM: return self.arg.addrspace
|
||||
if self.op is Ops.BUFFER: return self.arg.addrspace if isinstance(self.arg, ParamArg) else AddrSpace.GLOBAL
|
||||
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
|
||||
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
|
||||
if self.op in {Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
|
||||
if self.op is Ops.LOAD: return AddrSpace.ALU # LOAD brings things into the ALU
|
||||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
|
||||
|
|
@ -1060,10 +1056,13 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
|
||||
@staticmethod
|
||||
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
arg = ParamArg(slot, addrspace=addrspace) if addrspace is AddrSpace.GLOBAL else slot
|
||||
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=arg)
|
||||
if len(shape) > 1: ret = ret.reshape(shape)
|
||||
if addrspace is AddrSpace.GLOBAL:
|
||||
ret = UOp(Ops.PARAM, dtype.ptr(prod(shape), addrspace), arg=ParamArg(slot, addrspace=addrspace))
|
||||
else:
|
||||
assert addrspace in (AddrSpace.LOCAL, AddrSpace.REG)
|
||||
buf_shape = (prod(shape),) + ((dtype.count,) if dtype.count > 1 else ())
|
||||
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
|
||||
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
|
||||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
assert all_int(self.shape), "no placeholder-like on symbolic shape"
|
||||
|
|
@ -1152,7 +1151,7 @@ class ProgramInfo:
|
|||
if u.op is Ops.PARAM and u.addrspace != AddrSpace.ALU: _globals.append(u.arg.slot)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
|
||||
if (buf:=idx.src[0].buf_uop).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
|
||||
if u.op is Ops.SPECIAL:
|
||||
if u.arg[0] == 'i': local_size = None
|
||||
special_size = local_size if u.arg[0] == 'l' else global_size
|
||||
|
|
|
|||
|
|
@ -79,17 +79,15 @@ spec_shared = PatternMatcher([
|
|||
|
||||
# PARAM
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: isinstance(x.arg, ParamArg)),
|
||||
(UPat(Ops.BUFFER, src=(UPat(),), name="x"), lambda x:
|
||||
isinstance(x.arg, ParamArg) and x.addrspace in (AddrSpace.REG, AddrSpace.LOCAL)),
|
||||
|
||||
# GROUP of stores (or groups, or NOOPs)
|
||||
# TODO: remove UNROLL here, it's for SPEC=2
|
||||
(UPat(Ops.GROUP, dtypes.void, src=UPat((Ops.GROUP, Ops.STORE, Ops.NOOP, Ops.UNROLL, Ops.INS))), lambda: True),
|
||||
|
||||
# TOOD: these should be buffer with different addrspace everywhere.
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), lambda: True),
|
||||
|
||||
# AFTER on Movement Op, PARAM, BUFFER, CONTIGUOUS, or another AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.AFTER, Ops.MULTI,
|
||||
Ops.BITCAST, Ops.INS})),),
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.AFTER, Ops.MULTI, Ops.BITCAST, Ops.INS})),),
|
||||
allow_any_len=True), lambda: True),
|
||||
|
||||
# CUSTOM (inline and non inline)
|
||||
|
|
@ -176,7 +174,7 @@ spec_tensor = PatternMatcher([
|
|||
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None),
|
||||
lambda root,x: root.dtype == x.dtype),
|
||||
|
||||
# TODO: this should not be here. STAGE is transformed to DEFINE_LOCAL later
|
||||
# TODO: this should not be here. STAGE is transformed to BUFFER later
|
||||
(UPat(Ops.STAGE, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
|
|
@ -196,7 +194,7 @@ spec_tensor = PatternMatcher([
|
|||
# these ops can exist in programs but not the tensor spec. example: LOAD
|
||||
spec_program = PatternMatcher([
|
||||
# no more of these in programs
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.GEP)), lambda: False),
|
||||
(UPat(Ops.GEP), lambda: False),
|
||||
|
||||
# weakint is not allowed in programs
|
||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphE
|
|||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.SHAPED_WMMA: "#FF5B5B",
|
||||
Ops.SHAPED_WMMA: "#FF5B5B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#D8F9E4", Ops.STACK: "#D8F9E4",
|
||||
Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue