mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
110 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
687ade119e |
||
|
|
0a8e61d0c5 |
||
|
|
dfea9e7994 |
||
|
|
ce87d80911 |
||
|
|
5a2b3b7b06 |
||
|
|
116045cc8e |
||
|
|
7c1d0b6d9a |
||
|
|
c9dc1d63cc |
||
|
|
da98fae9e1 |
||
|
|
15988b5941 |
||
|
|
cbfcf36e44 |
||
|
|
f9c8c697d6 |
||
|
|
0138480910 |
||
|
|
33b635d23a |
||
|
|
625d8bbd0d |
||
|
|
fe9b19b12d |
||
|
|
267af9c601 |
||
|
|
97da54b9d6 |
||
|
|
fd0dc40689 |
||
|
|
2d8b802958 |
||
|
|
ba1d3baae8 |
||
|
|
d80a41d559 |
||
|
|
5164c21b44 |
||
|
|
58ff75272e |
||
|
|
b50da5c205 |
||
|
|
4618d27129 |
||
|
|
9ae0a93d0e |
||
|
|
30830850a9 |
||
|
|
8b07cca9f7 |
||
|
|
b2199c54a3 |
||
|
|
1822eed8d3 |
||
|
|
bba611bb59 |
||
|
|
67c3e589a1 |
||
|
|
649971f02a |
||
|
|
b05bea81ce |
||
|
|
97c2e7a3d9 |
||
|
|
d7b10c69bc |
||
|
|
091ec8d10d |
||
|
|
925c49ce99 |
||
|
|
05249466ed |
||
|
|
4a4b6956df |
||
|
|
eda0a402d1 |
||
|
|
5989d0b150 |
||
|
|
d37248c3ec |
||
|
|
d74f488376 |
||
|
|
d7a1022188 |
||
|
|
924bece1d5 |
||
|
|
b753fb5e4c |
||
|
|
31094a794f |
||
|
|
1720987dc7 |
||
|
|
bed0c343a3 |
||
|
|
e0fe6e542e |
||
|
|
a74b7130b4 |
||
|
|
df015ad541 |
||
|
|
1bd4551ee1 |
||
|
|
53a1226a49 |
||
|
|
aef85ddc4d |
||
|
|
1e08c0a07c |
||
|
|
1acc40600d |
||
|
|
0f0c622086 |
||
|
|
be9b570cb2 |
||
|
|
c7055d658f |
||
|
|
d631716858 |
||
|
|
36f6d1b064 |
||
|
|
1cb6b88d37 |
||
|
|
5644605d92 |
||
|
|
d5d59a2be6 |
||
|
|
f0998e9bba |
||
|
|
7d2b0b697d |
||
|
|
70cac72781 |
||
|
|
443f976305 |
||
|
|
aa2bef24a8 |
||
|
|
efd03d7153 |
||
|
|
4a0488ae97 |
||
|
|
41aa2fe119 |
||
|
|
10bdb9c9d0 |
||
|
|
f998b9930a |
||
|
|
4dc51aff6e |
||
|
|
2adedf5ccb |
||
|
|
a6d7fb9d4d |
||
|
|
b1fb39502d | ||
|
|
2e181f4259 |
||
|
|
5d5ead78da |
||
|
|
b00dd754a9 |
||
|
|
5a9227b30a |
||
|
|
8efc8d064f |
||
|
|
c43091a464 |
||
|
|
2e77bd01db |
||
|
|
bcdb988df0 |
||
|
|
6b8fdfe4ca |
||
|
|
67a4f129c2 |
||
|
|
8862c7549c |
||
|
|
9e72a6b376 |
||
|
|
aa32d309db |
||
|
|
96b86aad7b |
||
|
|
a35964493e |
||
|
|
3036b15ed9 |
||
|
|
b2e95b2db3 |
||
|
|
833cb37574 |
||
|
|
51100d2c5c |
||
|
|
76c10cd635 |
||
|
|
2bfdf85f87 |
||
|
|
fb74f75485 |
||
|
|
4d34590b7d |
||
|
|
12f4cf0e49 |
||
|
|
e770805d21 |
||
|
|
b8aec4cce7 |
||
|
|
762f50bd52 |
||
|
|
a2cec397f3 |
||
|
|
b97e3e01e3 |
288 changed files with 15713 additions and 2104 deletions
6
.github/actions/setup-tinygrad/action.yml
vendored
6
.github/actions/setup-tinygrad/action.yml
vendored
|
|
@ -80,7 +80,7 @@ runs:
|
|||
- name: Cache Python packages (PR)
|
||||
if: github.event_name == 'pull_request'
|
||||
id: restore-venv-pr
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: /tmp/.uv-cache
|
||||
key: uv-${{ runner.os }}-${{ runner.arch }}-python-${{ inputs.python-version }}-${{ inputs.deps }}-${{ inputs.pydeps }}-${{ env.CACHE_VERSION }}
|
||||
|
|
@ -96,7 +96,7 @@ runs:
|
|||
|
||||
- name: Cache downloads (PR)
|
||||
if: inputs.key != '' && github.event_name == 'pull_request'
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: ${{ runner.os == 'Linux' && '~/.cache/tinygrad/downloads/' || '~/Library/Caches/tinygrad/downloads/' }}
|
||||
key: downloads-${{ github.job }}-${{ inputs.key }}-${{ env.CACHE_VERSION }}
|
||||
|
|
@ -203,7 +203,7 @@ runs:
|
|||
|
||||
- name: Cache apt (PR)
|
||||
if: runner.os == 'Linux' && (inputs.opencl == 'true' || inputs.amd == 'true' || inputs.webgpu == 'true' || inputs.llvm == 'true' || inputs.qemu == 'true') && github.event_name == 'pull_request'
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: /var/cache/apt/archives/
|
||||
key: ${{ runner.os }}-${{ runner.arch }}-apt-${{ steps.apt-pkgs.outputs.hash }}-${{ env.CACHE_VERSION }}
|
||||
|
|
|
|||
152
.github/workflows/benchmark.yml
vendored
152
.github/workflows/benchmark.yml
vendored
|
|
@ -99,7 +99,6 @@ jobs:
|
|||
ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu
|
||||
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
|
||||
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
|
|
@ -134,32 +133,10 @@ jobs:
|
|||
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
|
||||
- name: Fuzz Padded Tensor Core GEMM
|
||||
run: DEV=METAL M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run quantized LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
|
||||
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
|
||||
- name: Run quantized LLaMA3
|
||||
run: |
|
||||
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
|
||||
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
|
||||
#- name: Run LLaMA 7B on 4 (virtual) GPUs
|
||||
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run OLMoE
|
||||
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
|
||||
- name: Run llama3.2
|
||||
run: BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run olmoe
|
||||
run: BENCHMARK_LOG=olmoe JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m olmoe --benchmark --warmup
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
|
||||
|
||||
|
|
@ -235,9 +212,6 @@ jobs:
|
|||
- name: Symlink models and datasets
|
||||
run: |
|
||||
mkdir -p weights
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
|
||||
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
|
||||
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
|
||||
mkdir -p extra/datasets
|
||||
ln -s /raid/datasets/imagenet extra/datasets/imagenet
|
||||
|
|
@ -279,36 +253,16 @@ jobs:
|
|||
# TODO: too slow
|
||||
# - name: Run SDXL
|
||||
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit DEV=NV JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama DEV=NV JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run llama3.2
|
||||
run: DEV=NV BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run qwen3.5
|
||||
run: DEV=NV BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run quantized LLaMA3
|
||||
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit DEV=NV JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 DEV=NV JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half DEV=NV HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam DEV=NV HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: Speed (NVIDIA)
|
||||
|
|
@ -402,10 +356,7 @@ jobs:
|
|||
run: |
|
||||
mkdir -p weights
|
||||
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
|
||||
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
|
||||
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
|
||||
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
|
||||
mkdir -p extra/datasets
|
||||
ln -s /raid/datasets/imagenet extra/datasets/imagenet
|
||||
|
|
@ -458,18 +409,10 @@ jobs:
|
|||
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 DEV=AMD python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
|
||||
- name: Run SDXL
|
||||
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 DEV=AMD python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA 7B
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit DEV=AMD JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama DEV=AMD JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA 7B with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run llama3.2
|
||||
run: DEV=AMD BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run qwen3.5
|
||||
run: DEV=AMD BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
|
|
@ -478,16 +421,6 @@ jobs:
|
|||
# run: sudo modprobe amdgpu
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral DEV=AMD python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit DEV=AMD JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 DEV=AMD JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half DEV=AMD HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam DEV=AMD HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
|
@ -594,10 +527,10 @@ jobs:
|
|||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testqualcommbenchmark:
|
||||
name: comma Benchmark
|
||||
testcommalatest:
|
||||
name: comma Benchmark (0.11.0)
|
||||
runs-on: [self-hosted, Linux, comma]
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 10
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -e -o pipefail {0}
|
||||
|
|
@ -622,6 +555,27 @@ jobs:
|
|||
run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3.2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx
|
||||
- name: openpilot compile3 0.11.0 dmonitoring
|
||||
run: BENCHMARK_LOG=openpilot_0_11_0_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testcommaold:
|
||||
name: comma Benchmark (0.10.1)
|
||||
runs-on: [self-hosted, Linux, comma]
|
||||
timeout-minutes: 10
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -e -o pipefail {0}
|
||||
if: github.repository_owner == 'tinygrad'
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: DEBUG=2 openpilot compile3 0.10.1 driving_vision
|
||||
run: PYTHONPATH="." DEBUG=2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: openpilot compile3 0.10.1 driving_vision
|
||||
|
|
@ -630,14 +584,44 @@ jobs:
|
|||
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3.2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
|
||||
- name: openpilot compile3 0.10.1 dmonitoring
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testqualcommdsp:
|
||||
name: DSP Benchmark
|
||||
runs-on: [self-hosted, Linux, comma4]
|
||||
timeout-minutes: 5
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -e -o pipefail {0}
|
||||
if: github.repository_owner == 'tinygrad'
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: benchmark MobileNetV2 on DSP
|
||||
run: |
|
||||
# generate quantized weights
|
||||
ln -s /data/home/tiny/tinygrad/extra/datasets/imagenet extra/datasets/imagenet
|
||||
ln -s /data/home/tiny/tinygrad/testsig-*.so .
|
||||
PYTHONPATH=. CC=clang-19 DEV=CPU QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
|
||||
PYTHONPATH=. DEV=CPU QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
|
||||
# benchmark on DSP with NOOPT=1, the devectorizer has issues
|
||||
PYTHONPATH=. CC=clang-19 DEV=DSP NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
|
||||
PYTHONPATH=. DEV=DSP NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
|
|
|||
122
.github/workflows/test.yml
vendored
122
.github/workflows/test.yml
vendored
|
|
@ -133,46 +133,26 @@ jobs:
|
|||
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
|
||||
- name: Test IMAGE support
|
||||
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
|
||||
- name: Test emulated METAL tensor cores
|
||||
- name: Test emulated tensor cores
|
||||
env:
|
||||
DEV: 'PYTHON::METAL'
|
||||
DEBUG: 2
|
||||
N: 64
|
||||
CNT: 1
|
||||
SHOULD_USE_TC: 1
|
||||
run: |
|
||||
DEBUG=2 python3 test/backend/test_ops.py TestOps.test_big_gemm
|
||||
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test emulated AMD tensor cores
|
||||
env:
|
||||
DEV: 'PYTHON::gfx1100'
|
||||
parallel -k --link --tagstring '[{1}]' '{2} python3 ./extra/gemm/simple_matmul.py' \
|
||||
::: metal gfx950 gfx1100 gfx1100_acchalf gfx1201 gfx1201_acchalf sm_75 sm_80_half sm_80_tf32 \
|
||||
::: 'DEV=PYTHON::METAL' 'DEV=PYTHON::gfx950 HALF=1 ACC_HALF=0' \
|
||||
'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=1 ATOL=1e-3' \
|
||||
'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=1 ATOL=1e-3' \
|
||||
'DEV=PYTHON::sm_75 HALF=1' 'DEV=PYTHON::sm_80 HALF=1' 'DEV=PYTHON::sm_80 ALLOW_TF32=1'
|
||||
- name: Run additional tensor core tests
|
||||
run: |
|
||||
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test emulated AMD MFMA tensor cores
|
||||
env:
|
||||
DEV: 'PYTHON::gfx950'
|
||||
run: |
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test emulated AMD RDNA4 tensor cores
|
||||
env:
|
||||
DEV: 'PYTHON::gfx1201'
|
||||
run: |
|
||||
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test emulated CUDA tensor cores
|
||||
run: |
|
||||
DEBUG=2 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
|
||||
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
|
||||
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
|
||||
DEV=PYTHON::METAL python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
|
||||
DEV=PYTHON::gfx1100 python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
|
||||
DEV=PYTHON::gfx950 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
DEV=PYTHON::gfx1201 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test device flop counts
|
||||
run: |
|
||||
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
|
||||
linter:
|
||||
|
|
@ -218,14 +198,15 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: unittest-13
|
||||
pydeps: "pillow ftfy regex pre-commit"
|
||||
deps: testing_unit
|
||||
llvm: 'true'
|
||||
amd: 'true'
|
||||
- name: Run NULL backend tests
|
||||
run: DEV=NULL python -m pytest -n=auto test/null/ --durations=20
|
||||
- name: Run targeted tests on NULL backend
|
||||
run: DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
run: |
|
||||
DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
DEV=NULL VIZ=1 python3 -m pytest -n=auto test/null/test_viz.py
|
||||
# TODO: too slow
|
||||
# - name: Run SDXL on NULL backend
|
||||
# run: DEV=NULL DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
|
||||
|
|
@ -249,7 +230,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: unittest-13
|
||||
pydeps: "pillow ftfy regex pre-commit"
|
||||
pydeps: "pre-commit"
|
||||
deps: testing_unit
|
||||
llvm: 'true'
|
||||
- name: Run pre-commit test hooks
|
||||
|
|
@ -266,13 +247,6 @@ jobs:
|
|||
run: python3 test/external/external_benchmark_schedule.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
- name: Regen dataset on test_tiny
|
||||
run: |
|
||||
test/external/process_replay/reset.py
|
||||
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
|
||||
python extra/optimization/extract_dataset.py
|
||||
gzip -c /tmp/sops > extra/datasets/sops.gz
|
||||
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
|
||||
- name: Repo line count < 25000 lines
|
||||
run: MAX_LINE_COUNT=25000 python sz.py
|
||||
|
||||
|
|
@ -337,31 +311,6 @@ jobs:
|
|||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testgpumisc:
|
||||
name: CL Misc tests
|
||||
runs-on: *linux
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: gen-dataset
|
||||
deps: testing
|
||||
opencl: 'true'
|
||||
- name: Generate Dataset
|
||||
run: DEV=CL extra/optimization/generate_dataset.sh
|
||||
- name: Run Kernel Count Test
|
||||
run: DEV=CL python -m pytest -n=auto test/external/external_test_opt.py
|
||||
- name: Run fused optimizer tests
|
||||
run: DEV=CL FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py test/backend/test_optim.py -k "not muon"
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: sops.gz
|
||||
path: /tmp/sops.gz
|
||||
|
||||
testopenpilot:
|
||||
name: openpilot Compile Tests
|
||||
runs-on: *linux
|
||||
|
|
@ -378,7 +327,7 @@ jobs:
|
|||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp32 (test correctness)
|
||||
run: |
|
||||
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
|
||||
|
|
@ -421,7 +370,6 @@ jobs:
|
|||
with:
|
||||
key: optim
|
||||
deps: testing
|
||||
pydeps: "tensorflow==2.19"
|
||||
opencl: 'true'
|
||||
#- name: Test Optimization Helpers
|
||||
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
|
||||
|
|
@ -430,7 +378,7 @@ jobs:
|
|||
- name: Test Beam Search
|
||||
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
|
||||
- name: Test MLPerf stuff
|
||||
run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
||||
run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
||||
- name: DEV=NULL beautiful_mnist_multigpu
|
||||
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
|
||||
- name: Test Bert training
|
||||
|
|
@ -468,7 +416,7 @@ jobs:
|
|||
# ****** Models Tests ******
|
||||
|
||||
testmodels:
|
||||
name: Models (llvm+cpu+gpu)
|
||||
name: Models
|
||||
runs-on: *linux
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
|
|
@ -479,34 +427,12 @@ jobs:
|
|||
with:
|
||||
key: models
|
||||
deps: testing
|
||||
opencl: 'true'
|
||||
llvm: 'true'
|
||||
- name: Test models (llvm)
|
||||
run: DEV=CPU:LLVM python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (opencl)
|
||||
run: DEV=CL python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (cpu)
|
||||
run: DEV=CPU python -m pytest -n=auto test/models --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testmetalmodels:
|
||||
name: Models (metal)
|
||||
runs-on: &macos macos-26
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: metal
|
||||
deps: testing
|
||||
- name: Test models (Metal)
|
||||
run: DEV=METAL python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test LLaMA compile speed
|
||||
run: DEV=METAL python test/external/external_test_speed_llama.py
|
||||
|
||||
# ****** Feature Tests ******
|
||||
|
||||
testdsp:
|
||||
|
|
@ -715,7 +641,7 @@ jobs:
|
|||
|
||||
unittestmacos:
|
||||
name: MacOS (unit)
|
||||
runs-on: *macos
|
||||
runs-on: &macos macos-26
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
|
|||
Throw in an optimizer, a data loader, and some compute, and you have all you need.
|
||||
|
||||
```python
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad import Tensor, nn, Context
|
||||
|
||||
class LinearNet:
|
||||
def __init__(self):
|
||||
|
|
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
|
|||
|
||||
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
|
||||
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
for i in range(10):
|
||||
optim.zero_grad()
|
||||
loss = model(x).sparse_categorical_crossentropy(y).backward()
|
||||
|
|
|
|||
|
|
@ -165,13 +165,14 @@ from extra.datasets import fetch_mnist
|
|||
Now we have everything we need to start training our neural network.
|
||||
We will be training for 1000 steps with a batch size of 64.
|
||||
|
||||
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training.
|
||||
We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
|
||||
Upon exit, the flag is restored to its previous value by the context manager.
|
||||
|
||||
```python
|
||||
from tinygrad import Context
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
for step in range(1000):
|
||||
# random sample a batch
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(64))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Tuple
|
||||
import time
|
||||
from tinygrad import Tensor, TinyJit, nn
|
||||
from tinygrad import Tensor, TinyJit, nn, Context
|
||||
import gymnasium as gym
|
||||
from tinygrad.helpers import trange
|
||||
import numpy as np # TODO: remove numpy import
|
||||
|
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
|||
|
||||
@TinyJit
|
||||
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
log_dist, value = model(x)
|
||||
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()
|
||||
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ if __name__ == "__main__":
|
|||
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def train_step(idxs:Tensor) -> Tensor:
|
||||
X, Y = X_train[idxs], Y_train[idxs]
|
||||
if len(GPUS) > 1:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ class Model:
|
|||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import List, Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
|||
|
||||
@TinyJit
|
||||
def train_step() -> Tensor:
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import itertools
|
||||
from typing import Callable
|
||||
from tinygrad import nn, Tensor, dtypes, Device, TinyJit
|
||||
from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
|
||||
from tinygrad.helpers import getenv, trange, partition
|
||||
|
||||
class Model:
|
||||
|
|
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
|||
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def microbatch():
|
||||
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
|
||||
for t in params: t.grad = None
|
||||
|
|
|
|||
|
|
@ -359,7 +359,7 @@ def train_cifar():
|
|||
i = 0
|
||||
eval_acc_pct = 0.0
|
||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
st = time.monotonic()
|
||||
while i <= STEPS:
|
||||
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
import os, math, time
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
|
||||
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
|
|
@ -177,7 +177,7 @@ if __name__ == "__main__":
|
|||
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def step(x:Tensor, y:Tensor) -> Tensor:
|
||||
_, loss = model(x, y)
|
||||
optimizer.zero_grad()
|
||||
|
|
@ -204,4 +204,3 @@ if __name__ == "__main__":
|
|||
top_k = 40
|
||||
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
||||
print(decode(y[0].tolist()))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# much taken from https://github.com/cloneofsimo/minRF
|
||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
|
||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
|
||||
from tinygrad.helpers import getenv, trange
|
||||
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
|
||||
|
||||
|
|
@ -135,7 +135,7 @@ if __name__ == "__main__":
|
|||
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def train_step():
|
||||
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
|
||||
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])
|
||||
|
|
|
|||
|
|
@ -358,7 +358,7 @@ def eval_stable_diffusion():
|
|||
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
|
||||
return batch, unpadded_bs
|
||||
|
||||
@Tensor.train(mode=False)
|
||||
@Context(TRAINING=0)
|
||||
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
|
||||
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
|
||||
# Eval is divided into 5 jits, one per model
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
|
|||
from pathlib import Path
|
||||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
|
||||
|
|
@ -614,7 +614,7 @@ def train_retinanet():
|
|||
|
||||
if getenv("RESET_STEP", 1): _train_step.reset()
|
||||
|
||||
with Tensor.train(mode=False):
|
||||
with Context(TRAINING=0):
|
||||
if not RUNMLPERF:
|
||||
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
|
||||
else:
|
||||
|
|
@ -784,7 +784,7 @@ def train_unet3d():
|
|||
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def train_step(model, x, y):
|
||||
optim.zero_grad()
|
||||
|
||||
|
|
@ -795,7 +795,7 @@ def train_unet3d():
|
|||
optim.step()
|
||||
return loss.realize()
|
||||
|
||||
@Tensor.train(mode=False)
|
||||
@Context(TRAINING=0)
|
||||
def eval_step(model, x, y):
|
||||
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
|
||||
y_hat, y = Tensor(y_hat), Tensor(y)
|
||||
|
|
@ -1282,7 +1282,7 @@ def train_bert():
|
|||
previous_step = i
|
||||
|
||||
def train_llama3():
|
||||
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8_DTYPE
|
||||
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8_DTYPE, MXFP8
|
||||
from examples.llama3 import MODEL_PARAMS
|
||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
|
||||
from examples.mlperf.optim import GradAccClipAdamW
|
||||
|
|
@ -1447,7 +1447,12 @@ def train_llama3():
|
|||
idx = next(j for j, p in enumerate(optim.params) if p is w)
|
||||
master = optim.master_params[idx]
|
||||
inv = w._inv_scale if w._inv_scale.device == master.device else w._inv_scale.to(master.device)
|
||||
master.assign((master * inv.reshape(*inv.shape, *([1]*(w.ndim-inv.ndim)))).contiguous())
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import _mx_block_scale
|
||||
bs = _mx_block_scale(inv.reshape(-1, inv.shape[-1])).reshape(w.shape)
|
||||
master.assign((master * bs).contiguous())
|
||||
else:
|
||||
master.assign((master * inv.reshape(*inv.shape, *([1]*(w.ndim-inv.ndim)))).contiguous())
|
||||
|
||||
# realize everything here
|
||||
if optim.master_params: Tensor.realize(*optim.master_params)
|
||||
|
|
@ -1485,7 +1490,7 @@ def train_llama3():
|
|||
return lr_cpu, grad_norm_cpu
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train(False)
|
||||
@Context(TRAINING=0)
|
||||
def eval_step(tokens:Tensor):
|
||||
if is_dp: tokens = tokens.to(None).shard(device, 0)
|
||||
if is_mp: tokens = tokens.shard(device)
|
||||
|
|
@ -1798,7 +1803,7 @@ if __name__ == "__main__":
|
|||
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
|
||||
else: bench_log_manager = contextlib.nullcontext()
|
||||
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
|
||||
nm = f"train_{m}"
|
||||
if nm in globals():
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ FUSED_ADD_NORM_MUL_QUANTIZE = getenv("FUSED_ADD_NORM_MUL_QUANTIZE", 0)
|
|||
FUSED_SILU_W13 = getenv("FUSED_SILU_W13", 0)
|
||||
SPLIT_W13 = getenv("SPLIT_W13", 0)
|
||||
COLUMNWISE_WEIGHT_SCALE = getenv("COLUMNWISE_WEIGHT_SCALE", 0)
|
||||
MXFP8 = getenv("MXFP8", 0)
|
||||
|
||||
FP8_DTYPE = dtypes.fp8e4m3
|
||||
FP8_GRAD_DTYPE = dtypes.fp8e5m2
|
||||
|
|
@ -36,35 +37,48 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
|||
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
|
||||
|
||||
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
||||
x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None,
|
||||
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
|
||||
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
|
||||
grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]:
|
||||
if not fp8:
|
||||
if ASM_GEMM:
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),)
|
||||
return (x @ w.T,)
|
||||
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
|
||||
if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d)
|
||||
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
|
||||
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
|
||||
if can_use_asm_gemm(x_q, w.T):
|
||||
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
|
||||
mx_w_stored=True).reshape(*l_shape, w.shape[0])
|
||||
else:
|
||||
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1])
|
||||
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
|
||||
return out, (amax_x.detach() if amax_x is not None else None), x_q
|
||||
if x_fp8 is None:
|
||||
if FUSED_INPUT_QUANTIZE and amax_x is not None:
|
||||
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed
|
||||
x_fp8, x_scale, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
|
||||
x_fp8, _, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
|
||||
else:
|
||||
x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
|
||||
x_fp8, _, x_new_amax = quantize_fp8(x, amax_state=amax_x)
|
||||
if ASM_GEMM:
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x_fp8, w.T):
|
||||
assert amax_x is not None
|
||||
if COLUMNWISE_WEIGHT_SCALE:
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
|
||||
else:
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
|
||||
return out, x_new_amax, x_fp8
|
||||
return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
|
||||
return (x_fp8.dot(w.T, dtype=dtypes.float) * ((amax_x.float() + 1e-8) / FP8_MAX) * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
|
||||
|
||||
def norm_quantize_matmul(x:Tensor, norm:Tensor, w:Tensor, w_inv_scale:Tensor, eps:float, amax_x:Tensor, grad_amax_state:Tensor):
|
||||
if FUSED_ADD_NORM_MUL_QUANTIZE:
|
||||
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_rmsnorm_mul_quantize_fp8
|
||||
x_fp8, x_inv_scale, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
x_fp8, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
return out, x_normed, rrms, ret
|
||||
x_normed, rrms = rmsnorm(x, eps)
|
||||
out, *ret = matmul(x_normed * norm, w, amax_x=amax_x, w_inv_scale=w_inv_scale, grad_amax_state=grad_amax_state)
|
||||
|
|
@ -74,8 +88,8 @@ def add_norm_quantize_matmul(x:Tensor, residual:Tensor, norm:Tensor, w:Tensor, w
|
|||
grad_amax_state:Tensor|None=None):
|
||||
if FUSED_ADD_NORM_MUL_QUANTIZE:
|
||||
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_add_rmsnorm_mul_quantize_fp8
|
||||
x_fp8, x_inv_scale, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
x_fp8, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
return out, h, x_normed, rrms, ret
|
||||
h = x + residual
|
||||
x_normed, rrms = rmsnorm(h, eps)
|
||||
|
|
@ -87,8 +101,8 @@ def silu_w13_quantize_matmul(x_w13:Tensor, w2:Tensor, s_2:Tensor,
|
|||
grad_amax_xw13:Tensor, grad_amax_xout:Tensor):
|
||||
if FUSED_SILU_W13:
|
||||
from extra.llama_kernels.cast_amax import fused_quantize_fp8_w13
|
||||
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
|
||||
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
|
||||
x2_fp8, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
|
||||
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, amax_x=amax_x2, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
|
||||
return out, ret
|
||||
hidden = x_w13.shape[-1] // 2
|
||||
x_w1, x_w3 = x_w13[..., :hidden], x_w13[..., hidden:]
|
||||
|
|
@ -140,12 +154,17 @@ class FlatTransformer:
|
|||
self._fp8_grad_amax = {name: [_amax() for _ in range(n_layers)] for name in grad_names}
|
||||
w_scales = [("wqkv", s_qkv), ("wo", s_o), ("w2", s_2)]
|
||||
w_scales += [("w1", s_1), ("w3", s_3)] if SPLIT_W13 else [("w13", s_13)]
|
||||
self._fp8_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
|
||||
self._fp8_next_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
|
||||
self._fp8_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
|
||||
self._fp8_next_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
|
||||
|
||||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02):
|
||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
|
||||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
|
||||
if w is None:
|
||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
|
||||
return w_q.reshape(self.n_layers, out_features, in_features), w_e8.reshape(self.n_layers, out_features, in_features // 32)
|
||||
amax = (w.abs().max(axis=2) if COLUMNWISE_WEIGHT_SCALE else w.abs().flatten(1).max(1)).detach()
|
||||
scale = FP8_MAX / (amax + 1e-8)
|
||||
inv_scale = (amax + 1e-8) / FP8_MAX
|
||||
|
|
@ -171,7 +190,7 @@ class FlatTransformer:
|
|||
xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
|
||||
if getenv("HK_FLASH_ATTENTION"):
|
||||
from extra.thunder.amd.fa import flash_attention
|
||||
attn, *save = flash_attention(xq, xk, xv, is_causal=True)
|
||||
attn, *save = flash_attention(xq, xk, xv, is_causal=True, write_flat=True)
|
||||
saves.extend(save)
|
||||
else:
|
||||
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
|
||||
|
|
@ -197,8 +216,15 @@ class FlatTransformer:
|
|||
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
|
||||
amaxs.append(new_amax)
|
||||
saves.extend([*s, x_w3])
|
||||
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
|
||||
grad_amax_state=kwargs["grad_amax_xout"])
|
||||
if FUSED_SILU_W13 and MXFP8:
|
||||
from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
|
||||
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
|
||||
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
|
||||
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
|
||||
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
|
||||
else:
|
||||
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
|
||||
grad_amax_state=kwargs["grad_amax_xout"])
|
||||
amaxs.append(new_amax)
|
||||
saves.extend([*s, out])
|
||||
else:
|
||||
|
|
@ -228,20 +254,30 @@ class FlatTransformer:
|
|||
for v in get_parameters(self): v.shard_(device, axis=None)
|
||||
else:
|
||||
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
||||
def _shard_fp8(name:str, axis:int):
|
||||
getattr(self, name).shard_(device, axis=axis)
|
||||
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||
def _shard_fp8(name:str, axis:int, std:float=0.02):
|
||||
w = getattr(self, name)
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std
|
||||
w_q, w_e8, _ = quantize_mxfp8(w_bf16)
|
||||
w.replace(w_q)
|
||||
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
|
||||
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
|
||||
else:
|
||||
w.shard_(device, axis=axis)
|
||||
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||
sstd = 0.02 / math.sqrt(2 * self.n_layers)
|
||||
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
|
||||
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
|
||||
_shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in
|
||||
if SPLIT_W13:
|
||||
_shard_fp8("w1", 1)
|
||||
_shard_fp8("w3", 1)
|
||||
else:
|
||||
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
|
||||
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
|
||||
_shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in
|
||||
self.attention_norm.shard_(device, axis=None).realize()
|
||||
self.ffn_norm.shard_(device, axis=None).realize()
|
||||
self.norm.weight.shard_(device, axis=None).realize()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0)
|
|||
MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0)
|
||||
FP8_AMAX_MARGIN = getenv("FP8_AMAX_MARGIN", 1.1)
|
||||
IMMEDIATE_SCALE = getenv("IMMEDIATE_SCALE", 0)
|
||||
MXFP8 = getenv("MXFP8", 0)
|
||||
|
||||
def stochastic_round_bf16(x:Tensor) -> Tensor:
|
||||
bits = x.bitcast(dtypes.uint32)
|
||||
|
|
@ -90,6 +91,13 @@ class GradAccClipAdamW(Optimizer):
|
|||
out = stochastic_round_bf16(new_w)
|
||||
return out.shard_like(t) if offloaded else out
|
||||
if t.dtype in dtypes.fp8s:
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_q, w_e8, _ = quantize_mxfp8(new_w.reshape(-1, new_w.shape[-1]))
|
||||
new_e8 = w_e8.reshape(t._inv_scale.shape)
|
||||
t._inv_scale.assign(new_e8.shard_like(t._inv_scale) if offloaded else new_e8)
|
||||
ret = w_q.reshape(new_w.shape)
|
||||
return ret.shard_like(t) if offloaded else ret
|
||||
from examples.mlperf.models.flat_llama import FP8_MAX
|
||||
if IMMEDIATE_SCALE:
|
||||
amax_axis = tuple(range(t._inv_scale.ndim, new_w.ndim))
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from torchvision.utils import make_grid, save_image
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import trange
|
||||
from tinygrad.helpers import trange, Context
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ if __name__ == "__main__":
|
|||
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
||||
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
|
||||
# training loop
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
for epoch in (t := trange(epochs)):
|
||||
loss_g, loss_d = 0.0, 0.0
|
||||
for _ in range(n_steps):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# - symbolic removal
|
||||
|
||||
from examples.beautiful_mnist import Model
|
||||
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable
|
||||
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
|
||||
from tinygrad.nn.datasets import mnist
|
||||
from tinygrad.helpers import trange
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ if __name__ == "__main__":
|
|||
X_samp, Y_samp = X_train[samples], Y_train[samples]
|
||||
print("*** got samples")
|
||||
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
"""
|
||||
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
|
||||
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.dtype import DType, dtypes, AddrSpace
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
|
@ -39,7 +39,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
|
|||
prg = to_program(call.src[0], Device[arg_uops[0].device].renderer)
|
||||
info = prg.arg
|
||||
functions[info.function_name] = prg.src[3].arg
|
||||
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR]
|
||||
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + list(info.vars)
|
||||
statements.append((info.function_name, cargs, info.global_size, info.local_size))
|
||||
|
||||
return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
|
||||
|
|
@ -253,17 +253,18 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
|
|||
symbolic_vars = OrderedDict()
|
||||
for i, (_, args, global_size, _) in enumerate(statements):
|
||||
for j, var in enumerate(args):
|
||||
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
|
||||
if getattr(var, "op", None) is Ops.PARAM and var.addrspace is AddrSpace.ALU and var.arg.name is not None:
|
||||
if var not in symbolic_vars:
|
||||
symbolic_vars[var] = var.arg[0]
|
||||
symbolic_vars[var] = var.expr
|
||||
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
|
||||
statements[i][1][j] = symbolic_vars[var]
|
||||
|
||||
if global_size:
|
||||
for j, dim in enumerate(global_size):
|
||||
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
|
||||
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and \
|
||||
any(s.op is Ops.PARAM and s.addrspace is AddrSpace.ALU for s in dim.src) and any(s.op is Ops.CONST for s in dim.src):
|
||||
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
|
||||
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
|
||||
global_size[j] = f"_{name.expr}[0] + {val.arg}"
|
||||
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
|
|
|
|||
|
|
@ -458,7 +458,8 @@ def test_matmul():
|
|||
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
|
||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
|
||||
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
|
||||
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
|
|||
|
||||
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
|
||||
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
acc = acc.after(acc.store(acc.zeros_like()))
|
||||
acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
|
||||
|
||||
if use_wmma:
|
||||
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
|
|||
P_lds = QP_lds[:, :BLOCK_N]
|
||||
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
|
||||
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
|
||||
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) — shaped store fails due to RESHAPE(DEFINE_LOCAL) surviving linearization
|
||||
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) -- shaped store fails due to RESHAPE(local BUFFER) surviving linearization
|
||||
rw1 = UOp.range(TM, 296, AxisType.LOOP)
|
||||
rw2 = UOp.range(TN, 297, AxisType.LOOP)
|
||||
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)
|
||||
|
|
|
|||
|
|
@ -2619,7 +2619,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
|||
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
|
||||
gidx = UOp.special(NUM_WG, "gidx0")
|
||||
insts = build_kernel(batch, M, N, K, A.dtype.base)
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds = UOp.placeholder((133_120,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
|
||||
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),
|
||||
|
|
@ -2630,7 +2630,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
|||
@functools.cache
|
||||
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3) -> UOp:
|
||||
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both
|
||||
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0)
|
||||
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0) + (1 if scale_mode & 4 else 0)
|
||||
scales, extra = args[:n_scales], args[n_scales:]
|
||||
M, K = A.shape[0]*A.shape[1], A.shape[2]
|
||||
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
|
||||
|
|
@ -2660,7 +2660,9 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
|
|||
block_size = 256
|
||||
threads = UOp.special(64 * 8, "lidx0")
|
||||
workgroups = UOp.special((M // block_size) * (N // block_size), "gidx0")
|
||||
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, threads, workgroups)
|
||||
e_a = extra[0].base if len(extra) >= 1 else scale_A.base
|
||||
e_b = extra[1].base if len(extra) >= 2 else scale_B.base
|
||||
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, e_a, e_b, threads, workgroups)
|
||||
sink = UOp.sink(*sink_inputs,
|
||||
arg=KernelInfo(f"hk_mxfp8_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K)*A.dtype.itemsize+M*N*C.dtype.itemsize)))
|
||||
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
|
||||
|
|
@ -2672,14 +2674,14 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
|
|||
|
||||
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# 1x32 block scaling along the last axis
|
||||
rows, K = x.shape
|
||||
scale_K, k_iters = K // 32, K // 128
|
||||
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
|
||||
*batch, K = x.shape
|
||||
scale_K = K // 32
|
||||
amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
|
||||
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
|
||||
x_scaled = x.float() * qscale
|
||||
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
|
||||
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8)
|
||||
return x_clamped.cast(FP8_DTYPE), e8, (mx_pack(e8) if len(batch) == 1 else None)
|
||||
|
||||
def mx_pack(e8:Tensor) -> Tensor:
|
||||
rows, scale_K = e8.shape
|
||||
|
|
@ -2750,7 +2752,8 @@ def custom_hk_bf16_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str) -> UOp:
|
|||
assert M % block_m == 0 and N % block_n == 0 and K % block_k == 0, f"invalid bf16 tile {(block_m, block_n, block_k)} for {(M, N, K)}"
|
||||
threads = UOp.special(64 * num_warps, "lidx0")
|
||||
workgroups = UOp.special((M // block_m) * (N // block_n), "gidx0")
|
||||
sink = UOp.sink(C.base, A.base, B.base, threads, workgroups,
|
||||
b_extra = args[0].base if len(args) >= 1 else B.base
|
||||
sink = UOp.sink(C.base, A.base, B.base, b_extra, threads, workgroups,
|
||||
arg=KernelInfo(f"hk_bf16_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K+M*N)*A.dtype.itemsize)))
|
||||
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
|
||||
src = (kittens_path/"gemm_bf16.cpp").read_text()
|
||||
|
|
@ -2785,15 +2788,21 @@ def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
|
|||
assert M % TILE_M == 0 and N % TILE_N == 0 and (batch * rows) % TILE_K == 0, \
|
||||
f"atb shape {a.shape} {b.shape} must produce (M,N,K) multiples of ({TILE_M},{TILE_N},{TILE_K})"
|
||||
is_multi = isinstance(a.device, tuple)
|
||||
reduce_out = False
|
||||
if is_multi:
|
||||
out = Tensor(Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
|
||||
ndev = len(a.device)
|
||||
if a.uop.axis in (0, 1) or b.uop.axis in (0, 1): inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
|
||||
elif b.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M, N // ndev, dtype=a.dtype, device=a.device), 2
|
||||
elif a.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M // ndev, N, dtype=a.dtype, device=a.device), 1
|
||||
else: inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
|
||||
out = Tensor(inv.uop.multi(out_axis), device=a.device)
|
||||
dname = a.device[0]
|
||||
else:
|
||||
out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device)
|
||||
dname = a.device
|
||||
dname = dname.split(":")[0]
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_hk_bf16_atb_gemm, dname=dname))[0]
|
||||
if is_multi: out = out.sum(0)
|
||||
if reduce_out: out = out.sum(0)
|
||||
return out.squeeze(0) if out.ndim == 3 else out
|
||||
|
||||
|
||||
|
|
@ -2805,13 +2814,15 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
out, a, b = inputs[:3]
|
||||
i = 3
|
||||
s_x = inputs[i]; i += 1
|
||||
has_w = n_scales == 2
|
||||
has_w = n_scales >= 2
|
||||
s_w = inputs[i] if has_w else None; i += has_w
|
||||
s_g = inputs[i] if n_scales == 3 else None; i += (n_scales == 3)
|
||||
grad_amax_state = inputs[i] if has_grad_amax else None; i += has_grad_amax
|
||||
w_post = inputs[i] if has_w_post else None
|
||||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
s_x_t = Tensor(s_x, device=a.device)
|
||||
s_w_t = Tensor(s_w, device=a.device) if has_w else None
|
||||
s_g_t = Tensor(s_g, device=a.device) if s_g is not None else None
|
||||
w_post_t = Tensor(w_post, device=a.device) if has_w_post else None
|
||||
g_t = g_t[:a.shape[0]]
|
||||
from extra.llama_kernels.cast_amax import _grad_fp8_mailbox
|
||||
|
|
@ -2836,7 +2847,8 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
store_effect = grad_amax_state.store(new_grad_amax.uop)
|
||||
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
|
||||
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t) if has_w else asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t)
|
||||
if s_g_t is not None: g_scale = g_scale * s_g_t
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: no w_scale
|
||||
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
|
||||
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
|
||||
|
|
@ -2844,7 +2856,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
|
||||
else:
|
||||
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: rescale if not scalar
|
||||
if w_post_t is not None:
|
||||
grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim)))
|
||||
|
|
@ -2876,7 +2888,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
|
||||
# ** mxfp8 gemm backward
|
||||
|
||||
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
|
||||
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool, w_stored:bool=False):
|
||||
inputs = kernel.src[1:] # (out, a_q, b_q, a_si, b_si, a_e8, b_e8, [w_post])
|
||||
aq, bq = Tensor(inputs[1], device=inputs[1].device), Tensor(inputs[2], device=inputs[2].device)
|
||||
ae8, be8 = Tensor(inputs[5], device=inputs[5].device), Tensor(inputs[6], device=inputs[6].device)
|
||||
|
|
@ -2890,14 +2902,14 @@ def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
|
|||
grad_b = asm_gemm(g.T, a_phys, mx=True)
|
||||
|
||||
grad_a = (grad_a * _mx_block_scale(ae8)).reshape(aq.shape)
|
||||
grad_b = grad_b * _mx_block_scale(be8)
|
||||
if not w_stored: grad_b = grad_b * _mx_block_scale(be8)
|
||||
if wp is not None: grad_b = grad_b / wp.reshape(-1, 1)
|
||||
return (None, grad_a.uop, grad_b.uop) + tuple(None for _ in inputs[3:])
|
||||
|
||||
# ** main gemm function
|
||||
|
||||
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None) -> Tensor:
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False, g_scale:Tensor|None=None) -> Tensor:
|
||||
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
|
||||
counters["used"] += 1
|
||||
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
|
||||
|
|
@ -2939,13 +2951,13 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
|
|||
b_q, b_e8, b_si = quantize_mxfp8(b.T)
|
||||
has_w_post = w_post_scale is not None
|
||||
fxn = functools.partial(custom_hk_mxfp8_gemm, dname=dname)
|
||||
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post)
|
||||
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post, w_stored=mx_w_stored)
|
||||
extra = [w_post_scale] if w_post_scale is not None else []
|
||||
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
|
||||
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
|
||||
elif a.dtype == FP8_DTYPE:
|
||||
scales = tuple(s for s in (x_scale, w_scale) if s is not None)
|
||||
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0)
|
||||
scales = tuple(s for s in (x_scale, w_scale, g_scale) if s is not None)
|
||||
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0) | (4 if g_scale is not None else 0)
|
||||
extra = ([grad_amax_state] if grad_amax_state is not None else []) + ([w_post_scale] if w_post_scale is not None else [])
|
||||
fxn = functools.partial(custom_hk_fp8_gemm, dname=dname, scale_mode=scale_mode)
|
||||
bw = functools.partial(custom_gemm_bw, n_scales=len(scales), has_grad_amax=grad_amax_state is not None, has_w_post=w_post_scale is not None)
|
||||
|
|
|
|||
|
|
@ -219,7 +219,8 @@ def test_matmul():
|
|||
def asm_kernel(A, B, C):
|
||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||
lidxs = [UOp.special(THREADS, "lidx0")]
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2))
|
||||
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
|
||||
arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
|
|
|||
|
|
@ -1,145 +0,0 @@
|
|||
from __future__ import annotations
|
||||
import time
|
||||
from typing import cast
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.jit import GraphRunner
|
||||
from tinygrad.engine.realize import get_call_outs_ins, get_runtime
|
||||
from tinygrad.helpers import round_up, ceildiv
|
||||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, graph_rewrite
|
||||
from extra.hcq2.hcq2 import HCQ2Compiled, HCQ2DeviceCtx, HCQ2LowerCtx, pm_prep_runtime, pm_lower_ops
|
||||
from extra.hcq2.hcq2 import pm_split_into_queues, pm_add_barriers, pm_add_signals
|
||||
from extra.hcq2.hcq2 import pm_bufferize, pm_lift_patches_to_cmdbuf, pm_resolve_patches, pm_parametrize_host_buffers
|
||||
from extra.hcq2.hcq2 import pm_add_timeline_inc, pm_callify, pm_calc_kernargs_sizes
|
||||
|
||||
# **************** insert deps ****************
|
||||
|
||||
def insert_deps(ctx:HCQ2Graph, linear:UOp) -> UOp:
|
||||
src = []
|
||||
for j, call in enumerate(linear.src):
|
||||
call = call.replace(tag=j)
|
||||
_, _, bufs, _ = ctx.calls[j]
|
||||
outs, ins = get_call_outs_ins(call)
|
||||
deps = ctx._access_resources([bufs[i] for i in outs + ins], list(range(len(outs))), call)
|
||||
src.append(UOp(Ops.AFTER, call.dtype, (call, *deps), tag=call.tag))
|
||||
return linear.replace(src=tuple(src))
|
||||
pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)])
|
||||
|
||||
pm_replace_params = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
|
||||
(UPat(Ops.SLICE, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
|
||||
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
|
||||
])
|
||||
|
||||
# **************** graph-only passes ****************
|
||||
|
||||
def alloc_queue_sig(ctx:HCQ2Graph, q:UOp) -> None:
|
||||
if q.arg in ctx.queue_sigs: return None
|
||||
dev = q.arg[0][0] # TODO: multi device
|
||||
buf = Buffer(dev, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
|
||||
ctx.queue_sig_bufs.append(buf)
|
||||
ctx.queue_sigs[q.arg] = UOp.from_buffer(buf, dev)
|
||||
return None
|
||||
pm_alloc_queue_sigs = PatternMatcher([(UPat(Ops.LINEAR, src=UPat({Ops.PROGRAM, Ops.COPY}), name="q"), alloc_queue_sig)])
|
||||
|
||||
def lower_queue_deps(ctx:HCQ2Graph, after:UOp) -> UOp:
|
||||
wrapper, deps, call_idx = after.src[0], after.src[1:], after.tag
|
||||
def store(q_arg, v): return ctx.queue_sigs[q_arg].store(UOp.const(dtypes.uint32, v))
|
||||
waits = tuple(UOp(Ops.WAIT, dtypes.void, (ctx.queue_sigs[dep.src[0].arg], UOp.const(dtypes.uint32, dep.tag),
|
||||
store(dep.src[0].arg, dep.tag))) for dep in deps)
|
||||
return wrapper.replace(src=tuple(q.replace(src=(*waits, *q.src, store(q.arg, call_idx))) for q in wrapper.src))
|
||||
pm_lower_queue_deps = PatternMatcher([(UPat(Ops.AFTER, src=UPat(Ops.LINEAR), name="after"), lower_queue_deps)])
|
||||
|
||||
def optimize_queue_deps(ctx:HCQ2Graph, queue:UOp) -> UOp|None:
|
||||
src, seen, pending, queue_sig = [], {}, {}, ctx.queue_sigs[queue.arg]
|
||||
for x in queue.src:
|
||||
if x.op is Ops.WAIT:
|
||||
sig, val = x.src[0], x.src[1]
|
||||
if sig is queue_sig or seen.get(sig, -1) >= val.arg: continue
|
||||
if (old:=pending.get(sig)) is None or old.src[1].arg < val.arg: pending[sig] = x
|
||||
continue
|
||||
for wait in pending.values():
|
||||
src.append(wait)
|
||||
seen[wait.src[0]] = wait.src[1].arg
|
||||
pending.clear()
|
||||
src.append(x)
|
||||
src += pending.values()
|
||||
return queue.replace(src=tuple(src)) if tuple(src) != queue.src else None
|
||||
pm_optimize_queue_deps = PatternMatcher([
|
||||
(UPat(Ops.LINEAR, src=UPat({Ops.BARRIER, Ops.WAIT, Ops.STORE, Ops.PROGRAM, Ops.COPY}), name="queue"), optimize_queue_deps),
|
||||
])
|
||||
|
||||
def drop_dead_stores(ctx:HCQ2Graph, outer:UOp) -> UOp:
|
||||
live = {u.src[2] for u in outer.toposort() if u.op is Ops.WAIT}
|
||||
return outer.replace(src=tuple(q.replace(src=tuple(x for x in q.src if x.op is not Ops.STORE or x in live)) for q in outer.src))
|
||||
pm_drop_dead_stores = PatternMatcher([(UPat(Ops.LINEAR, src=UPat(Ops.LINEAR), name="outer"), drop_dead_stores)])
|
||||
|
||||
def add_queue_sig_resets(ctx:HCQ2Graph, x:UOp, cmdbuf:UOp) -> UOp|None:
|
||||
if not ctx.queue_sig_bufs or cmdbuf.tag not in ("compute", "copy"): return None
|
||||
resets = tuple((b:=UOp.from_buffer(sig)).index(UOp.const(dtypes.int, 0), dtype=b.dtype.ptr())
|
||||
.cast(dtypes.uint64.ptr()).store(UOp.const(dtypes.uint64, 0)) for sig in ctx.queue_sig_bufs)
|
||||
return x.replace(src=x.src + resets)
|
||||
pm_add_queue_sig_resets = PatternMatcher([(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, name="cmdbuf"),), allow_any_len=True, name="x"),
|
||||
add_queue_sig_resets)])
|
||||
|
||||
# **************** Graph ****************
|
||||
|
||||
class HCQ2Graph(GraphRunner):
|
||||
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
|
||||
super().__init__(linear, input_uops)
|
||||
self.dev = cast(HCQ2Compiled, Device[self.device])
|
||||
self.hcq_ctx = HCQ2LowerCtx(name="hcq_graph")
|
||||
|
||||
self.input_addrs = Buffer("CPU", max(len(input_uops), 1), dtypes.uint64, preallocate=True)
|
||||
self.input_addrs_uop = UOp.from_buffer(self.input_addrs, "CPU")
|
||||
|
||||
self.linear = graph_rewrite(self.linear, pm_insert_deps, ctx=self, name="hcq: insert deps", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_replace_params, ctx=self, name="hcq: replace params", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_prep_runtime, ctx=self.hcq_ctx, name="hcq: prepare runtime")
|
||||
self.linear = graph_rewrite(self.linear, pm_lower_ops, ctx=self.hcq_ctx, name="hcq: lower ops")
|
||||
|
||||
# per-queue signal state — populated as a side-effect by pm_alloc_queue_sigs walking the lowered linear.
|
||||
self.queue_sig_bufs:list[Buffer] = []
|
||||
self.queue_sigs:dict[tuple[str, str], UOp] = {}
|
||||
graph_rewrite(self.linear, pm_alloc_queue_sigs, ctx=self, name="hcq: alloc queue sigs", walk=True)
|
||||
|
||||
self.linear = graph_rewrite(self.linear, pm_lower_queue_deps, ctx=self, name="hcq: lower queue deps")
|
||||
self.linear = graph_rewrite(self.linear, pm_split_into_queues, ctx=self.hcq_ctx, name="hcq: split into queues")
|
||||
self.linear = graph_rewrite(self.linear, pm_add_barriers, ctx=self.hcq_ctx, name="hcq: add barriers", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_optimize_queue_deps, ctx=self, name="hcq: optimize queue deps", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_drop_dead_stores, ctx=self, name="hcq: drop dead stores")
|
||||
self.linear = graph_rewrite(self.linear, pm_add_signals, ctx=self.hcq_ctx, name="hcq: add signals", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_add_timeline_inc, ctx=self.hcq_ctx, name="hcq: add submit", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, self.dev.pm_lower, ctx=self.hcq_ctx, name=f"hcq: encode cmdbuf {self.dev.device}", walk=True)
|
||||
|
||||
graph_rewrite(self.linear, pm_calc_kernargs_sizes, ctx=(sizes:={}), name=None)
|
||||
for dev_name, sz in sizes.items():
|
||||
buf = Buffer(dev_name, sz, dtypes.uint8, options=BufferSpec(cpu_access=True), preallocate=True)
|
||||
self.hcq_ctx.dev_ctx[dev_name] = HCQ2DeviceCtx(dev_name, UOp.from_buffer(buf, dev_name), UOp.const(dtypes.uint64, buf._buf.va_addr))
|
||||
|
||||
self.linear = graph_rewrite(self.linear, pm_bufferize, ctx=self.hcq_ctx, bottom_up=True, name="realize binaries")
|
||||
self.linear = graph_rewrite(self.linear, pm_lift_patches_to_cmdbuf, ctx=self.hcq_ctx, bottom_up=False, name="lift patches to cmdbuf")
|
||||
self.linear = graph_rewrite(self.linear, pm_resolve_patches, ctx=self.hcq_ctx, bottom_up=False, name="simplify patches")
|
||||
self.linear = graph_rewrite(self.linear, pm_add_queue_sig_resets, ctx=self, name="hcq: add queue sig resets", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_parametrize_host_buffers, ctx=self.hcq_ctx, bottom_up=True, name="parametrize host buffers")
|
||||
self.host_call = graph_rewrite(self.linear, pm_callify, ctx=self.hcq_ctx, name="hcq: callify")
|
||||
|
||||
self.host_rt, self.host_globals = get_runtime("CPU", self.host_call.src[0]), self.host_call.src[0].arg.globals
|
||||
|
||||
def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False) -> float|None:
|
||||
addrs = self.input_addrs.as_memoryview(force_zero_copy=True).cast('Q')
|
||||
for i, u in enumerate(input_uops):
|
||||
buf = next(b for b in u.buffer.bufs if b.device == self.dev.device) if isinstance(u.buffer, MultiBuffer) else u.buffer
|
||||
addrs[i] = buf._buf.va_addr
|
||||
self.host_rt(*[self.hcq_ctx.inputs[i].get_buf("CPU") for i in self.host_globals], vals=self.host_call.src[0].arg.vals(var_vals), wait=True)
|
||||
if wait:
|
||||
st = time.perf_counter()
|
||||
self.dev.synchronize()
|
||||
return time.perf_counter() - st
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool:
|
||||
all_devs = GraphRunner._all_devs(batch_devs, new_call)
|
||||
return new_call.src[0].op in (Ops.PROGRAM, Ops.COPY) and len(all_devs) == 1 and isinstance(all_devs[0], HCQ2Compiled)
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast, Callable, TypeVar, Generic, Any
|
||||
import struct, functools, time, collections, importlib, itertools, weakref
|
||||
from dataclasses import replace
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, pluralize, to_tuple
|
||||
from dataclasses import replace, dataclass, field
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, DEBUG, dedup, flatten, pluralize
|
||||
from tinygrad.helpers import to_tuple, round_up
|
||||
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
|
||||
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
|
||||
from tinygrad.uop.symbolic import symbolic_simple, symbolic
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
|
||||
|
|
@ -28,7 +27,7 @@ class HCQ2Compiled(Compiled):
|
|||
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value()),
|
||||
(UPat(Ops.BUFFER, tag="sentinel_signal"), lambda ctx: ctx.timeline_signal("sentinel", (1 << 64) - 1)),
|
||||
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
|
||||
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
|
||||
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=False, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
|
||||
])
|
||||
|
||||
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
|
||||
|
|
@ -137,21 +136,24 @@ def unwrap_after(uop):
|
|||
while uop.op is Ops.AFTER: uop = uop.src[0]
|
||||
return uop
|
||||
|
||||
def make_getaddr(u, dev=None):
|
||||
def make_getaddr(u, device=None):
|
||||
if unwrap_after(u).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return u
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=dev or to_tuple(u.device)[0])))
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=device or to_tuple(u.device)[0])))
|
||||
|
||||
def make_ins(op, *srcs):
|
||||
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
||||
|
||||
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
|
||||
dt = dtype or val.dtype
|
||||
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
|
||||
|
||||
def make_cmdbuf(lin, devs, tag):
|
||||
blob, patches = b'', []
|
||||
for s in (s for ins in lin.src for s in ins.src):
|
||||
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
||||
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
||||
buf = UOp.new_buffer(devs if len(devs) > 1 else devs[0], len(blob), dtypes.uint8).rtag(tag)
|
||||
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
|
||||
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
|
||||
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
|
||||
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches])
|
||||
|
||||
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
|
||||
|
||||
|
|
@ -172,6 +174,10 @@ class HCQInfo:
|
|||
name:str = ""
|
||||
estimates:Estimates = Estimates()
|
||||
outs:tuple[int, ...] = ()
|
||||
devs:tuple[str, ...] = ()
|
||||
|
||||
params:tuple[int, ...] = ()
|
||||
inputs:int|None = None
|
||||
|
||||
@staticmethod
|
||||
def from_call(call:UOp) -> HCQInfo: return HCQInfo(get_call_name(call, get_call_arg_uops(call)), estimate_uop(call), get_call_outs_ins(call)[0])
|
||||
|
|
@ -208,15 +214,11 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
|
|||
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
|
||||
|
||||
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
|
||||
data, info = prg.arg
|
||||
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
|
||||
dtypes.uint64) for i,gi in enumerate(info.globals)] \
|
||||
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
||||
|
||||
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
||||
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
|
||||
|
||||
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
|
||||
(data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device)
|
||||
buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
||||
patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \
|
||||
+ [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
||||
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
|
||||
|
||||
pm_prep_runtime = PatternMatcher([
|
||||
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
|
||||
|
|
@ -269,6 +271,14 @@ class DepsCtx:
|
|||
deps:DepsTracker = field(default_factory=DepsTracker)
|
||||
opid:itertools.count = field(default_factory=lambda: itertools.count(0))
|
||||
last_per_queue:weakref.WeakValueDictionary[tuple[Any, str], UOp] = field(default_factory=weakref.WeakValueDictionary)
|
||||
params:dict[tuple[int, int], Buffer] = field(default_factory=dict)
|
||||
|
||||
def get_dep_buf(ctx:DepsCtx, u:UOp, lane:int) -> Buffer:
|
||||
# TODO: should this be a part of DepsTracker?
|
||||
if u.op is Ops.PARAM: return ctx.params.setdefault((u.arg.slot, lane), Buffer("NULL", u.max_numel(), u.dtype.base))
|
||||
if u.op is Ops.MSTACK: return get_dep_buf(ctx, u.src[lane], 0)
|
||||
if u.op in (Ops.SLICE, Ops.MSELECT): return get_dep_buf(ctx, u.src[0], u.arg if u.op is Ops.MSELECT else lane)
|
||||
return b.bufs[lane] if isinstance(b:=u.buffer, MultiBuffer) else b
|
||||
|
||||
def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||
new_src = []
|
||||
|
|
@ -277,16 +287,26 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
|||
new_src.append(call)
|
||||
continue
|
||||
|
||||
q = get_submit(call.src[0]).src[0]
|
||||
new_q = ctx.last_per_queue[q.arg] = q.rtag(next(ctx.opid))
|
||||
new_q = ctx.last_per_queue[q.arg] = (q:=get_submit(call.src[0]).src[0]).rtag(next(ctx.opid))
|
||||
qdevs, refs = to_tuple(new_q.arg[0]), get_call_arg_uops(call)
|
||||
|
||||
deps = []
|
||||
refs = [b.buffer for b in get_call_arg_uops(call)]
|
||||
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
|
||||
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.arg.aux.outs, new_q)
|
||||
# per-lane deps, tracked per (device, queue). skip self
|
||||
dep_lanes:list[tuple[UOp, int]] = []
|
||||
for lane, d in enumerate(qdevs):
|
||||
for dep in ctx.deps.access_resources([get_dep_buf(ctx, b, lane) for b in refs], call.arg.aux.outs, new_q.replace(arg=(d, new_q.arg[1]))):
|
||||
if dep.tag != new_q.tag: dep_lanes.append((dep, lane))
|
||||
|
||||
new_q = new_q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else new_q
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
|
||||
# drop self-queue waits, queue self-orders
|
||||
if qdevs[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
|
||||
dep_lanes = [(dep, lane) for dep, lane in dep_lanes if dep.arg != (qdevs[lane], new_q.arg[1])]
|
||||
|
||||
# keep latest dep per lane, group lanes
|
||||
latest = {(dep.arg, lane): dep for dep, lane in sorted(dep_lanes, key=lambda x: x[0].tag)}
|
||||
deps:dict[UOp, tuple[int, ...]] = collections.defaultdict(tuple)
|
||||
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
|
||||
|
||||
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}),)))
|
||||
return linear.replace(src=tuple(new_src))
|
||||
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
||||
|
||||
|
|
@ -298,11 +318,16 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
|
|||
zero = UOp.const(dtypes.int, 0)
|
||||
tl = make_signal_value(devs)
|
||||
|
||||
submit = make_submit(make_signal(devs).store(tl.index(zero) + 1), devs=devs, queue="COMPUTE:0")
|
||||
# queue is inc with deps
|
||||
submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0")
|
||||
|
||||
# split each (multi-device) queue into per-device deps so each finalizer lane waits on the matching device's signal
|
||||
lane_queues = [(q.replace(arg=(d, q.arg[1])), (devs.index(d),)) for q in queues for d in to_tuple(q.arg[0])]
|
||||
submit = submit.replace(src=(submit.src[0].after(*(q for q, _ in lane_queues), arg=tuple(l for _, l in lane_queues)).rtag("deps"),))
|
||||
|
||||
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
|
||||
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]) \
|
||||
.sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
|
||||
patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]
|
||||
return UOp.barrier(*patches).sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
|
||||
|
||||
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||
parts:dict[str, list[UOp]] = collections.defaultdict(list)
|
||||
|
|
@ -319,21 +344,22 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
|
|||
cur_devs = to_tuple((cur:=deps.src[0]).arg[0])
|
||||
|
||||
waits = []
|
||||
for dep in deps.src[1:]:
|
||||
devs, queue = dep.arg
|
||||
for lanes, dep in zip(deps.arg, deps.src[1:]):
|
||||
dep_dev, queue = dep.arg # dep_dev is a single device (deps are recorded per-device)
|
||||
ctx.add(dep.tag) # mark op to update signal.
|
||||
|
||||
sig = make_mstack([make_signal(d, queue=queue, sentinel=d not in devs) for d in cur_devs])
|
||||
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0))
|
||||
# for lanes that need this dep, wait on the dep device's signal/value; other lanes get a passing sentinel
|
||||
lanes = set(lanes)
|
||||
sig = make_mstack([make_signal(dep_dev if j in lanes else d, queue=queue, sentinel=j not in lanes) for j, d in enumerate(cur_devs)])
|
||||
val = make_mstack([make_signal_value(dep_dev if j in lanes else d, queue=queue) for j, d in enumerate(cur_devs)]).index(UOp.const(dtypes.int, 0))
|
||||
waits.append(sig.wait(val + dep.tag))
|
||||
return cur.replace(src=tuple(waits) + cur.src)
|
||||
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
|
||||
|
||||
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
|
||||
src = q.src
|
||||
if q.tag in ctx:
|
||||
devs, queue = q.arg
|
||||
src += (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
|
||||
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp|None:
|
||||
if q.tag not in ctx: return None
|
||||
devs, queue = q.arg
|
||||
src = q.src + (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
|
||||
return submit.replace(src=(q.replace(src=src, tag=None),))
|
||||
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
|
||||
|
||||
|
|
@ -342,34 +368,36 @@ pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", s
|
|||
|
||||
def get_submit(ast:UOp) -> UOp: return next(u for u in ast.toposort() if u.op is Ops.CUSTOM_FUNCTION and u.arg == "submit")
|
||||
|
||||
def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
|
||||
old_submit, new_submit = get_submit(old_sink), get_submit(new_sink)
|
||||
old_queue, new_queue = old_submit.src[0], new_submit.src[0]
|
||||
merged_submit = new_submit.replace(src=(new_queue.replace(src=old_queue.src + new_queue.src),))
|
||||
old_root = old_sink.src[0].substitute({old_submit: merged_submit})
|
||||
new_anchor = merged_submit if old_sink.src[0] is old_submit else old_root
|
||||
return new_sink.substitute({new_submit: new_anchor})
|
||||
def merge_sink(sinks:list[UOp]) -> UOp:
|
||||
if len(sinks) == 1: return sinks[0]
|
||||
submits = [get_submit(sink) for sink in sinks]
|
||||
queues = [submit.src[0] for submit in submits]
|
||||
anchor = submits[-1].replace(src=(queues[-1].replace(src=tuple(x for q in queues for x in q.src)),))
|
||||
for sink, submit in zip(sinks[:-1], submits[:-1]):
|
||||
if sink.src[0] is not submit: anchor = sink.src[0].substitute({submit: anchor}, walk=True)
|
||||
return sinks[-1].substitute({submits[-1]: anchor}, walk=True)
|
||||
|
||||
def merge_queues(linear:UOp) -> UOp:
|
||||
new_src:list[UOp] = []
|
||||
opened_qs:dict[tuple[tuple[str, ...], str], tuple[UOp, HCQInfo]] = {} # (devs, queue) -> (sink, aux), kept in submit order
|
||||
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
|
||||
|
||||
for call in linear.src:
|
||||
if call.tag != "hcq":
|
||||
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in list(opened_qs)] + [call]
|
||||
# finalizer cannot be merged, since it bumps inner signal (this introduces race when multidevs).
|
||||
if call.tag != "hcq" or (call.tag == "hcq" and call.arg.aux.name == "hcq finalizer"):
|
||||
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
|
||||
continue
|
||||
|
||||
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
|
||||
aux = call.arg.aux
|
||||
new_rec = ([new_sink], call.arg.aux)
|
||||
if (old:=opened_qs.pop((devs, queue), None)) is not None:
|
||||
new_sink = merge_sinks(old[0], new_sink) # exact same queue: merge, and re-insert at the end
|
||||
aux = replace(aux, name=f"{queue.lower()} submit", estimates=old[1].estimates + aux.estimates)
|
||||
new_rec = (old[0] + [new_sink], replace(new_rec[1], name=f"{queue.lower()} submit", estimates=old[1].estimates + new_rec[1].estimates))
|
||||
else:
|
||||
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
|
||||
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
|
||||
opened_qs[(devs, queue)] = (new_sink, aux)
|
||||
closing = [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]
|
||||
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in closing]
|
||||
opened_qs[(devs, queue)] = new_rec
|
||||
|
||||
return linear.replace(src=tuple(new_src + [sink.call(aux=aux).rtag('hcq') for sink, aux in opened_qs.values()]))
|
||||
return linear.replace(src=tuple(new_src + [merge_sink(sinks).call(aux=aux).rtag("hcq") for sinks, aux in opened_qs.values()]))
|
||||
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
|
||||
|
||||
# *****************
|
||||
|
|
@ -384,6 +412,29 @@ def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
|
|||
return submit.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), wait, *q.src)),))
|
||||
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
|
||||
|
||||
# *****************
|
||||
# 4.3. annotate exec devs
|
||||
|
||||
pm_annotate_devs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
|
||||
lambda call: call.replace(arg=replace(call.arg, aux=replace(call.arg.aux, devs=get_submit(call.src[0]).src[0].arg[0]))))])
|
||||
|
||||
# *****************
|
||||
# 4.4. replace params with per-submit input address loads
|
||||
|
||||
def replace_params(call:UOp) -> UOp|None:
|
||||
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is AddrSpace.GLOBAL}): return None
|
||||
|
||||
# fill new info
|
||||
hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call)))
|
||||
|
||||
inputs = UOp.new_buffer(get_submit(call.src[0]).src[0].arg[0], len(hcqinfo.params), dtypes.uint64).rtag("inputs")
|
||||
|
||||
slot2idx = {s:i for i,s in enumerate(hcqinfo.params)}
|
||||
body = call.src[0].substitute({u:inputs.index(UOp.const(dtypes.int, slot2idx[s])).load() for u,s in params.items()})
|
||||
|
||||
return call.replace(src=(body, *call.src[1:], inputs), arg=replace(call.arg, aux=hcqinfo))
|
||||
pm_replace_params = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), replace_params)])
|
||||
|
||||
# *****************
|
||||
# 5.1. encode cmdbufs
|
||||
|
||||
|
|
@ -411,16 +462,28 @@ pm_lift_patches_to_cmdbuf = PatternMatcher([
|
|||
])
|
||||
|
||||
# *****************
|
||||
# 6. bufferize placeholders: replace placeholders with real buffers.
|
||||
# 5.3. pack placeholders buffers
|
||||
|
||||
def bufferize_buf(buf:UOp) -> UOp|None:
|
||||
if buf.tag is None: return None
|
||||
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), dev) for dev in to_tuple(buf.src[1].arg))
|
||||
return make_mstack(uops)
|
||||
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
|
||||
def pack_hcq_placeholders(call:UOp) -> UOp|None:
|
||||
bufs = [b for b in call.src[0].toposort() if b.op is Ops.BUFFER and b.tag in (maxtags:={"scratch"}) | (sumtags:={"program", "kernargs"})]
|
||||
|
||||
off_per_buf:dict[UOp, int] = {}
|
||||
size_per_tag:dict[str, int] = {}
|
||||
for b in bufs:
|
||||
if b.tag in maxtags: size_per_tag[b.tag] = max(size_per_tag.get(b.tag, 0), b.arg)
|
||||
elif b.tag in sumtags:
|
||||
off_per_buf[b] = round_up(size_per_tag.get(b.tag, 0), {"program": 0x1000}.get(b.tag, 128))
|
||||
size_per_tag[b.tag] = off_per_buf[b] + b.arg
|
||||
|
||||
count_per_tag = collections.Counter(b.tag for b in bufs)
|
||||
ref_bufs = {b.tag:b for b in bufs if count_per_tag[b.tag] > 1}
|
||||
bases = {tag:UOp.new_buffer(b.src[1].arg, size_per_tag[tag], b.dtype).rtag(tag) for tag,b in ref_bufs.items()}
|
||||
subs = {b:UOp(Ops.SLICE, b.dtype, (bases[b.tag], UOp.const(dtypes.weakint, off_per_buf.get(b, 0))), b.arg) for b in bufs if b.tag in bases}
|
||||
return call.replace(src=(call.src[0].substitute(subs, walk=True), *call.src[1:])) if subs else None
|
||||
pm_pack_placeholders = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), pack_hcq_placeholders)])
|
||||
|
||||
# *****************
|
||||
# 7.1. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
||||
# 5.4. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
||||
|
||||
def hold_call_buffers(call:UOp) -> UOp|None:
|
||||
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
|
||||
|
|
@ -428,57 +491,83 @@ def hold_call_buffers(call:UOp) -> UOp|None:
|
|||
pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), hold_call_buffers)])
|
||||
|
||||
# *****************
|
||||
# 7.2. resolve patches
|
||||
# 6. bufferize placeholders: replace placeholders with real buffers.
|
||||
|
||||
def bufferize_buf(buf:UOp) -> UOp|None:
|
||||
if buf.tag is None: return None
|
||||
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), "CPU") for dev in to_tuple(buf.src[1].arg))
|
||||
return make_mstack(uops)
|
||||
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
|
||||
|
||||
# *****************
|
||||
# 7. resolve patches
|
||||
|
||||
def push_stack(op, s): return UOp(Ops.STACK, op.dtype.scalar().vec(len(s.src)),
|
||||
tuple(op.replace(dtype=op.dtype.scalar(), src=tuple(x if y is s else y for y in op.src)) for x in s.src))
|
||||
|
||||
def fold_blob_store(buf:UOp, blob:UOp) -> UOp:
|
||||
for b in (buf.src if buf.op is Ops.MSTACK else (buf,)): b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
|
||||
for b in (mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)): b.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
|
||||
return UOp(Ops.NOOP)
|
||||
|
||||
def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
|
||||
for b, v in zip((buf.src if buf.op is Ops.MSTACK else (buf,)), (val.src if val.op is Ops.STACK else (val,))):
|
||||
struct.pack_into(f'<{v.dtype.fmt}', b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * b.dtype.base.itemsize, v.arg)
|
||||
for b, v in zip((bs:=mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)), val.src if val.op is Ops.STACK else (val,)*len(bs)):
|
||||
struct.pack_into(f'<{v.dtype.fmt}', b.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * buf.dtype.base.itemsize, v.arg)
|
||||
return UOp(Ops.NOOP)
|
||||
|
||||
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
|
||||
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr)
|
||||
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs))
|
||||
if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf
|
||||
devs, b = to_tuple(g.src[1].arg), buf.buffer
|
||||
bufs = tuple(cast(Buffer, x.buffer) for x in buf.src) if buf.op is Ops.MSTACK else tuple(b.bufs if isinstance(b, MultiBuffer) else (b,)*len(devs))
|
||||
assert len(bufs) == len(devs), f"can't resolve {len(bufs)} buffers on {len(devs)} devices"
|
||||
addrs = tuple(UOp.const(dtypes.uint64, x.get_buf(d).va_addr) for x, d in zip(bufs, devs))
|
||||
return addrs[0] if len(addrs) == 1 else UOp(Ops.STACK, dtypes.uint64.vec(len(addrs)), addrs)
|
||||
|
||||
def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp:
|
||||
itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * itemsize)
|
||||
|
||||
pm_resolve_patches = PatternMatcher([
|
||||
# multi
|
||||
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
||||
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
||||
|
||||
# shrink on slice is shrink on base at offset
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"),
|
||||
lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))),
|
||||
|
||||
# getaddr
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||
lambda bv, dev: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
|
||||
(UPat(Ops.GETADDR, src=(UPat({Ops.BUFFER, Ops.MSTACK, Ops.MSELECT}, name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||
(UPat(Ops.GETADDR, src=(UPat(name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
|
||||
|
||||
# folders
|
||||
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
||||
fold_const_store),
|
||||
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
(UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast()
|
||||
.store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store),
|
||||
]) + symbolic_simple
|
||||
|
||||
# *****************
|
||||
# 8. callify hcq programs
|
||||
|
||||
def to_param(bufs:list[UOp], ref:UOp) -> UOp:
|
||||
bufs.append(ref)
|
||||
return UOp.placeholder((ref.buffer.size,), ref.dtype, len(bufs)-1)
|
||||
if ref not in bufs: bufs.append(ref)
|
||||
return UOp.placeholder((ref.buffer.size,), ref.dtype, bufs.index(ref))
|
||||
pm_to_param = PatternMatcher([(UPat({Ops.MSELECT, Ops.MSTACK, Ops.BUFFER}, name="r"), lambda ctx, r: to_param(ctx, r))])
|
||||
|
||||
def parametrize_host_buffers(call:UOp) -> UOp:
|
||||
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=[]), bottom_up=True, name="parametrize host buffers")
|
||||
return call.replace(src=(body, *bufs) + call.src[1:], tag="hcq_param")
|
||||
# preserve original order of args
|
||||
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=list(get_call_arg_uops(call))), bottom_up=True, name="parametrize host buffers")
|
||||
|
||||
# move vars to new slots
|
||||
var_slots = {nm:len(bufs)+i for i,nm in enumerate(sorted({v.expr for v in body.variables() if v.op is Ops.PARAM}))}
|
||||
body = body.substitute({v:v.replace(arg=replace(v.arg, slot=var_slots[v.expr])) for v in body.variables() if v.op is Ops.PARAM})
|
||||
|
||||
return call.replace(src=(body, *bufs) + tuple(x for x in call.src[1:] if x.op is Ops.BIND))
|
||||
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
|
||||
|
||||
def callify_hcq(call:UOp) -> UOp:
|
||||
prg = to_program(call.src[0].sink(arg=KernelInfo("hcq_submit"), tag=1), Device["CPU"].renderer)
|
||||
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(prg,), arg="hcq").call(*call.src[1:], aux=call.arg.aux)
|
||||
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq_param", name="call"), callify_hcq)])
|
||||
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), callify_hcq)])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
|
||||
def hcq_schedule(linear:UOp) -> UOp:
|
||||
|
|
@ -492,14 +581,17 @@ def hcq_schedule(linear:UOp) -> UOp:
|
|||
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_merge_queues, name="merge queues")
|
||||
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_annotate_devs, name="annotate devs")
|
||||
linear = graph_rewrite(linear, pm_replace_params, name="replace params")
|
||||
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_pack_placeholders, walk=True, name="pack placeholders")
|
||||
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
|
||||
|
||||
# realize starts from here
|
||||
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
|
||||
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, walk=True, name="bufferize placeholders", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_parametrize_host_buffers, name="parametrize host buffers")
|
||||
linear = graph_rewrite(linear, pm_parametrize_host_buffers, walk=True, name="parametrize host buffers")
|
||||
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
|
||||
|
||||
return linear
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast
|
||||
from typing import cast, Any, Callable
|
||||
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -89,25 +89,25 @@ def memory_barrier(ctx):
|
|||
reg_done=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff),
|
||||
acquire_mem(ctx)))
|
||||
|
||||
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.device))
|
||||
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.devs))
|
||||
|
||||
def pm4_barrier(ctx): return memory_barrier(ctx)
|
||||
|
||||
def pm4_store(ctx, dst, val):
|
||||
if val.op is Ops.BINARY: return None
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.devs), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
ctx.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
|
||||
|
||||
def pm4_timestamp(ctx, dst):
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.devs), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
ctx.pm4.int_sel__mec_release_mem__none)
|
||||
|
||||
def pm4_program(ctx, prg):
|
||||
data, info = prg.arg
|
||||
lib_gpu, args = prg.src
|
||||
prog_addr = make_getaddr(lib_gpu, ctx.device) + data.entry_point_offset
|
||||
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.device)
|
||||
args_addr = make_getaddr(args, ctx.device)
|
||||
prog_addr = make_getaddr(lib_gpu, ctx.devs) + data.entry_point_offset
|
||||
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.devs)
|
||||
args_addr = make_getaddr(args, ctx.devs)
|
||||
|
||||
user_regs = []
|
||||
if data.enable_private_segment_sgpr:
|
||||
|
|
@ -174,7 +174,7 @@ pm_pm4_submit = PatternMatcher([(UPat(Ops.LINEAR, name="lin"),
|
|||
class SDMAOps(FastEnum): COPY = auto(); POLL_REGMEM = auto(); FENCE = auto(); TRAP = auto(); TIMESTAMP = auto() # noqa: E702
|
||||
|
||||
def sdma_copy(ctx, dst, src, copy):
|
||||
src_addr, dst_addr = make_getaddr(src, ctx.device), make_getaddr(dst, ctx.device)
|
||||
src_addr, dst_addr = make_getaddr(src, ctx.devs), make_getaddr(dst, ctx.devs)
|
||||
return UOp(Ops.LINEAR, dtypes.void, tuple([make_ins(SDMAOps.COPY,
|
||||
ctx.sdma.SDMA_OP_COPY | ctx.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_COPY_LINEAR),
|
||||
ctx.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(min(copy.arg - off, ctx.max_copy_size) - 1), 0,
|
||||
|
|
@ -183,17 +183,17 @@ def sdma_copy(ctx, dst, src, copy):
|
|||
def sdma_wait(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_POLL_REGMEM | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) \
|
||||
| ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1)
|
||||
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.device)), val, 0xffffffff,
|
||||
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.devs)), val, 0xffffffff,
|
||||
ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
|
||||
|
||||
def sdma_store(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_FENCE | (ctx.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if ctx.target[0] != 9 else 0)
|
||||
return UOp(Ops.LINEAR, dtypes.void, (
|
||||
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.device)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
|
||||
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.devs)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
|
||||
|
||||
def sdma_timestamp(ctx, dst):
|
||||
op = ctx.sdma.SDMA_OP_TIMESTAMP | ctx.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL)
|
||||
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.device)))
|
||||
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.devs)))
|
||||
|
||||
pm_sdma_opsel = PatternMatcher([
|
||||
(UPat(Ops.BARRIER), lambda: UOp(Ops.NOOP, dtypes.void, ())),
|
||||
|
|
@ -516,11 +516,17 @@ class PCIIface(PCIIfaceBase):
|
|||
|
||||
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AMDEncodeCtx: # encode-time constants for one queue: devs (every cmdbuf address resolves into these) + gfx version + packet/ip modules
|
||||
devs: tuple[str, ...]; target: tuple[int, ...]; pm4: Any; sdma: Any; soc: Any # noqa: E702
|
||||
gc: AMDIP; nbio: AMDIP; xccs: int; max_copy_size: int; tmpring_size: Callable # noqa: E702
|
||||
|
||||
def encode_queue(q:UOp) -> UOp|None:
|
||||
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
|
||||
devs = to_tuple(q.arg[0])
|
||||
d = Device[(devs:=to_tuple(q.arg[0]))[0]]
|
||||
ctx = AMDEncodeCtx(devs, d.target, d.pm4, d.sdma, d.soc, d.gc, d.nbio, d.xccs, d.max_copy_size, d.tmpring_size)
|
||||
opsel, submit = (pm_pm4_opsel, pm_pm4_submit) if q.arg[1].startswith("COMPUTE") else (pm_sdma_opsel, pm_sdma_submit)
|
||||
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=Device[devs[0]], name=f"{q.arg[1]} opsel"))
|
||||
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=ctx, name=f"{q.arg[1]} opsel"))
|
||||
|
||||
pm_lower = PatternMatcher([
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),)), encode_queue),
|
||||
|
|
|
|||
|
|
@ -59,8 +59,8 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
|
|||
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, inv_scale.uop)
|
||||
return (None, None, grad_xw13_uop, None, None)
|
||||
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor]:
|
||||
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, new_amax)
|
||||
# grad_amax_state: delayed amax for grad_xw13 fp8 quantization in the backward.
|
||||
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
|
||||
MBS, SEQ, H2 = xw13.shape
|
||||
|
|
@ -72,5 +72,4 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_
|
|||
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname_of(xw13.device))
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, grad_amax_state,
|
||||
fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf)
|
||||
return fp8_out, scalar_amax(amax_buf)
|
||||
|
|
|
|||
|
|
@ -112,8 +112,8 @@ def _fused_add_bwd(*args, **kwargs):
|
|||
grad_h, grad_w = _bwd_common(fp8_grad_u, h_grad_u, x_u, x_normed_u, rrms_u, weight_u, amax_state_u, kernel)
|
||||
return (None, None, None, None, None, grad_h, grad_h, grad_w, None)
|
||||
|
||||
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, inv_scale, new_amax, x_normed, rrms).
|
||||
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, new_amax, x_normed, rrms).
|
||||
# x_normed + rrms are saved for the rmsnorm backward (also recomputed here from x regs).
|
||||
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
|
||||
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
|
||||
|
|
@ -127,13 +127,12 @@ def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, e
|
|||
fxn = functools.partial(_custom_fwd, dname=dname_of(x.device), eps_val=eps)
|
||||
fp8_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
|
||||
fp8_out, x_normed_out, rrms_out, amax_buf, x, weight, amax_state, fxn=fxn, grad_fxn=_fused_bwd)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf), x_normed_out, rrms_out
|
||||
return fp8_out, scalar_amax(amax_buf), x_normed_out, rrms_out
|
||||
|
||||
def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor, amax_state:Tensor,
|
||||
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
# NOTE: h = x + residual; y_normed = rmsnorm(h); fp8 = quantize(y_normed * weight).
|
||||
# Returns (fp8, inv_scale, new_amax, h, x_normed, rrms). h is also written so downstream can
|
||||
# Returns (fp8, new_amax, h, x_normed, rrms). h is also written so downstream can
|
||||
# reuse it without recomputing x+residual — eliminates the separate residual-add kernel.
|
||||
assert x.dtype == dtypes.bfloat16 and residual.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
|
||||
assert x.shape == residual.shape
|
||||
|
|
@ -149,5 +148,4 @@ def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor,
|
|||
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
|
||||
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, x, residual, weight, amax_state,
|
||||
fxn=fxn, grad_fxn=_fused_add_bwd)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out
|
||||
return fp8_out, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out
|
||||
|
|
|
|||
104
extra/llama_kernels/fused_silu_mul_quantize_mxfp8/__init__.py
Normal file
104
extra/llama_kernels/fused_silu_mul_quantize_mxfp8/__init__.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import functools
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
|
||||
|
||||
BLK = 32
|
||||
PACK = 4
|
||||
LOG2E = 1.4426950408889634
|
||||
|
||||
@functools.cache
|
||||
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
|
||||
rows, K = x_w1.shape
|
||||
scale_K = K // BLK
|
||||
n_elems = rows * K
|
||||
n_super = n_elems // (BLK * PACK)
|
||||
sk4 = scale_K // PACK
|
||||
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
|
||||
nwg = n_super // THREADS_PER_WG
|
||||
|
||||
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
|
||||
fp8_out = fp8_out.reshape(n_elems)
|
||||
e8_out = e8_out.reshape(rows * scale_K)
|
||||
si_out = si_out.reshape(sk4 * rows)
|
||||
|
||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
||||
sb = UOp.range(PACK, 2, AxisType.UNROLL)
|
||||
lane = UOp.range(BLK, 3, AxisType.UNROLL)
|
||||
|
||||
super_idx = wg * THREADS_PER_WG + tid
|
||||
idx = super_idx * (BLK * PACK) + sb * BLK + lane
|
||||
|
||||
w1 = x_w1[idx].cast(dtypes.float)
|
||||
w3 = x_w3[idx].cast(dtypes.float)
|
||||
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
|
||||
act = w1 * sig * w3
|
||||
abs_a = (act < 0.0).where(-act, act)
|
||||
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
|
||||
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
|
||||
qscale = (127.0 - e8f).exp2()
|
||||
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
|
||||
e8u8 = e8f.cast(dtypes.uint8)
|
||||
|
||||
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
|
||||
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
|
||||
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
|
||||
row, col4 = super_idx // sk4, super_idx % sk4
|
||||
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
|
||||
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
|
||||
|
||||
@functools.cache
|
||||
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
|
||||
rows, K = x_w1.shape
|
||||
scale_K = K // BLK
|
||||
n_elems = rows * K
|
||||
VEC = 8
|
||||
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
|
||||
nwg = n_elems // (THREADS_PER_WG * VEC)
|
||||
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
|
||||
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
|
||||
|
||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
||||
lane = UOp.range(VEC, 2, AxisType.UNROLL)
|
||||
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
|
||||
|
||||
e8v = e8[idx // BLK].cast(dtypes.float)
|
||||
qscale = (127.0 - e8v).exp2()
|
||||
ga = grad_aq[idx].cast(dtypes.float) * qscale
|
||||
w1 = x_w1[idx].cast(dtypes.float)
|
||||
w3 = x_w3[idx].cast(dtypes.float)
|
||||
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
|
||||
s = w1 * sig
|
||||
sprime = sig * (1.0 + w1 * (1.0 - sig))
|
||||
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
|
||||
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
|
||||
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
|
||||
|
||||
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
|
||||
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
|
||||
device = x_w1.device
|
||||
rows, K = x_w1.shape
|
||||
axis = x_w1.axis if isinstance(device, tuple) else None
|
||||
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
|
||||
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
|
||||
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
|
||||
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
|
||||
fxn=_custom_silu_mul_bwd_mxfp8)
|
||||
return (None, None, None, gx1.uop, gx3.uop)
|
||||
|
||||
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
|
||||
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
|
||||
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
|
||||
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
|
||||
rows, K = x_w1.shape
|
||||
scale_K = K // BLK
|
||||
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
|
||||
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
|
||||
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
|
||||
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
|
||||
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
|
||||
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
|
||||
return fp8_out, e8_out, si_out
|
||||
|
|
@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
|
|||
step = THREADS_PER_WG // 2
|
||||
while step:
|
||||
active = tid < step
|
||||
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
|
||||
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
|
||||
other = lds[(tid + step).valid(active)].load()
|
||||
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
|
||||
step //= 2
|
||||
|
||||
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])
|
||||
|
|
|
|||
71
extra/llama_kernels/quantize_mxfp8_fused/__init__.py
Normal file
71
extra/llama_kernels/quantize_mxfp8_fused/__init__.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
import functools
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
|
||||
|
||||
BLK = 32
|
||||
PACK = 4
|
||||
|
||||
@functools.cache
|
||||
def _custom_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x:UOp) -> UOp:
|
||||
rows, K = x.shape
|
||||
scale_K = K // BLK
|
||||
n_elems = rows * K
|
||||
n_super = n_elems // (BLK * PACK)
|
||||
sk4 = scale_K // PACK
|
||||
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
|
||||
nwg = n_super // THREADS_PER_WG
|
||||
|
||||
x = x.reshape(n_elems)
|
||||
fp8_out = fp8_out.reshape(n_elems)
|
||||
e8_out = e8_out.reshape(rows * scale_K)
|
||||
si_out = si_out.reshape(sk4 * rows)
|
||||
|
||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
||||
sb = UOp.range(PACK, 2, AxisType.UNROLL)
|
||||
lane = UOp.range(BLK, 3, AxisType.UNROLL)
|
||||
|
||||
super_idx = wg * THREADS_PER_WG + tid
|
||||
idx = super_idx * (BLK * PACK) + sb * BLK + lane
|
||||
|
||||
x_f = x[idx].cast(dtypes.float)
|
||||
abs_x = (x_f < 0.0).where(-x_f, x_f)
|
||||
blk_max = abs_x.reduce(lane, arg=Ops.MAX)
|
||||
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
|
||||
qscale = (127.0 - e8f).exp2()
|
||||
scaled = (x_f * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
|
||||
e8u8 = e8f.cast(dtypes.uint8)
|
||||
|
||||
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
|
||||
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
|
||||
|
||||
# pack the 4 e8 of this super-block into one uint32 (little-endian: byte sb), write transposed (sk4, row)
|
||||
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
|
||||
row, col4 = super_idx // sk4, super_idx % sk4
|
||||
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
|
||||
return si_store.end(tid, wg).sink(arg=KernelInfo(f"quantize_mxfp8_{n_elems}", opts_to_apply=()))
|
||||
|
||||
def _quantize_mxfp8_fused_bwd(gradient:UOp, kernel:UOp):
|
||||
_, e8_out, _, x = kernel.src[1:]
|
||||
device = x.device
|
||||
rows, K = x.shape
|
||||
scale_K = K // BLK
|
||||
e8 = Tensor(e8_out, device=device).reshape(rows, scale_K)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, BLK).reshape(rows, K)
|
||||
grad_x = (Tensor(gradient, device=device).float() * qscale).cast(dtypes.bfloat16)
|
||||
return (None, None, None, grad_x.uop)
|
||||
|
||||
def quantize_mxfp8_fused(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
assert x.dtype == dtypes.bfloat16, f"expected bf16, got {x.dtype}"
|
||||
assert x.ndim == 2, f"expected 2d (rows, K), got {x.shape}"
|
||||
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
|
||||
rows, K = x.shape
|
||||
scale_K = K // BLK
|
||||
axis = x.uop.axis if isinstance(x.device, tuple) else None
|
||||
fp8_out = alloc_like((rows, K), FP8_DTYPE, x.device, axis)
|
||||
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x.device, axis)
|
||||
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x.device, None if axis is None else (1 if axis == 0 else 0))
|
||||
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x, fxn=_custom_quantize_mxfp8, grad_fxn=_quantize_mxfp8_fused_bwd)
|
||||
return fp8_out, e8_out, si_out
|
||||
|
|
@ -9,8 +9,8 @@ EXAMPLES = {
|
|||
"empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
|
||||
"plus":"test/test_tiny.py TestTiny.test_plus",
|
||||
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"",
|
||||
"sync":"test/amd/test_custom_kernel.py TestCustomKernel.test_lds_sync",
|
||||
"handwritten":"test/amd/test_custom_kernel.py TestCustomKernel.test_handwritten",
|
||||
"sync":"test/amd/test_asm_kernel.py TestAsmKernel.test_lds_sync",
|
||||
"handwritten":"test/amd/test_asm_kernel.py TestAsmKernel.test_handwritten",
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ def _sign(serial_num, out_dir):
|
|||
final = _build_elf(ehdr, allph, segs)
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
out = os.path.join(out_dir, "testsig-0x%08X.so" % serial_num)
|
||||
out = os.path.join(out_dir, "testsig-0x%08x.so" % serial_num)
|
||||
with open(out, 'wb') as f:
|
||||
f.write(final)
|
||||
print("Signing complete! Output saved at %s" % out)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ def _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, sha
|
|||
return None, None, dq.uop, dk.uop, dv.uop
|
||||
return grad
|
||||
|
||||
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
|
||||
# TODO: remove write_flat once scheduler can remove reshapes between custom_kernel. TestCustomKernel.test_simple_reshape
|
||||
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False, write_flat:bool=False):
|
||||
assert attn_mask is None, "attn_mask not supported"
|
||||
assert is_causal, "only causal attention supported"
|
||||
|
||||
|
|
@ -73,6 +74,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
arch = Device[single_device].renderer.target.arch
|
||||
|
||||
attn = _sharded_empty_like(xq, axis=shard_axis)
|
||||
attn = _sharded_empty((B, N, H * D), xq, axis=shard_axis) if write_flat else _sharded_empty_like(xq, axis=shard_axis)
|
||||
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
|
||||
|
||||
grad = _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ constexpr int NUM_THREADS = WARP_THREADS * NUM_WARPS;
|
|||
|
||||
using G = kittens::group<NUM_WARPS>;
|
||||
|
||||
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr) {
|
||||
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr, bf16 *b_unused) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
static_assert(M % BLOCK_SIZE == 0 && N % BLOCK_SIZE == 0 && K % K_STEP == 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -99,12 +99,14 @@ using G = kittens::group<NUM_WARPS>;
|
|||
#endif
|
||||
|
||||
__global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr
|
||||
#if SCALE_MODE == 1
|
||||
#if SCALE_MODE & 1
|
||||
, float *x_scale_ptr
|
||||
#elif SCALE_MODE == 2
|
||||
#endif
|
||||
#if SCALE_MODE & 2
|
||||
, float *w_scale_ptr
|
||||
#elif SCALE_MODE == 3
|
||||
, float *x_scale_ptr, float *w_scale_ptr
|
||||
#endif
|
||||
#if SCALE_MODE & 4
|
||||
, float *g_scale_ptr
|
||||
#endif
|
||||
) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
|
|
@ -346,20 +348,19 @@ __global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_pt
|
|||
}
|
||||
|
||||
// apply x_scale * w_scale before bf16 store to prevent overflow
|
||||
#if SCALE_MODE == 1
|
||||
float scale = *x_scale_ptr;
|
||||
mul(cA, cA, scale);
|
||||
mul(cB, cB, scale);
|
||||
mul(cC, cC, scale);
|
||||
mul(cD, cD, scale);
|
||||
#elif SCALE_MODE == 2
|
||||
float scale = *w_scale_ptr;
|
||||
mul(cA, cA, scale);
|
||||
mul(cB, cB, scale);
|
||||
mul(cC, cC, scale);
|
||||
mul(cD, cD, scale);
|
||||
#elif SCALE_MODE == 3
|
||||
float scale = *x_scale_ptr * *w_scale_ptr;
|
||||
#if SCALE_MODE != 0
|
||||
float scale = 1.0f;
|
||||
#if SCALE_MODE & 1
|
||||
float x_scale = (*x_scale_ptr + 1e-08f) * (1.0f / 448.0f);
|
||||
scale *= x_scale;
|
||||
#endif
|
||||
#if SCALE_MODE & 2
|
||||
scale *= *w_scale_ptr;
|
||||
#endif
|
||||
#if SCALE_MODE & 4
|
||||
scale *= *g_scale_ptr;
|
||||
#endif
|
||||
|
||||
mul(cA, cA, scale);
|
||||
mul(cB, cB, scale);
|
||||
mul(cC, cC, scale);
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
|
|||
|
||||
using G = kittens::group<NUM_WARPS>;
|
||||
|
||||
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
|
||||
const uint32_t *__restrict__ scale_A_iter,
|
||||
const uint32_t *__restrict__ scale_B_iter) {
|
||||
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr, fp8e8m0 *scale_A_ptr, fp8e8m0 *scale_B_ptr,
|
||||
const uint8_t *__restrict__ a_e8_unused,
|
||||
const uint8_t *__restrict__ b_e8_unused) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
|
||||
kittens::gl<fp8e4m3, 1, 1, M, K> A{A_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
|
@ -38,15 +38,19 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
constexpr int k_iters = K / BLOCK_K;
|
||||
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
|
||||
|
||||
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
|
||||
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
|
||||
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
|
||||
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
|
||||
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
|
||||
kittens::gl<fp8e8m0, k_iters * (M / BLOCK_ROW), 1, 16, 64> scale_A_gl{scale_A_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
kittens::gl<fp8e8m0, k_iters * (N / BLOCK_COL), 1, 16, 64> scale_B_gl{scale_B_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
|
||||
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
|
||||
using ST_Scale = st<fp8e8m0, 16, 64, st_16x64_s>;
|
||||
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
|
||||
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
|
||||
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
|
||||
|
||||
__shared__ ST_A As[2][2];
|
||||
__shared__ ST_B Bs[2][2];
|
||||
__shared__ uint8_t smem_scales[2048];
|
||||
__shared__ ST_Scale scale_A_smem[2], scale_B_smem[2];
|
||||
|
||||
RT_A a;
|
||||
RT_B b0, b1;
|
||||
|
|
@ -79,54 +83,87 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
G::prefill_swizzled_offsets(As[0][0], A, sw_A);
|
||||
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
|
||||
|
||||
const T *a_base = (const T *)&A[{0, 0, 0, 0}];
|
||||
const T *b_base = (const T *)&B[{0, 0, 0, 0}];
|
||||
const int a_row_stride = A.template stride<2>() * sizeof(T);
|
||||
const int b_row_stride = B.template stride<2>() * sizeof(T);
|
||||
i32x4 a_srd = make_srsrc(a_base, M * a_row_stride, a_row_stride);
|
||||
i32x4 b_srd = make_srsrc(b_base, N * b_row_stride, b_row_stride);
|
||||
|
||||
const int wid = warpid() % NUM_WARPS;
|
||||
constexpr int elem_per_warp = (16 / sizeof(T)) * kittens::WARP_THREADS;
|
||||
uint32_t a_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t a_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t a_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t a_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
|
||||
int a_row_h0 = warp_m * REG_M;
|
||||
int a_row_h1 = HALF_ROW + warp_m * REG_M;
|
||||
int b_row_h0 = warp_n * REG_N;
|
||||
int b_row_h1 = HALF_COL + warp_n * REG_N;
|
||||
|
||||
int tic = 0, toc = 1;
|
||||
int tic_scales = 0, toc_scales = 1;
|
||||
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A);
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B);
|
||||
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A);
|
||||
uint32_t a_lds[2][2] = {{a_lds_00, a_lds_01}, {a_lds_10, a_lds_11}};
|
||||
uint32_t b_lds[2][2] = {{b_lds_00, b_lds_01}, {b_lds_10, b_lds_11}};
|
||||
|
||||
if constexpr (k_iters >= 6 && (k_iters % 2 == 0)) {
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
|
||||
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][1]));
|
||||
|
||||
if (warp_m == 1) __builtin_amdgcn_s_barrier();
|
||||
asm volatile("s_waitcnt vmcnt(4)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A);
|
||||
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B);
|
||||
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B);
|
||||
|
||||
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][0]));
|
||||
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][0]));
|
||||
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][1]));
|
||||
asm volatile("s_waitcnt vmcnt(6)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1) {
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
G::load(scale_A_smem[0], scale_A_gl, {0 * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[0], scale_B_gl, {0 * tiles_N + block_col, 0, 0, 0});
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
int sa_idx = block_row, sb_idx = block_col;
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
|
||||
if (k + 1 < k_iters) {
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
|
||||
asm volatile("s_waitcnt lgkmcnt(8)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
||||
load(b1, bs1);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
|
|
@ -135,10 +172,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
||||
load(a, as1);
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B);
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
|
|
@ -148,7 +184,7 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B);
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
|
||||
asm volatile("s_waitcnt vmcnt(6)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
|
|
@ -160,19 +196,24 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
|
||||
{ // Epilogue k = k_iters - 2
|
||||
int k = k_iters - 2;
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
if (k + 1 < k_iters) {
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
||||
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
|
|
@ -215,17 +256,17 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
tic ^= 1; toc ^= 1;
|
||||
tic_scales ^= 1; toc_scales ^= 1;
|
||||
}
|
||||
|
||||
{ // Epilogue k = k_iters - 1 (final)
|
||||
int k = k_iters - 1;
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
||||
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
|
|
@ -261,6 +302,38 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
} else {
|
||||
#pragma unroll 1
|
||||
for (int kk = 0; kk < k_iters; kk++) {
|
||||
G::load(As[0][0], A, {0, 0, block_row * 2, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][0]));
|
||||
G::load(As[0][1], A, {0, 0, block_row * 2 + 1, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][1]));
|
||||
G::load(Bs[0][0], B, {0, 0, block_col * 2, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][0]));
|
||||
G::load(Bs[0][1], B, {0, 0, block_col * 2 + 1, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][1]));
|
||||
G::load(scale_A_smem[0], scale_A_gl, {kk * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[0], scale_B_gl, {kk * tiles_N + block_col, 0, 0, 0});
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[0].data, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[0].data, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[0].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[0].data, b_row_h1);
|
||||
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][0], {warp_n, 0}); load(b0, bs0);
|
||||
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][1], {warp_n, 0}); load(b1, bs1);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[0][0], {warp_m, 0}); load(a, as0);
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
||||
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
||||
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[0][1], {warp_m, 0}); load(a, as1);
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
||||
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
}
|
||||
|
||||
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
||||
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
|
||||
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
||||
|
|
|
|||
|
|
@ -438,6 +438,35 @@ __device__ __forceinline__ T buffer_load_dwordx2(
|
|||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_load_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("buffer_load_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a u32x3 (ext_vector_type(3)). Note sizeof(u32x3) == 16 due to vector
|
||||
// alignment padding, but the compiler still allocates exactly 3 consecutive VGPRs
|
||||
// for the "=v" operand and emits buffer_load_dwordx3 v[N:N+2].
|
||||
template<typename T = u32x3>
|
||||
__device__ __forceinline__ T buffer_load_dwordx3(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
T result;
|
||||
asm volatile("buffer_load_dwordx3 %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
|
|
@ -465,6 +494,116 @@ __device__ __forceinline__ T buffer_load_dwordx4(
|
|||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_ubyte(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_sbyte(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_sbyte %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_ushort(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_sshort(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_sshort %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
// d16 variants merge the loaded sub-dword into the low (or high, for _hi) 16 bits
|
||||
// of the destination register, leaving the other half untouched: read-modify-write.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_ubyte_d16(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_ubyte_d16 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_ubyte_d16_hi(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_ubyte_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_sbyte_d16(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_sbyte_d16 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_sbyte_d16_hi(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_sbyte_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_short_d16(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_short_d16 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_short_d16_hi(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_short_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR>
|
||||
__device__ __forceinline__ void buffer_store_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
|
|
@ -519,6 +658,32 @@ __device__ __forceinline__ void buffer_store_dwordx2(
|
|||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_store_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
// VGPRS
|
||||
} else {
|
||||
asm volatile("buffer_store_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T = u32x3>
|
||||
__device__ __forceinline__ void buffer_store_dwordx3(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
asm volatile("buffer_store_dwordx3 %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
|
|
@ -546,6 +711,50 @@ __device__ __forceinline__ void buffer_store_dwordx4(
|
|||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the low 8 bits of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_byte(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_byte %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the high 16 bits' low 8 bits (bits [23:16]) of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_byte_d16_hi(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_byte_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the low 16 bits of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_short(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_short %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the high 16 bits of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_short_d16_hi(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_short_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR>
|
||||
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR >= 256) {
|
||||
|
|
@ -52,6 +52,8 @@ __device__ __forceinline__ int laneid() { return threadIdx.x & 0x3f; }
|
|||
|
||||
using i32x2 = int32_t __attribute__((ext_vector_type(2)));
|
||||
using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
|
||||
using i32x3 = int32_t __attribute__((ext_vector_type(3)));
|
||||
using u32x3 = uint32_t __attribute__((ext_vector_type(3)));
|
||||
using i32x4 = int32_t __attribute__((ext_vector_type(4)));
|
||||
using u32x4 = uint32_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
10
extra/thunder/amd/include/cdna4/includes.cuh
Normal file
10
extra/thunder/amd/include/cdna4/includes.cuh
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief Aggregate header for CDNA4 ThunderKittens core include tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/common.cuh"
|
||||
#include "types/types.cuh"
|
||||
#include "ops/ops.cuh"
|
||||
|
|
@ -248,7 +248,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
|
|||
"ds_read_b64_tr_b16 %0, %2 offset:%3\n"
|
||||
"ds_read_b64_tr_b16 %1, %2 offset:%4\n"
|
||||
// "s_waitcnt lgkmcnt(0)\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
|
||||
: "v"(addr), "i"(offset), "i"(offset + 4 * ST::underlying_subtile_row_bytes)
|
||||
: "memory"
|
||||
|
|
@ -259,7 +259,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
|
|||
"ds_read_b64_tr_b16 %0, %2 offset:%4\n"
|
||||
"ds_read_b64_tr_b16 %1, %3 offset:%4\n"
|
||||
// "s_waitcnt lgkmcnt(0)\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
|
||||
: "v"(addr), "v"(next_addr), "i"(offset)
|
||||
: "memory"
|
||||
|
|
@ -688,4 +688,88 @@ __device__ inline static void store(ST &dst, const RT &src) {
|
|||
}
|
||||
}
|
||||
|
||||
template<ducks::rt::col_layout RT, ducks::st::all ST>
|
||||
__device__ inline static void load(RT &dst, const ST &src, int col_offset) {
|
||||
static_assert(RT::cols == ST::rows,
|
||||
"col_layout load with col_offset: ST.rows must equal RT::cols");
|
||||
static_assert(RT::width == 1, "col_layout load with col_offset: only width==1 supported");
|
||||
|
||||
using T2 = RT::dtype;
|
||||
using U = ST::dtype;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
|
||||
const int laneid = kittens::laneid();
|
||||
const uint32_t src_ptr = (uint32_t)(uintptr_t)(&src.data[0]);
|
||||
constexpr int subtile_bytes = ST::underlying_subtile_bytes;
|
||||
|
||||
if constexpr (std::is_same_v<T2, fp8e4m3_4>) {
|
||||
const int block_id = laneid / 16;
|
||||
const int l_within = laneid % 16;
|
||||
const int tr_k_grp = l_within / 2;
|
||||
const int m_half = l_within & 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RT::height; i++) {
|
||||
const int m_col = col_offset + i * RT::base_tile_rows + m_half * 8;
|
||||
|
||||
const int subtile_base = block_id;
|
||||
uint32_t addr = src_ptr
|
||||
+ (uint32_t)(subtile_base * subtile_bytes)
|
||||
+ src.swizzle({tr_k_grp, m_col});
|
||||
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
|
||||
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
|
||||
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
|
||||
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[4]))
|
||||
: "v"(addr), "i"(4 * subtile_bytes)
|
||||
: "memory"
|
||||
);
|
||||
|
||||
addr ^= 1088u;
|
||||
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
|
||||
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
|
||||
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2])),
|
||||
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[6]))
|
||||
: "v"(addr), "i"(4 * subtile_bytes)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
} else if constexpr (std::is_same_v<U2, bf16_2> || std::is_same_v<U2, half_2>) {
|
||||
const int row_offset = ((laneid % 16) / 4) + ((laneid / 16) * 4);
|
||||
const int col_offset_lane = (laneid % 4) * 4;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RT::height; i++) {
|
||||
const int m_col = col_offset + i * RT::base_tile_rows + col_offset_lane;
|
||||
const uint32_t addr = src_ptr + src.swizzle({row_offset, m_col});
|
||||
const uint32_t next_addr = src_ptr + src.swizzle({row_offset + 4, m_col});
|
||||
|
||||
if constexpr (RT::base_tile_stride == 8) {
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b16 %0, %2 offset:0\n"
|
||||
"ds_read_b64_tr_b16 %1, %3 offset:0\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
|
||||
"=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2]))
|
||||
: "v"(addr), "v"(next_addr)
|
||||
: "memory"
|
||||
);
|
||||
} else if constexpr (RT::base_tile_stride == 4) {
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b16 %0, %1 offset:0\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0]))
|
||||
: "v"(addr)
|
||||
: "memory"
|
||||
);
|
||||
} else {
|
||||
static_assert(false, "Unsupported stride for col_layout load with col_offset");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(false, "Unsupported type for col_layout load with col_offset");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kittens
|
||||
|
|
@ -201,10 +201,10 @@ __device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::col,
|
|||
* @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
|
||||
template<ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
|
||||
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col, D_shape> &d,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b, // in row-major mode
|
||||
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
|
||||
const rt_base<float, ducks::rt_layout::col, C_shape> &c) {
|
||||
|
||||
static_assert(std::is_same_v<D_shape, C_shape>, "D and C must have the same shape");
|
||||
|
|
@ -243,6 +243,32 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
|
||||
*
|
||||
* @param smem_scales LDS pointer to scale tile data.
|
||||
* @param row_offset Starting row within the scale region (warp's tile offset).
|
||||
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
|
||||
*/
|
||||
__device__ __forceinline__ fp8e8m0_4 pack_scales(
|
||||
const fp8e8m0 *smem_scales, int row_offset) {
|
||||
int lid = laneid();
|
||||
int r16 = lid % 16;
|
||||
int k_sub = lid / 16;
|
||||
|
||||
const fp8e8m0_4 *s4 = (const fp8e8m0_4 *)smem_scales;
|
||||
fp8e8m0_4 w0 = s4[row_offset + 0 * 16 + r16];
|
||||
fp8e8m0_4 w1 = s4[row_offset + 1 * 16 + r16];
|
||||
fp8e8m0_4 w2 = s4[row_offset + 2 * 16 + r16];
|
||||
fp8e8m0_4 w3 = s4[row_offset + 3 * 16 + r16];
|
||||
|
||||
fp8e8m0_4 sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
|
||||
fp8e8m0_4 lo = __builtin_amdgcn_perm(w0, w1, sel);
|
||||
fp8e8m0_4 hi = __builtin_amdgcn_perm(w2, w3, sel);
|
||||
|
||||
return lo | (hi << 16);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Base dot product operation for row layout.
|
||||
*
|
||||
|
|
@ -254,10 +280,10 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
|
|||
* @param[in] b The second input rt_base<Operand_T, row_layout> matrix.
|
||||
* @param[in] c The input rt_base<float, col_layout> accumulator matrix.
|
||||
*/
|
||||
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
|
||||
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
|
||||
__device__ static inline void mma_ABt_base_scaled(rt_base<float, ducks::rt_layout::col, D_shape> &d,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b,
|
||||
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
|
||||
const rt_base<float, ducks::rt_layout::col, C_shape> &c,
|
||||
const fp8e8m0_4 *scale_a,
|
||||
const fp8e8m0_4 *scale_b) {
|
||||
|
|
@ -441,7 +467,7 @@ __device__ static inline void mma_AB(D &d,
|
|||
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
|
||||
template<ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
|
||||
__device__ static inline void mma_ABt(D &d,
|
||||
const A &a,
|
||||
const B &b, // notice row and (M, K) instead of col and (K, M)
|
||||
|
|
@ -499,7 +525,7 @@ __device__ static inline void mma_ABt(D &d,
|
|||
* @param[in] scale_a Pointer to the packed E8M0 scale for the A matrix.
|
||||
* @param[in] scale_b Pointer to the packed E8M0 scale for the B matrix.
|
||||
*/
|
||||
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
|
||||
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
|
||||
__device__ static inline void mma_ABt_scaled(D &d,
|
||||
const A &a,
|
||||
const B &b,
|
||||
|
|
@ -9,6 +9,5 @@
|
|||
#include "maps.cuh"
|
||||
#include "reductions.cuh"
|
||||
#include "mma.cuh"
|
||||
#include "scales.cuh"
|
||||
|
||||
#include "assembly/tile.cuh"
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue