mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into x86_cleanups
This commit is contained in:
commit
fc45e15afb
226 changed files with 14227 additions and 909 deletions
152
.github/workflows/benchmark.yml
vendored
152
.github/workflows/benchmark.yml
vendored
|
|
@ -99,7 +99,6 @@ jobs:
|
|||
ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu
|
||||
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
|
||||
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
|
|
@ -134,32 +133,10 @@ jobs:
|
|||
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
|
||||
- name: Fuzz Padded Tensor Core GEMM
|
||||
run: DEV=METAL M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run quantized LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
|
||||
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
|
||||
- name: Run quantized LLaMA3
|
||||
run: |
|
||||
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
|
||||
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
|
||||
#- name: Run LLaMA 7B on 4 (virtual) GPUs
|
||||
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run OLMoE
|
||||
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
|
||||
- name: Run llama3.2
|
||||
run: BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run olmoe
|
||||
run: BENCHMARK_LOG=olmoe JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m olmoe --benchmark --warmup
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
|
||||
|
||||
|
|
@ -235,9 +212,6 @@ jobs:
|
|||
- name: Symlink models and datasets
|
||||
run: |
|
||||
mkdir -p weights
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
|
||||
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
|
||||
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
|
||||
mkdir -p extra/datasets
|
||||
ln -s /raid/datasets/imagenet extra/datasets/imagenet
|
||||
|
|
@ -279,36 +253,16 @@ jobs:
|
|||
# TODO: too slow
|
||||
# - name: Run SDXL
|
||||
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit DEV=NV JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama DEV=NV JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run llama3.2
|
||||
run: DEV=NV BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run qwen3.5
|
||||
run: DEV=NV BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run quantized LLaMA3
|
||||
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit DEV=NV JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 DEV=NV JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half DEV=NV HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam DEV=NV HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: Speed (NVIDIA)
|
||||
|
|
@ -402,10 +356,7 @@ jobs:
|
|||
run: |
|
||||
mkdir -p weights
|
||||
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
|
||||
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
|
||||
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
|
||||
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
|
||||
mkdir -p extra/datasets
|
||||
ln -s /raid/datasets/imagenet extra/datasets/imagenet
|
||||
|
|
@ -458,18 +409,10 @@ jobs:
|
|||
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 DEV=AMD python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
|
||||
- name: Run SDXL
|
||||
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 DEV=AMD python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA 7B
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit DEV=AMD JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama DEV=AMD JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA 7B with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run llama3.2
|
||||
run: DEV=AMD BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run qwen3.5
|
||||
run: DEV=AMD BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
|
|
@ -478,16 +421,6 @@ jobs:
|
|||
# run: sudo modprobe amdgpu
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral DEV=AMD python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit DEV=AMD JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 DEV=AMD JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half DEV=AMD HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam DEV=AMD HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
|
@ -594,10 +527,10 @@ jobs:
|
|||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testqualcommbenchmark:
|
||||
name: comma Benchmark
|
||||
testcommalatest:
|
||||
name: comma Benchmark (0.11.0)
|
||||
runs-on: [self-hosted, Linux, comma]
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 10
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -e -o pipefail {0}
|
||||
|
|
@ -622,6 +555,27 @@ jobs:
|
|||
run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3.2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx
|
||||
- name: openpilot compile3 0.11.0 dmonitoring
|
||||
run: BENCHMARK_LOG=openpilot_0_11_0_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testcommaold:
|
||||
name: comma Benchmark (0.10.1)
|
||||
runs-on: [self-hosted, Linux, comma]
|
||||
timeout-minutes: 10
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -e -o pipefail {0}
|
||||
if: github.repository_owner == 'tinygrad'
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: DEBUG=2 openpilot compile3 0.10.1 driving_vision
|
||||
run: PYTHONPATH="." DEBUG=2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: openpilot compile3 0.10.1 driving_vision
|
||||
|
|
@ -630,14 +584,44 @@ jobs:
|
|||
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3.2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
|
||||
- name: openpilot compile3 0.10.1 dmonitoring
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testqualcommdsp:
|
||||
name: DSP Benchmark
|
||||
runs-on: [self-hosted, Linux, comma4]
|
||||
timeout-minutes: 5
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -e -o pipefail {0}
|
||||
if: github.repository_owner == 'tinygrad'
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: benchmark MobileNetV2 on DSP
|
||||
run: |
|
||||
# generate quantized weights
|
||||
ln -s /data/home/tiny/tinygrad/extra/datasets/imagenet extra/datasets/imagenet
|
||||
ln -s /data/home/tiny/tinygrad/testsig-*.so .
|
||||
PYTHONPATH=. CC=clang-19 DEV=CPU QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
|
||||
PYTHONPATH=. DEV=CPU QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
|
||||
# benchmark on DSP with NOOPT=1, the devectorizer has issues
|
||||
PYTHONPATH=. CC=clang-19 DEV=DSP NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
|
||||
PYTHONPATH=. DEV=DSP NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
|
|
|||
9
.github/workflows/test.yml
vendored
9
.github/workflows/test.yml
vendored
|
|
@ -218,14 +218,15 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: unittest-13
|
||||
pydeps: "pillow ftfy regex pre-commit"
|
||||
deps: testing_unit
|
||||
llvm: 'true'
|
||||
amd: 'true'
|
||||
- name: Run NULL backend tests
|
||||
run: DEV=NULL python -m pytest -n=auto test/null/ --durations=20
|
||||
- name: Run targeted tests on NULL backend
|
||||
run: DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
run: |
|
||||
DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
DEV=NULL VIZ=1 python3 -m pytest -n=auto test/null/test_viz.py
|
||||
# TODO: too slow
|
||||
# - name: Run SDXL on NULL backend
|
||||
# run: DEV=NULL DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
|
||||
|
|
@ -249,7 +250,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: unittest-13
|
||||
pydeps: "pillow ftfy regex pre-commit"
|
||||
pydeps: "pre-commit"
|
||||
deps: testing_unit
|
||||
llvm: 'true'
|
||||
- name: Run pre-commit test hooks
|
||||
|
|
@ -378,7 +379,7 @@ jobs:
|
|||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp32 (test correctness)
|
||||
run: |
|
||||
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
|||
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
|
||||
|
||||
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
||||
x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None,
|
||||
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
|
||||
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
|
||||
if not fp8:
|
||||
if ASM_GEMM:
|
||||
|
|
@ -58,24 +58,25 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
|
|||
if x_fp8 is None:
|
||||
if FUSED_INPUT_QUANTIZE and amax_x is not None:
|
||||
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed
|
||||
x_fp8, x_scale, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
|
||||
x_fp8, _, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
|
||||
else:
|
||||
x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
|
||||
x_fp8, _, x_new_amax = quantize_fp8(x, amax_state=amax_x)
|
||||
if ASM_GEMM:
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x_fp8, w.T):
|
||||
assert amax_x is not None
|
||||
if COLUMNWISE_WEIGHT_SCALE:
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
|
||||
else:
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
|
||||
return out, x_new_amax, x_fp8
|
||||
return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
|
||||
return (x_fp8.dot(w.T, dtype=dtypes.float) * ((amax_x.float() + 1e-8) / FP8_MAX) * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
|
||||
|
||||
def norm_quantize_matmul(x:Tensor, norm:Tensor, w:Tensor, w_inv_scale:Tensor, eps:float, amax_x:Tensor, grad_amax_state:Tensor):
|
||||
if FUSED_ADD_NORM_MUL_QUANTIZE:
|
||||
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_rmsnorm_mul_quantize_fp8
|
||||
x_fp8, x_inv_scale, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
x_fp8, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
return out, x_normed, rrms, ret
|
||||
x_normed, rrms = rmsnorm(x, eps)
|
||||
out, *ret = matmul(x_normed * norm, w, amax_x=amax_x, w_inv_scale=w_inv_scale, grad_amax_state=grad_amax_state)
|
||||
|
|
@ -85,8 +86,8 @@ def add_norm_quantize_matmul(x:Tensor, residual:Tensor, norm:Tensor, w:Tensor, w
|
|||
grad_amax_state:Tensor|None=None):
|
||||
if FUSED_ADD_NORM_MUL_QUANTIZE:
|
||||
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_add_rmsnorm_mul_quantize_fp8
|
||||
x_fp8, x_inv_scale, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
x_fp8, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
return out, h, x_normed, rrms, ret
|
||||
h = x + residual
|
||||
x_normed, rrms = rmsnorm(h, eps)
|
||||
|
|
@ -98,8 +99,8 @@ def silu_w13_quantize_matmul(x_w13:Tensor, w2:Tensor, s_2:Tensor,
|
|||
grad_amax_xw13:Tensor, grad_amax_xout:Tensor):
|
||||
if FUSED_SILU_W13:
|
||||
from extra.llama_kernels.cast_amax import fused_quantize_fp8_w13
|
||||
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
|
||||
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
|
||||
x2_fp8, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
|
||||
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, amax_x=amax_x2, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
|
||||
return out, ret
|
||||
hidden = x_w13.shape[-1] // 2
|
||||
x_w1, x_w3 = x_w13[..., :hidden], x_w13[..., hidden:]
|
||||
|
|
@ -125,8 +126,10 @@ class FlatTransformer:
|
|||
|
||||
# FeedForward
|
||||
if SPLIT_W13:
|
||||
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim)
|
||||
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim)
|
||||
if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim)
|
||||
else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02)
|
||||
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
|
||||
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
|
||||
else:
|
||||
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
|
||||
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
|
||||
|
|
@ -154,9 +157,10 @@ class FlatTransformer:
|
|||
self._fp8_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
|
||||
self._fp8_next_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
|
||||
|
||||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02):
|
||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
|
||||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
|
||||
if w is None:
|
||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize()
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.dtype import DType, dtypes, AddrSpace
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
|
@ -39,7 +39,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
|
|||
prg = to_program(call.src[0], Device[arg_uops[0].device].renderer)
|
||||
info = prg.arg
|
||||
functions[info.function_name] = prg.src[3].arg
|
||||
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR]
|
||||
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + list(info.vars)
|
||||
statements.append((info.function_name, cargs, info.global_size, info.local_size))
|
||||
|
||||
return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
|
||||
|
|
@ -253,17 +253,18 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
|
|||
symbolic_vars = OrderedDict()
|
||||
for i, (_, args, global_size, _) in enumerate(statements):
|
||||
for j, var in enumerate(args):
|
||||
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
|
||||
if getattr(var, "op", None) is Ops.PARAM and var.addrspace is AddrSpace.ALU and var.arg.name is not None:
|
||||
if var not in symbolic_vars:
|
||||
symbolic_vars[var] = var.arg[0]
|
||||
symbolic_vars[var] = var.expr
|
||||
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
|
||||
statements[i][1][j] = symbolic_vars[var]
|
||||
|
||||
if global_size:
|
||||
for j, dim in enumerate(global_size):
|
||||
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
|
||||
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and \
|
||||
any(s.op is Ops.PARAM and s.addrspace is AddrSpace.ALU for s in dim.src) and any(s.op is Ops.CONST for s in dim.src):
|
||||
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
|
||||
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
|
||||
global_size[j] = f"_{name.expr}[0] + {val.arg}"
|
||||
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
|
|
|
|||
|
|
@ -2630,7 +2630,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
|||
@functools.cache
|
||||
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3) -> UOp:
|
||||
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both
|
||||
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0)
|
||||
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0) + (1 if scale_mode & 4 else 0)
|
||||
scales, extra = args[:n_scales], args[n_scales:]
|
||||
M, K = A.shape[0]*A.shape[1], A.shape[2]
|
||||
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
|
||||
|
|
@ -2752,7 +2752,8 @@ def custom_hk_bf16_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str) -> UOp:
|
|||
assert M % block_m == 0 and N % block_n == 0 and K % block_k == 0, f"invalid bf16 tile {(block_m, block_n, block_k)} for {(M, N, K)}"
|
||||
threads = UOp.special(64 * num_warps, "lidx0")
|
||||
workgroups = UOp.special((M // block_m) * (N // block_n), "gidx0")
|
||||
sink = UOp.sink(C.base, A.base, B.base, threads, workgroups,
|
||||
b_extra = args[0].base if len(args) >= 1 else B.base
|
||||
sink = UOp.sink(C.base, A.base, B.base, b_extra, threads, workgroups,
|
||||
arg=KernelInfo(f"hk_bf16_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K+M*N)*A.dtype.itemsize)))
|
||||
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
|
||||
src = (kittens_path/"gemm_bf16.cpp").read_text()
|
||||
|
|
@ -2787,15 +2788,21 @@ def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
|
|||
assert M % TILE_M == 0 and N % TILE_N == 0 and (batch * rows) % TILE_K == 0, \
|
||||
f"atb shape {a.shape} {b.shape} must produce (M,N,K) multiples of ({TILE_M},{TILE_N},{TILE_K})"
|
||||
is_multi = isinstance(a.device, tuple)
|
||||
reduce_out = False
|
||||
if is_multi:
|
||||
out = Tensor(Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
|
||||
ndev = len(a.device)
|
||||
if a.uop.axis in (0, 1) or b.uop.axis in (0, 1): inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
|
||||
elif b.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M, N // ndev, dtype=a.dtype, device=a.device), 2
|
||||
elif a.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M // ndev, N, dtype=a.dtype, device=a.device), 1
|
||||
else: inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
|
||||
out = Tensor(inv.uop.multi(out_axis), device=a.device)
|
||||
dname = a.device[0]
|
||||
else:
|
||||
out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device)
|
||||
dname = a.device
|
||||
dname = dname.split(":")[0]
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_hk_bf16_atb_gemm, dname=dname))[0]
|
||||
if is_multi: out = out.sum(0)
|
||||
if reduce_out: out = out.sum(0)
|
||||
return out.squeeze(0) if out.ndim == 3 else out
|
||||
|
||||
|
||||
|
|
@ -2807,13 +2814,15 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
out, a, b = inputs[:3]
|
||||
i = 3
|
||||
s_x = inputs[i]; i += 1
|
||||
has_w = n_scales == 2
|
||||
has_w = n_scales >= 2
|
||||
s_w = inputs[i] if has_w else None; i += has_w
|
||||
s_g = inputs[i] if n_scales == 3 else None; i += (n_scales == 3)
|
||||
grad_amax_state = inputs[i] if has_grad_amax else None; i += has_grad_amax
|
||||
w_post = inputs[i] if has_w_post else None
|
||||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
s_x_t = Tensor(s_x, device=a.device)
|
||||
s_w_t = Tensor(s_w, device=a.device) if has_w else None
|
||||
s_g_t = Tensor(s_g, device=a.device) if s_g is not None else None
|
||||
w_post_t = Tensor(w_post, device=a.device) if has_w_post else None
|
||||
g_t = g_t[:a.shape[0]]
|
||||
from extra.llama_kernels.cast_amax import _grad_fp8_mailbox
|
||||
|
|
@ -2838,7 +2847,8 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
store_effect = grad_amax_state.store(new_grad_amax.uop)
|
||||
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
|
||||
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t) if has_w else asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t)
|
||||
if s_g_t is not None: g_scale = g_scale * s_g_t
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: no w_scale
|
||||
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
|
||||
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
|
||||
|
|
@ -2846,7 +2856,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
|
||||
else:
|
||||
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: rescale if not scalar
|
||||
if w_post_t is not None:
|
||||
grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim)))
|
||||
|
|
@ -2899,7 +2909,7 @@ def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool, w_stored:bool=F
|
|||
# ** main gemm function
|
||||
|
||||
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False) -> Tensor:
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False, g_scale:Tensor|None=None) -> Tensor:
|
||||
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
|
||||
counters["used"] += 1
|
||||
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
|
||||
|
|
@ -2946,8 +2956,8 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
|
|||
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
|
||||
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
|
||||
elif a.dtype == FP8_DTYPE:
|
||||
scales = tuple(s for s in (x_scale, w_scale) if s is not None)
|
||||
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0)
|
||||
scales = tuple(s for s in (x_scale, w_scale, g_scale) if s is not None)
|
||||
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0) | (4 if g_scale is not None else 0)
|
||||
extra = ([grad_amax_state] if grad_amax_state is not None else []) + ([w_post_scale] if w_post_scale is not None else [])
|
||||
fxn = functools.partial(custom_hk_fp8_gemm, dname=dname, scale_mode=scale_mode)
|
||||
bw = functools.partial(custom_gemm_bw, n_scales=len(scales), has_grad_amax=grad_amax_state is not None, has_w_post=w_post_scale is not None)
|
||||
|
|
|
|||
|
|
@ -1,145 +0,0 @@
|
|||
from __future__ import annotations
|
||||
import time
|
||||
from typing import cast
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.jit import GraphRunner
|
||||
from tinygrad.engine.realize import get_call_outs_ins, get_runtime
|
||||
from tinygrad.helpers import round_up, ceildiv
|
||||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, graph_rewrite
|
||||
from extra.hcq2.hcq2 import HCQ2Compiled, HCQ2DeviceCtx, HCQ2LowerCtx, pm_prep_runtime, pm_lower_ops
|
||||
from extra.hcq2.hcq2 import pm_split_into_queues, pm_add_barriers, pm_add_signals
|
||||
from extra.hcq2.hcq2 import pm_bufferize, pm_lift_patches_to_cmdbuf, pm_resolve_patches, pm_parametrize_host_buffers
|
||||
from extra.hcq2.hcq2 import pm_add_timeline_inc, pm_callify, pm_calc_kernargs_sizes
|
||||
|
||||
# **************** insert deps ****************
|
||||
|
||||
def insert_deps(ctx:HCQ2Graph, linear:UOp) -> UOp:
|
||||
src = []
|
||||
for j, call in enumerate(linear.src):
|
||||
call = call.replace(tag=j)
|
||||
_, _, bufs, _ = ctx.calls[j]
|
||||
outs, ins = get_call_outs_ins(call)
|
||||
deps = ctx._access_resources([bufs[i] for i in outs + ins], list(range(len(outs))), call)
|
||||
src.append(UOp(Ops.AFTER, call.dtype, (call, *deps), tag=call.tag))
|
||||
return linear.replace(src=tuple(src))
|
||||
pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)])
|
||||
|
||||
pm_replace_params = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
|
||||
(UPat(Ops.SLICE, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
|
||||
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
|
||||
])
|
||||
|
||||
# **************** graph-only passes ****************
|
||||
|
||||
def alloc_queue_sig(ctx:HCQ2Graph, q:UOp) -> None:
|
||||
if q.arg in ctx.queue_sigs: return None
|
||||
dev = q.arg[0][0] # TODO: multi device
|
||||
buf = Buffer(dev, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
|
||||
ctx.queue_sig_bufs.append(buf)
|
||||
ctx.queue_sigs[q.arg] = UOp.from_buffer(buf, dev)
|
||||
return None
|
||||
pm_alloc_queue_sigs = PatternMatcher([(UPat(Ops.LINEAR, src=UPat({Ops.PROGRAM, Ops.COPY}), name="q"), alloc_queue_sig)])
|
||||
|
||||
def lower_queue_deps(ctx:HCQ2Graph, after:UOp) -> UOp:
|
||||
wrapper, deps, call_idx = after.src[0], after.src[1:], after.tag
|
||||
def store(q_arg, v): return ctx.queue_sigs[q_arg].store(UOp.const(dtypes.uint32, v))
|
||||
waits = tuple(UOp(Ops.WAIT, dtypes.void, (ctx.queue_sigs[dep.src[0].arg], UOp.const(dtypes.uint32, dep.tag),
|
||||
store(dep.src[0].arg, dep.tag))) for dep in deps)
|
||||
return wrapper.replace(src=tuple(q.replace(src=(*waits, *q.src, store(q.arg, call_idx))) for q in wrapper.src))
|
||||
pm_lower_queue_deps = PatternMatcher([(UPat(Ops.AFTER, src=UPat(Ops.LINEAR), name="after"), lower_queue_deps)])
|
||||
|
||||
def optimize_queue_deps(ctx:HCQ2Graph, queue:UOp) -> UOp|None:
|
||||
src, seen, pending, queue_sig = [], {}, {}, ctx.queue_sigs[queue.arg]
|
||||
for x in queue.src:
|
||||
if x.op is Ops.WAIT:
|
||||
sig, val = x.src[0], x.src[1]
|
||||
if sig is queue_sig or seen.get(sig, -1) >= val.arg: continue
|
||||
if (old:=pending.get(sig)) is None or old.src[1].arg < val.arg: pending[sig] = x
|
||||
continue
|
||||
for wait in pending.values():
|
||||
src.append(wait)
|
||||
seen[wait.src[0]] = wait.src[1].arg
|
||||
pending.clear()
|
||||
src.append(x)
|
||||
src += pending.values()
|
||||
return queue.replace(src=tuple(src)) if tuple(src) != queue.src else None
|
||||
pm_optimize_queue_deps = PatternMatcher([
|
||||
(UPat(Ops.LINEAR, src=UPat({Ops.BARRIER, Ops.WAIT, Ops.STORE, Ops.PROGRAM, Ops.COPY}), name="queue"), optimize_queue_deps),
|
||||
])
|
||||
|
||||
def drop_dead_stores(ctx:HCQ2Graph, outer:UOp) -> UOp:
|
||||
live = {u.src[2] for u in outer.toposort() if u.op is Ops.WAIT}
|
||||
return outer.replace(src=tuple(q.replace(src=tuple(x for x in q.src if x.op is not Ops.STORE or x in live)) for q in outer.src))
|
||||
pm_drop_dead_stores = PatternMatcher([(UPat(Ops.LINEAR, src=UPat(Ops.LINEAR), name="outer"), drop_dead_stores)])
|
||||
|
||||
def add_queue_sig_resets(ctx:HCQ2Graph, x:UOp, cmdbuf:UOp) -> UOp|None:
|
||||
if not ctx.queue_sig_bufs or cmdbuf.tag not in ("compute", "copy"): return None
|
||||
resets = tuple((b:=UOp.from_buffer(sig)).index(UOp.const(dtypes.int, 0), dtype=b.dtype.ptr())
|
||||
.cast(dtypes.uint64.ptr()).store(UOp.const(dtypes.uint64, 0)) for sig in ctx.queue_sig_bufs)
|
||||
return x.replace(src=x.src + resets)
|
||||
pm_add_queue_sig_resets = PatternMatcher([(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, name="cmdbuf"),), allow_any_len=True, name="x"),
|
||||
add_queue_sig_resets)])
|
||||
|
||||
# **************** Graph ****************
|
||||
|
||||
class HCQ2Graph(GraphRunner):
|
||||
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
|
||||
super().__init__(linear, input_uops)
|
||||
self.dev = cast(HCQ2Compiled, Device[self.device])
|
||||
self.hcq_ctx = HCQ2LowerCtx(name="hcq_graph")
|
||||
|
||||
self.input_addrs = Buffer("CPU", max(len(input_uops), 1), dtypes.uint64, preallocate=True)
|
||||
self.input_addrs_uop = UOp.from_buffer(self.input_addrs, "CPU")
|
||||
|
||||
self.linear = graph_rewrite(self.linear, pm_insert_deps, ctx=self, name="hcq: insert deps", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_replace_params, ctx=self, name="hcq: replace params", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_prep_runtime, ctx=self.hcq_ctx, name="hcq: prepare runtime")
|
||||
self.linear = graph_rewrite(self.linear, pm_lower_ops, ctx=self.hcq_ctx, name="hcq: lower ops")
|
||||
|
||||
# per-queue signal state — populated as a side-effect by pm_alloc_queue_sigs walking the lowered linear.
|
||||
self.queue_sig_bufs:list[Buffer] = []
|
||||
self.queue_sigs:dict[tuple[str, str], UOp] = {}
|
||||
graph_rewrite(self.linear, pm_alloc_queue_sigs, ctx=self, name="hcq: alloc queue sigs", walk=True)
|
||||
|
||||
self.linear = graph_rewrite(self.linear, pm_lower_queue_deps, ctx=self, name="hcq: lower queue deps")
|
||||
self.linear = graph_rewrite(self.linear, pm_split_into_queues, ctx=self.hcq_ctx, name="hcq: split into queues")
|
||||
self.linear = graph_rewrite(self.linear, pm_add_barriers, ctx=self.hcq_ctx, name="hcq: add barriers", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_optimize_queue_deps, ctx=self, name="hcq: optimize queue deps", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_drop_dead_stores, ctx=self, name="hcq: drop dead stores")
|
||||
self.linear = graph_rewrite(self.linear, pm_add_signals, ctx=self.hcq_ctx, name="hcq: add signals", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_add_timeline_inc, ctx=self.hcq_ctx, name="hcq: add submit", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, self.dev.pm_lower, ctx=self.hcq_ctx, name=f"hcq: encode cmdbuf {self.dev.device}", walk=True)
|
||||
|
||||
graph_rewrite(self.linear, pm_calc_kernargs_sizes, ctx=(sizes:={}), name=None)
|
||||
for dev_name, sz in sizes.items():
|
||||
buf = Buffer(dev_name, sz, dtypes.uint8, options=BufferSpec(cpu_access=True), preallocate=True)
|
||||
self.hcq_ctx.dev_ctx[dev_name] = HCQ2DeviceCtx(dev_name, UOp.from_buffer(buf, dev_name), UOp.const(dtypes.uint64, buf._buf.va_addr))
|
||||
|
||||
self.linear = graph_rewrite(self.linear, pm_bufferize, ctx=self.hcq_ctx, bottom_up=True, name="realize binaries")
|
||||
self.linear = graph_rewrite(self.linear, pm_lift_patches_to_cmdbuf, ctx=self.hcq_ctx, bottom_up=False, name="lift patches to cmdbuf")
|
||||
self.linear = graph_rewrite(self.linear, pm_resolve_patches, ctx=self.hcq_ctx, bottom_up=False, name="simplify patches")
|
||||
self.linear = graph_rewrite(self.linear, pm_add_queue_sig_resets, ctx=self, name="hcq: add queue sig resets", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_parametrize_host_buffers, ctx=self.hcq_ctx, bottom_up=True, name="parametrize host buffers")
|
||||
self.host_call = graph_rewrite(self.linear, pm_callify, ctx=self.hcq_ctx, name="hcq: callify")
|
||||
|
||||
self.host_rt, self.host_globals = get_runtime("CPU", self.host_call.src[0]), self.host_call.src[0].arg.globals
|
||||
|
||||
def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False) -> float|None:
|
||||
addrs = self.input_addrs.as_memoryview(force_zero_copy=True).cast('Q')
|
||||
for i, u in enumerate(input_uops):
|
||||
buf = next(b for b in u.buffer.bufs if b.device == self.dev.device) if isinstance(u.buffer, MultiBuffer) else u.buffer
|
||||
addrs[i] = buf._buf.va_addr
|
||||
self.host_rt(*[self.hcq_ctx.inputs[i].get_buf("CPU") for i in self.host_globals], vals=self.host_call.src[0].arg.vals(var_vals), wait=True)
|
||||
if wait:
|
||||
st = time.perf_counter()
|
||||
self.dev.synchronize()
|
||||
return time.perf_counter() - st
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool:
|
||||
all_devs = GraphRunner._all_devs(batch_devs, new_call)
|
||||
return new_call.src[0].op in (Ops.PROGRAM, Ops.COPY) and len(all_devs) == 1 and isinstance(all_devs[0], HCQ2Compiled)
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast, Callable, TypeVar, Generic, Any
|
||||
import struct, functools, time, collections, importlib, itertools, weakref
|
||||
from dataclasses import replace
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, pluralize, to_tuple
|
||||
from dataclasses import replace, dataclass, field
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, DEBUG, dedup, flatten, pluralize
|
||||
from tinygrad.helpers import to_tuple, round_up
|
||||
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
|
||||
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
|
||||
from tinygrad.uop.symbolic import symbolic_simple, symbolic
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
|
||||
|
|
@ -28,7 +27,7 @@ class HCQ2Compiled(Compiled):
|
|||
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value()),
|
||||
(UPat(Ops.BUFFER, tag="sentinel_signal"), lambda ctx: ctx.timeline_signal("sentinel", (1 << 64) - 1)),
|
||||
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
|
||||
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
|
||||
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=False, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
|
||||
])
|
||||
|
||||
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
|
||||
|
|
@ -137,9 +136,9 @@ def unwrap_after(uop):
|
|||
while uop.op is Ops.AFTER: uop = uop.src[0]
|
||||
return uop
|
||||
|
||||
def make_getaddr(u, dev=None):
|
||||
def make_getaddr(u, device=None):
|
||||
if unwrap_after(u).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return u
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=dev or to_tuple(u.device)[0])))
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=device or to_tuple(u.device)[0])))
|
||||
|
||||
def make_ins(op, *srcs):
|
||||
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
||||
|
|
@ -149,7 +148,7 @@ def make_cmdbuf(lin, devs, tag):
|
|||
for s in (s for ins in lin.src for s in ins.src):
|
||||
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
||||
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
||||
buf = UOp.new_buffer(devs if len(devs) > 1 else devs[0], len(blob), dtypes.uint8).rtag(tag)
|
||||
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
|
||||
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
|
||||
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
|
||||
|
||||
|
|
@ -172,6 +171,10 @@ class HCQInfo:
|
|||
name:str = ""
|
||||
estimates:Estimates = Estimates()
|
||||
outs:tuple[int, ...] = ()
|
||||
devs:tuple[str, ...] = ()
|
||||
|
||||
params:tuple[int, ...] = ()
|
||||
inputs:int|None = None
|
||||
|
||||
@staticmethod
|
||||
def from_call(call:UOp) -> HCQInfo: return HCQInfo(get_call_name(call, get_call_arg_uops(call)), estimate_uop(call), get_call_outs_ins(call)[0])
|
||||
|
|
@ -269,6 +272,14 @@ class DepsCtx:
|
|||
deps:DepsTracker = field(default_factory=DepsTracker)
|
||||
opid:itertools.count = field(default_factory=lambda: itertools.count(0))
|
||||
last_per_queue:weakref.WeakValueDictionary[tuple[Any, str], UOp] = field(default_factory=weakref.WeakValueDictionary)
|
||||
params:dict[tuple[int, int], Buffer] = field(default_factory=dict)
|
||||
|
||||
def get_dep_buf(ctx:DepsCtx, u:UOp, lane:int) -> Buffer:
|
||||
# TODO: should this be a part of DepsTracker?
|
||||
if u.op is Ops.PARAM: return ctx.params.setdefault((u.arg.slot, lane), Buffer("NULL", u.max_numel(), u.dtype.base))
|
||||
if u.op is Ops.MSTACK: return get_dep_buf(ctx, u.src[lane], 0)
|
||||
if u.op in (Ops.SLICE, Ops.MSELECT): return get_dep_buf(ctx, u.src[0], u.arg if u.op is Ops.MSELECT else lane)
|
||||
return b.bufs[lane] if isinstance(b:=u.buffer, MultiBuffer) else b
|
||||
|
||||
def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||
new_src = []
|
||||
|
|
@ -277,15 +288,25 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
|||
new_src.append(call)
|
||||
continue
|
||||
|
||||
q = get_submit(call.src[0]).src[0]
|
||||
new_q = ctx.last_per_queue[q.arg] = q.rtag(next(ctx.opid))
|
||||
new_q = ctx.last_per_queue[q.arg] = (q:=get_submit(call.src[0]).src[0]).rtag(next(ctx.opid))
|
||||
qdevs, refs = to_tuple(new_q.arg[0]), get_call_arg_uops(call)
|
||||
|
||||
deps = []
|
||||
refs = [b.buffer for b in get_call_arg_uops(call)]
|
||||
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
|
||||
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.arg.aux.outs, new_q)
|
||||
# per-lane deps, tracked per (device, queue). skip self
|
||||
dep_lanes:list[tuple[UOp, int]] = []
|
||||
for lane, d in enumerate(qdevs):
|
||||
for dep in ctx.deps.access_resources([get_dep_buf(ctx, b, lane) for b in refs], call.arg.aux.outs, new_q.replace(arg=(d, new_q.arg[1]))):
|
||||
if dep.tag != new_q.tag: dep_lanes.append((dep, lane))
|
||||
|
||||
new_q = new_q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else new_q
|
||||
# drop self-queue waits, queue self-orders
|
||||
if qdevs[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
|
||||
dep_lanes = [(dep, lane) for dep, lane in dep_lanes if dep.arg != (qdevs[lane], new_q.arg[1])]
|
||||
|
||||
# keep latest dep per lane, group lanes
|
||||
latest = {(dep.arg, lane): dep for dep, lane in sorted(dep_lanes, key=lambda x: x[0].tag)}
|
||||
deps:dict[UOp, tuple[int, ...]] = collections.defaultdict(tuple)
|
||||
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
|
||||
|
||||
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
|
||||
return linear.replace(src=tuple(new_src))
|
||||
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
||||
|
|
@ -298,11 +319,16 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
|
|||
zero = UOp.const(dtypes.int, 0)
|
||||
tl = make_signal_value(devs)
|
||||
|
||||
submit = make_submit(make_signal(devs).store(tl.index(zero) + 1), devs=devs, queue="COMPUTE:0")
|
||||
# queue is inc with deps
|
||||
submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0")
|
||||
|
||||
# split each (multi-device) queue into per-device deps so each finalizer lane waits on the matching device's signal
|
||||
lane_queues = [(q.replace(arg=(d, q.arg[1])), (devs.index(d),)) for q in queues for d in to_tuple(q.arg[0])]
|
||||
submit = submit.replace(src=(submit.src[0].after(*(q for q, _ in lane_queues), arg=tuple(l for _, l in lane_queues)).rtag("deps"),))
|
||||
|
||||
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
|
||||
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]) \
|
||||
.sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
|
||||
patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]
|
||||
return UOp.barrier(*patches).sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
|
||||
|
||||
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||
parts:dict[str, list[UOp]] = collections.defaultdict(list)
|
||||
|
|
@ -319,21 +345,22 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
|
|||
cur_devs = to_tuple((cur:=deps.src[0]).arg[0])
|
||||
|
||||
waits = []
|
||||
for dep in deps.src[1:]:
|
||||
devs, queue = dep.arg
|
||||
for lanes, dep in zip(deps.arg, deps.src[1:]):
|
||||
dep_dev, queue = dep.arg # dep_dev is a single device (deps are recorded per-device)
|
||||
ctx.add(dep.tag) # mark op to update signal.
|
||||
|
||||
sig = make_mstack([make_signal(d, queue=queue, sentinel=d not in devs) for d in cur_devs])
|
||||
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0))
|
||||
# for lanes that need this dep, wait on the dep device's signal/value; other lanes get a passing sentinel
|
||||
lanes = set(lanes)
|
||||
sig = make_mstack([make_signal(dep_dev if j in lanes else d, queue=queue, sentinel=j not in lanes) for j, d in enumerate(cur_devs)])
|
||||
val = make_mstack([make_signal_value(dep_dev if j in lanes else d, queue=queue) for j, d in enumerate(cur_devs)]).index(UOp.const(dtypes.int, 0))
|
||||
waits.append(sig.wait(val + dep.tag))
|
||||
return cur.replace(src=tuple(waits) + cur.src)
|
||||
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
|
||||
|
||||
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
|
||||
src = q.src
|
||||
if q.tag in ctx:
|
||||
devs, queue = q.arg
|
||||
src += (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
|
||||
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp|None:
|
||||
if q.tag not in ctx: return None
|
||||
devs, queue = q.arg
|
||||
src = q.src + (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
|
||||
return submit.replace(src=(q.replace(src=src, tag=None),))
|
||||
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
|
||||
|
||||
|
|
@ -342,34 +369,36 @@ pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", s
|
|||
|
||||
def get_submit(ast:UOp) -> UOp: return next(u for u in ast.toposort() if u.op is Ops.CUSTOM_FUNCTION and u.arg == "submit")
|
||||
|
||||
def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
|
||||
old_submit, new_submit = get_submit(old_sink), get_submit(new_sink)
|
||||
old_queue, new_queue = old_submit.src[0], new_submit.src[0]
|
||||
merged_submit = new_submit.replace(src=(new_queue.replace(src=old_queue.src + new_queue.src),))
|
||||
old_root = old_sink.src[0].substitute({old_submit: merged_submit})
|
||||
new_anchor = merged_submit if old_sink.src[0] is old_submit else old_root
|
||||
return new_sink.substitute({new_submit: new_anchor})
|
||||
def merge_sink(sinks:list[UOp]) -> UOp:
|
||||
if len(sinks) == 1: return sinks[0]
|
||||
submits = [get_submit(sink) for sink in sinks]
|
||||
queues = [submit.src[0] for submit in submits]
|
||||
anchor = submits[-1].replace(src=(queues[-1].replace(src=tuple(x for q in queues for x in q.src)),))
|
||||
for sink, submit in zip(sinks[:-1], submits[:-1]):
|
||||
if sink.src[0] is not submit: anchor = sink.src[0].substitute({submit: anchor}, walk=True)
|
||||
return sinks[-1].substitute({submits[-1]: anchor}, walk=True)
|
||||
|
||||
def merge_queues(linear:UOp) -> UOp:
|
||||
new_src:list[UOp] = []
|
||||
opened_qs:dict[tuple[tuple[str, ...], str], tuple[UOp, HCQInfo]] = {} # (devs, queue) -> (sink, aux), kept in submit order
|
||||
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
|
||||
|
||||
for call in linear.src:
|
||||
if call.tag != "hcq":
|
||||
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in list(opened_qs)] + [call]
|
||||
# finalizer cannot be merged, since it bumps inner signal (this introduces race when multidevs).
|
||||
if call.tag != "hcq" or (call.tag == "hcq" and call.arg.aux.name == "hcq finalizer"):
|
||||
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
|
||||
continue
|
||||
|
||||
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
|
||||
aux = call.arg.aux
|
||||
new_rec = ([new_sink], call.arg.aux)
|
||||
if (old:=opened_qs.pop((devs, queue), None)) is not None:
|
||||
new_sink = merge_sinks(old[0], new_sink) # exact same queue: merge, and re-insert at the end
|
||||
aux = replace(aux, name=f"{queue.lower()} submit", estimates=old[1].estimates + aux.estimates)
|
||||
new_rec = (old[0] + [new_sink], replace(new_rec[1], name=f"{queue.lower()} submit", estimates=old[1].estimates + new_rec[1].estimates))
|
||||
else:
|
||||
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
|
||||
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
|
||||
opened_qs[(devs, queue)] = (new_sink, aux)
|
||||
closing = [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]
|
||||
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in closing]
|
||||
opened_qs[(devs, queue)] = new_rec
|
||||
|
||||
return linear.replace(src=tuple(new_src + [sink.call(aux=aux).rtag('hcq') for sink, aux in opened_qs.values()]))
|
||||
return linear.replace(src=tuple(new_src + [merge_sink(sinks).call(aux=aux).rtag("hcq") for sinks, aux in opened_qs.values()]))
|
||||
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
|
||||
|
||||
# *****************
|
||||
|
|
@ -384,6 +413,29 @@ def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
|
|||
return submit.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), wait, *q.src)),))
|
||||
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
|
||||
|
||||
# *****************
|
||||
# 4.3. annotate exec devs
|
||||
|
||||
pm_annotate_devs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
|
||||
lambda call: call.replace(arg=replace(call.arg, aux=replace(call.arg.aux, devs=get_submit(call.src[0]).src[0].arg[0]))))])
|
||||
|
||||
# *****************
|
||||
# 4.4. replace params with per-submit input address loads
|
||||
|
||||
def replace_params(call:UOp) -> UOp|None:
|
||||
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is AddrSpace.GLOBAL}): return None
|
||||
|
||||
# fill new info
|
||||
hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call)))
|
||||
|
||||
inputs = UOp.new_buffer(get_submit(call.src[0]).src[0].arg[0], len(hcqinfo.params), dtypes.uint64).rtag("inputs")
|
||||
|
||||
slot2idx = {s:i for i,s in enumerate(hcqinfo.params)}
|
||||
body = call.src[0].substitute({u:inputs.index(UOp.const(dtypes.int, slot2idx[s])).load() for u,s in params.items()})
|
||||
|
||||
return call.replace(src=(body, *call.src[1:], inputs), arg=replace(call.arg, aux=hcqinfo))
|
||||
pm_replace_params = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), replace_params)])
|
||||
|
||||
# *****************
|
||||
# 5.1. encode cmdbufs
|
||||
|
||||
|
|
@ -411,7 +463,28 @@ pm_lift_patches_to_cmdbuf = PatternMatcher([
|
|||
])
|
||||
|
||||
# *****************
|
||||
# 5.3. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
||||
# 5.3. pack placeholders buffers
|
||||
|
||||
def pack_hcq_placeholders(call:UOp) -> UOp|None:
|
||||
bufs = [b for b in call.src[0].toposort() if b.op is Ops.BUFFER and b.tag in (maxtags:={"scratch"}) | (sumtags:={"program", "kernargs"})]
|
||||
|
||||
off_per_buf:dict[UOp, int] = {}
|
||||
size_per_tag:dict[str, int] = {}
|
||||
for b in bufs:
|
||||
if b.tag in maxtags: size_per_tag[b.tag] = max(size_per_tag.get(b.tag, 0), b.arg)
|
||||
elif b.tag in sumtags:
|
||||
off_per_buf[b] = round_up(size_per_tag.get(b.tag, 0), {"program": 0x1000}.get(b.tag, 128))
|
||||
size_per_tag[b.tag] = off_per_buf[b] + b.arg
|
||||
|
||||
count_per_tag = collections.Counter(b.tag for b in bufs)
|
||||
ref_bufs = {b.tag:b for b in bufs if count_per_tag[b.tag] > 1}
|
||||
bases = {tag:UOp.new_buffer(b.src[1].arg, size_per_tag[tag], b.dtype).rtag(tag) for tag,b in ref_bufs.items()}
|
||||
subs = {b:UOp(Ops.SLICE, b.dtype, (bases[b.tag], UOp.const(dtypes.weakint, off_per_buf.get(b, 0))), b.arg) for b in bufs if b.tag in bases}
|
||||
return call.replace(src=(call.src[0].substitute(subs, walk=True), *call.src[1:])) if subs else None
|
||||
pm_pack_placeholders = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), pack_hcq_placeholders)])
|
||||
|
||||
# *****************
|
||||
# 5.4. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
||||
|
||||
def hold_call_buffers(call:UOp) -> UOp|None:
|
||||
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
|
||||
|
|
@ -423,7 +496,7 @@ pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
|
|||
|
||||
def bufferize_buf(buf:UOp) -> UOp|None:
|
||||
if buf.tag is None: return None
|
||||
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), dev) for dev in to_tuple(buf.src[1].arg))
|
||||
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), "CPU") for dev in to_tuple(buf.src[1].arg))
|
||||
return make_mstack(uops)
|
||||
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
|
||||
|
||||
|
|
@ -434,31 +507,42 @@ def push_stack(op, s): return UOp(Ops.STACK, op.dtype.scalar().vec(len(s.src)),
|
|||
tuple(op.replace(dtype=op.dtype.scalar(), src=tuple(x if y is s else y for y in op.src)) for x in s.src))
|
||||
|
||||
def fold_blob_store(buf:UOp, blob:UOp) -> UOp:
|
||||
for b in (buf.src if buf.op is Ops.MSTACK else (buf,)): b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
|
||||
for b in (mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)): b.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
|
||||
return UOp(Ops.NOOP)
|
||||
|
||||
def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
|
||||
for b, v in zip((buf.src if buf.op is Ops.MSTACK else (buf,)), (val.src if val.op is Ops.STACK else (val,))):
|
||||
struct.pack_into(f'<{v.dtype.fmt}', b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * b.dtype.base.itemsize, v.arg)
|
||||
for b, v in zip((bs:=mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)), val.src if val.op is Ops.STACK else (val,)*len(bs)):
|
||||
struct.pack_into(f'<{v.dtype.fmt}', b.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * buf.dtype.base.itemsize, v.arg)
|
||||
return UOp(Ops.NOOP)
|
||||
|
||||
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
|
||||
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr)
|
||||
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs))
|
||||
if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf
|
||||
devs, b = to_tuple(g.src[1].arg), buf.buffer
|
||||
bufs = tuple(cast(Buffer, x.buffer) for x in buf.src) if buf.op is Ops.MSTACK else tuple(b.bufs if isinstance(b, MultiBuffer) else (b,)*len(devs))
|
||||
assert len(bufs) == len(devs), f"can't resolve {len(bufs)} buffers on {len(devs)} devices"
|
||||
addrs = tuple(UOp.const(dtypes.uint64, x.get_buf(d).va_addr) for x, d in zip(bufs, devs))
|
||||
return addrs[0] if len(addrs) == 1 else UOp(Ops.STACK, dtypes.uint64.vec(len(addrs)), addrs)
|
||||
|
||||
def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp:
|
||||
itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * itemsize)
|
||||
|
||||
pm_resolve_patches = PatternMatcher([
|
||||
# multi
|
||||
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
||||
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
||||
|
||||
# index on slice is index
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
|
||||
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
|
||||
|
||||
# getaddr
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||
lambda bv, dev: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
|
||||
(UPat(Ops.GETADDR, src=(UPat({Ops.BUFFER, Ops.MSTACK, Ops.MSELECT}, name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||
(UPat(Ops.GETADDR, src=(UPat(name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
|
||||
|
||||
# folders
|
||||
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
||||
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
||||
fold_const_store),
|
||||
]) + symbolic_simple
|
||||
|
||||
|
|
@ -466,19 +550,25 @@ pm_resolve_patches = PatternMatcher([
|
|||
# 8. callify hcq programs
|
||||
|
||||
def to_param(bufs:list[UOp], ref:UOp) -> UOp:
|
||||
bufs.append(ref)
|
||||
return UOp.placeholder((ref.buffer.size,), ref.dtype, len(bufs)-1)
|
||||
if ref not in bufs: bufs.append(ref)
|
||||
return UOp.placeholder((ref.buffer.size,), ref.dtype, bufs.index(ref))
|
||||
pm_to_param = PatternMatcher([(UPat({Ops.MSELECT, Ops.MSTACK, Ops.BUFFER}, name="r"), lambda ctx, r: to_param(ctx, r))])
|
||||
|
||||
def parametrize_host_buffers(call:UOp) -> UOp:
|
||||
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=[]), bottom_up=True, name="parametrize host buffers")
|
||||
return call.replace(src=(body, *bufs) + call.src[1:], tag="hcq_param")
|
||||
# preserve original order of args
|
||||
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=list(get_call_arg_uops(call))), bottom_up=True, name="parametrize host buffers")
|
||||
|
||||
# move vars to new slots
|
||||
var_slots = {nm:len(bufs)+i for i,nm in enumerate(sorted({v.expr for v in body.variables() if v.op is Ops.PARAM}))}
|
||||
body = body.substitute({v:v.replace(arg=replace(v.arg, slot=var_slots[v.expr])) for v in body.variables() if v.op is Ops.PARAM})
|
||||
|
||||
return call.replace(src=(body, *bufs) + tuple(x for x in call.src[1:] if x.op is Ops.BIND))
|
||||
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
|
||||
|
||||
def callify_hcq(call:UOp) -> UOp:
|
||||
prg = to_program(call.src[0].sink(arg=KernelInfo("hcq_submit"), tag=1), Device["CPU"].renderer)
|
||||
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(prg,), arg="hcq").call(*call.src[1:], aux=call.arg.aux)
|
||||
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq_param", name="call"), callify_hcq)])
|
||||
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), callify_hcq)])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
|
||||
def hcq_schedule(linear:UOp) -> UOp:
|
||||
|
|
@ -492,14 +582,17 @@ def hcq_schedule(linear:UOp) -> UOp:
|
|||
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_merge_queues, name="merge queues")
|
||||
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_annotate_devs, name="annotate devs")
|
||||
linear = graph_rewrite(linear, pm_replace_params, name="replace params")
|
||||
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_pack_placeholders, walk=True, name="pack placeholders")
|
||||
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
|
||||
|
||||
# realize starts from here
|
||||
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, walk=True, name="bufferize placeholders", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_parametrize_host_buffers, name="parametrize host buffers")
|
||||
linear = graph_rewrite(linear, pm_parametrize_host_buffers, walk=True, name="parametrize host buffers")
|
||||
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
|
||||
|
||||
return linear
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast
|
||||
from typing import cast, Any, Callable
|
||||
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -89,25 +89,25 @@ def memory_barrier(ctx):
|
|||
reg_done=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff),
|
||||
acquire_mem(ctx)))
|
||||
|
||||
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.device))
|
||||
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.devs))
|
||||
|
||||
def pm4_barrier(ctx): return memory_barrier(ctx)
|
||||
|
||||
def pm4_store(ctx, dst, val):
|
||||
if val.op is Ops.BINARY: return None
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.devs), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
ctx.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
|
||||
|
||||
def pm4_timestamp(ctx, dst):
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.devs), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
ctx.pm4.int_sel__mec_release_mem__none)
|
||||
|
||||
def pm4_program(ctx, prg):
|
||||
data, info = prg.arg
|
||||
lib_gpu, args = prg.src
|
||||
prog_addr = make_getaddr(lib_gpu, ctx.device) + data.entry_point_offset
|
||||
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.device)
|
||||
args_addr = make_getaddr(args, ctx.device)
|
||||
prog_addr = make_getaddr(lib_gpu, ctx.devs) + data.entry_point_offset
|
||||
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.devs)
|
||||
args_addr = make_getaddr(args, ctx.devs)
|
||||
|
||||
user_regs = []
|
||||
if data.enable_private_segment_sgpr:
|
||||
|
|
@ -174,7 +174,7 @@ pm_pm4_submit = PatternMatcher([(UPat(Ops.LINEAR, name="lin"),
|
|||
class SDMAOps(FastEnum): COPY = auto(); POLL_REGMEM = auto(); FENCE = auto(); TRAP = auto(); TIMESTAMP = auto() # noqa: E702
|
||||
|
||||
def sdma_copy(ctx, dst, src, copy):
|
||||
src_addr, dst_addr = make_getaddr(src, ctx.device), make_getaddr(dst, ctx.device)
|
||||
src_addr, dst_addr = make_getaddr(src, ctx.devs), make_getaddr(dst, ctx.devs)
|
||||
return UOp(Ops.LINEAR, dtypes.void, tuple([make_ins(SDMAOps.COPY,
|
||||
ctx.sdma.SDMA_OP_COPY | ctx.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_COPY_LINEAR),
|
||||
ctx.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(min(copy.arg - off, ctx.max_copy_size) - 1), 0,
|
||||
|
|
@ -183,17 +183,17 @@ def sdma_copy(ctx, dst, src, copy):
|
|||
def sdma_wait(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_POLL_REGMEM | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) \
|
||||
| ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1)
|
||||
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.device)), val, 0xffffffff,
|
||||
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.devs)), val, 0xffffffff,
|
||||
ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
|
||||
|
||||
def sdma_store(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_FENCE | (ctx.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if ctx.target[0] != 9 else 0)
|
||||
return UOp(Ops.LINEAR, dtypes.void, (
|
||||
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.device)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
|
||||
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.devs)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
|
||||
|
||||
def sdma_timestamp(ctx, dst):
|
||||
op = ctx.sdma.SDMA_OP_TIMESTAMP | ctx.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL)
|
||||
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.device)))
|
||||
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.devs)))
|
||||
|
||||
pm_sdma_opsel = PatternMatcher([
|
||||
(UPat(Ops.BARRIER), lambda: UOp(Ops.NOOP, dtypes.void, ())),
|
||||
|
|
@ -516,11 +516,17 @@ class PCIIface(PCIIfaceBase):
|
|||
|
||||
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AMDEncodeCtx: # encode-time constants for one queue: devs (every cmdbuf address resolves into these) + gfx version + packet/ip modules
|
||||
devs: tuple[str, ...]; target: tuple[int, ...]; pm4: Any; sdma: Any; soc: Any # noqa: E702
|
||||
gc: AMDIP; nbio: AMDIP; xccs: int; max_copy_size: int; tmpring_size: Callable # noqa: E702
|
||||
|
||||
def encode_queue(q:UOp) -> UOp|None:
|
||||
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
|
||||
devs = to_tuple(q.arg[0])
|
||||
d = Device[(devs:=to_tuple(q.arg[0]))[0]]
|
||||
ctx = AMDEncodeCtx(devs, d.target, d.pm4, d.sdma, d.soc, d.gc, d.nbio, d.xccs, d.max_copy_size, d.tmpring_size)
|
||||
opsel, submit = (pm_pm4_opsel, pm_pm4_submit) if q.arg[1].startswith("COMPUTE") else (pm_sdma_opsel, pm_sdma_submit)
|
||||
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=Device[devs[0]], name=f"{q.arg[1]} opsel"))
|
||||
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=ctx, name=f"{q.arg[1]} opsel"))
|
||||
|
||||
pm_lower = PatternMatcher([
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),)), encode_queue),
|
||||
|
|
|
|||
|
|
@ -59,8 +59,8 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
|
|||
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, inv_scale.uop)
|
||||
return (None, None, grad_xw13_uop, None, None)
|
||||
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor]:
|
||||
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, new_amax)
|
||||
# grad_amax_state: delayed amax for grad_xw13 fp8 quantization in the backward.
|
||||
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
|
||||
MBS, SEQ, H2 = xw13.shape
|
||||
|
|
@ -72,5 +72,4 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_
|
|||
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname_of(xw13.device))
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, grad_amax_state,
|
||||
fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf)
|
||||
return fp8_out, scalar_amax(amax_buf)
|
||||
|
|
|
|||
|
|
@ -112,8 +112,8 @@ def _fused_add_bwd(*args, **kwargs):
|
|||
grad_h, grad_w = _bwd_common(fp8_grad_u, h_grad_u, x_u, x_normed_u, rrms_u, weight_u, amax_state_u, kernel)
|
||||
return (None, None, None, None, None, grad_h, grad_h, grad_w, None)
|
||||
|
||||
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, inv_scale, new_amax, x_normed, rrms).
|
||||
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, new_amax, x_normed, rrms).
|
||||
# x_normed + rrms are saved for the rmsnorm backward (also recomputed here from x regs).
|
||||
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
|
||||
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
|
||||
|
|
@ -127,13 +127,12 @@ def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, e
|
|||
fxn = functools.partial(_custom_fwd, dname=dname_of(x.device), eps_val=eps)
|
||||
fp8_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
|
||||
fp8_out, x_normed_out, rrms_out, amax_buf, x, weight, amax_state, fxn=fxn, grad_fxn=_fused_bwd)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf), x_normed_out, rrms_out
|
||||
return fp8_out, scalar_amax(amax_buf), x_normed_out, rrms_out
|
||||
|
||||
def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor, amax_state:Tensor,
|
||||
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
# NOTE: h = x + residual; y_normed = rmsnorm(h); fp8 = quantize(y_normed * weight).
|
||||
# Returns (fp8, inv_scale, new_amax, h, x_normed, rrms). h is also written so downstream can
|
||||
# Returns (fp8, new_amax, h, x_normed, rrms). h is also written so downstream can
|
||||
# reuse it without recomputing x+residual — eliminates the separate residual-add kernel.
|
||||
assert x.dtype == dtypes.bfloat16 and residual.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
|
||||
assert x.shape == residual.shape
|
||||
|
|
@ -149,5 +148,4 @@ def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor,
|
|||
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
|
||||
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, x, residual, weight, amax_state,
|
||||
fxn=fxn, grad_fxn=_fused_add_bwd)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out
|
||||
return fp8_out, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out
|
||||
|
|
|
|||
71
extra/llama_kernels/quantize_mxfp8_fused/__init__.py
Normal file
71
extra/llama_kernels/quantize_mxfp8_fused/__init__.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
import functools
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
|
||||
|
||||
BLK = 32
|
||||
PACK = 4
|
||||
|
||||
@functools.cache
|
||||
def _custom_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x:UOp) -> UOp:
|
||||
rows, K = x.shape
|
||||
scale_K = K // BLK
|
||||
n_elems = rows * K
|
||||
n_super = n_elems // (BLK * PACK)
|
||||
sk4 = scale_K // PACK
|
||||
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
|
||||
nwg = n_super // THREADS_PER_WG
|
||||
|
||||
x = x.reshape(n_elems)
|
||||
fp8_out = fp8_out.reshape(n_elems)
|
||||
e8_out = e8_out.reshape(rows * scale_K)
|
||||
si_out = si_out.reshape(sk4 * rows)
|
||||
|
||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
||||
sb = UOp.range(PACK, 2, AxisType.UNROLL)
|
||||
lane = UOp.range(BLK, 3, AxisType.UNROLL)
|
||||
|
||||
super_idx = wg * THREADS_PER_WG + tid
|
||||
idx = super_idx * (BLK * PACK) + sb * BLK + lane
|
||||
|
||||
x_f = x[idx].cast(dtypes.float)
|
||||
abs_x = (x_f < 0.0).where(-x_f, x_f)
|
||||
blk_max = abs_x.reduce(lane, arg=Ops.MAX)
|
||||
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
|
||||
qscale = (127.0 - e8f).exp2()
|
||||
scaled = (x_f * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
|
||||
e8u8 = e8f.cast(dtypes.uint8)
|
||||
|
||||
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
|
||||
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
|
||||
|
||||
# pack the 4 e8 of this super-block into one uint32 (little-endian: byte sb), write transposed (sk4, row)
|
||||
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
|
||||
row, col4 = super_idx // sk4, super_idx % sk4
|
||||
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
|
||||
return si_store.end(tid, wg).sink(arg=KernelInfo(f"quantize_mxfp8_{n_elems}", opts_to_apply=()))
|
||||
|
||||
def _quantize_mxfp8_fused_bwd(gradient:UOp, kernel:UOp):
|
||||
_, e8_out, _, x = kernel.src[1:]
|
||||
device = x.device
|
||||
rows, K = x.shape
|
||||
scale_K = K // BLK
|
||||
e8 = Tensor(e8_out, device=device).reshape(rows, scale_K)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, BLK).reshape(rows, K)
|
||||
grad_x = (Tensor(gradient, device=device).float() * qscale).cast(dtypes.bfloat16)
|
||||
return (None, None, None, grad_x.uop)
|
||||
|
||||
def quantize_mxfp8_fused(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
assert x.dtype == dtypes.bfloat16, f"expected bf16, got {x.dtype}"
|
||||
assert x.ndim == 2, f"expected 2d (rows, K), got {x.shape}"
|
||||
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
|
||||
rows, K = x.shape
|
||||
scale_K = K // BLK
|
||||
axis = x.uop.axis if isinstance(x.device, tuple) else None
|
||||
fp8_out = alloc_like((rows, K), FP8_DTYPE, x.device, axis)
|
||||
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x.device, axis)
|
||||
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x.device, None if axis is None else (1 if axis == 0 else 0))
|
||||
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x, fxn=_custom_quantize_mxfp8, grad_fxn=_quantize_mxfp8_fused_bwd)
|
||||
return fp8_out, e8_out, si_out
|
||||
|
|
@ -255,7 +255,7 @@ def _sign(serial_num, out_dir):
|
|||
final = _build_elf(ehdr, allph, segs)
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
out = os.path.join(out_dir, "testsig-0x%08X.so" % serial_num)
|
||||
out = os.path.join(out_dir, "testsig-0x%08x.so" % serial_num)
|
||||
with open(out, 'wb') as f:
|
||||
f.write(final)
|
||||
print("Signing complete! Output saved at %s" % out)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ constexpr int NUM_THREADS = WARP_THREADS * NUM_WARPS;
|
|||
|
||||
using G = kittens::group<NUM_WARPS>;
|
||||
|
||||
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr) {
|
||||
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr, bf16 *b_unused) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
static_assert(M % BLOCK_SIZE == 0 && N % BLOCK_SIZE == 0 && K % K_STEP == 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -99,12 +99,14 @@ using G = kittens::group<NUM_WARPS>;
|
|||
#endif
|
||||
|
||||
__global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr
|
||||
#if SCALE_MODE == 1
|
||||
#if SCALE_MODE & 1
|
||||
, float *x_scale_ptr
|
||||
#elif SCALE_MODE == 2
|
||||
#endif
|
||||
#if SCALE_MODE & 2
|
||||
, float *w_scale_ptr
|
||||
#elif SCALE_MODE == 3
|
||||
, float *x_scale_ptr, float *w_scale_ptr
|
||||
#endif
|
||||
#if SCALE_MODE & 4
|
||||
, float *g_scale_ptr
|
||||
#endif
|
||||
) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
|
|
@ -346,20 +348,19 @@ __global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_pt
|
|||
}
|
||||
|
||||
// apply x_scale * w_scale before bf16 store to prevent overflow
|
||||
#if SCALE_MODE == 1
|
||||
float scale = *x_scale_ptr;
|
||||
mul(cA, cA, scale);
|
||||
mul(cB, cB, scale);
|
||||
mul(cC, cC, scale);
|
||||
mul(cD, cD, scale);
|
||||
#elif SCALE_MODE == 2
|
||||
float scale = *w_scale_ptr;
|
||||
mul(cA, cA, scale);
|
||||
mul(cB, cB, scale);
|
||||
mul(cC, cC, scale);
|
||||
mul(cD, cD, scale);
|
||||
#elif SCALE_MODE == 3
|
||||
float scale = *x_scale_ptr * *w_scale_ptr;
|
||||
#if SCALE_MODE != 0
|
||||
float scale = 1.0f;
|
||||
#if SCALE_MODE & 1
|
||||
float x_scale = (*x_scale_ptr + 1e-08f) * (1.0f / 448.0f);
|
||||
scale *= x_scale;
|
||||
#endif
|
||||
#if SCALE_MODE & 2
|
||||
scale *= *w_scale_ptr;
|
||||
#endif
|
||||
#if SCALE_MODE & 4
|
||||
scale *= *g_scale_ptr;
|
||||
#endif
|
||||
|
||||
mul(cA, cA, scale);
|
||||
mul(cB, cB, scale);
|
||||
mul(cC, cC, scale);
|
||||
|
|
|
|||
|
|
@ -26,9 +26,7 @@ constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
|
|||
|
||||
using G = kittens::group<NUM_WARPS>;
|
||||
|
||||
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
|
||||
const uint32_t *__restrict__ scale_A_iter,
|
||||
const uint32_t *__restrict__ scale_B_iter,
|
||||
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr, fp8e8m0 *scale_A_ptr, fp8e8m0 *scale_B_ptr,
|
||||
const uint8_t *__restrict__ a_e8_unused,
|
||||
const uint8_t *__restrict__ b_e8_unused) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
|
|
@ -40,15 +38,19 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
constexpr int k_iters = K / BLOCK_K;
|
||||
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
|
||||
|
||||
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
|
||||
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
|
||||
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
|
||||
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
|
||||
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
|
||||
kittens::gl<fp8e8m0, k_iters * (M / BLOCK_ROW), 1, 16, 64> scale_A_gl{scale_A_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
kittens::gl<fp8e8m0, k_iters * (N / BLOCK_COL), 1, 16, 64> scale_B_gl{scale_B_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
|
||||
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
|
||||
using ST_Scale = st<fp8e8m0, 16, 64, st_16x64_s>;
|
||||
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
|
||||
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
|
||||
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
|
||||
|
||||
__shared__ ST_A As[2][2];
|
||||
__shared__ ST_B Bs[2][2];
|
||||
__shared__ uint8_t smem_scales[2048];
|
||||
__shared__ ST_Scale scale_A_smem[2], scale_B_smem[2];
|
||||
|
||||
RT_A a;
|
||||
RT_B b0, b1;
|
||||
|
|
@ -81,54 +83,84 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
G::prefill_swizzled_offsets(As[0][0], A, sw_A);
|
||||
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
|
||||
|
||||
const T *a_base = (const T *)&A[{0, 0, 0, 0}];
|
||||
const T *b_base = (const T *)&B[{0, 0, 0, 0}];
|
||||
const int a_row_stride = A.template stride<2>() * sizeof(T);
|
||||
const int b_row_stride = B.template stride<2>() * sizeof(T);
|
||||
i32x4 a_srd = make_srsrc(a_base, M * a_row_stride, a_row_stride);
|
||||
i32x4 b_srd = make_srsrc(b_base, N * b_row_stride, b_row_stride);
|
||||
|
||||
const int wid = warpid() % NUM_WARPS;
|
||||
constexpr int elem_per_warp = (16 / sizeof(T)) * kittens::WARP_THREADS;
|
||||
uint32_t a_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t a_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t a_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t a_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
|
||||
int a_row_h0 = warp_m * REG_M;
|
||||
int a_row_h1 = HALF_ROW + warp_m * REG_M;
|
||||
int b_row_h0 = warp_n * REG_N;
|
||||
int b_row_h1 = HALF_COL + warp_n * REG_N;
|
||||
|
||||
int tic = 0, toc = 1;
|
||||
int tic_scales = 0, toc_scales = 1;
|
||||
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A);
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B);
|
||||
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A);
|
||||
uint32_t a_lds[2][2] = {{a_lds_00, a_lds_01}, {a_lds_10, a_lds_11}};
|
||||
uint32_t b_lds[2][2] = {{b_lds_00, b_lds_01}, {b_lds_10, b_lds_11}};
|
||||
|
||||
if constexpr (k_iters >= 6 && (k_iters % 2 == 0)) {
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
|
||||
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][1]));
|
||||
|
||||
if (warp_m == 1) __builtin_amdgcn_s_barrier();
|
||||
asm volatile("s_waitcnt vmcnt(4)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A);
|
||||
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B);
|
||||
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B);
|
||||
|
||||
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][0]));
|
||||
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][0]));
|
||||
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][1]));
|
||||
asm volatile("s_waitcnt vmcnt(6)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1) {
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
G::load(scale_A_smem[0], scale_A_gl, {0 * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[0], scale_B_gl, {0 * tiles_N + block_col, 0, 0, 0});
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
|
||||
if (k + 1 < k_iters) {
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
|
||||
}
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
|
||||
asm volatile("s_waitcnt lgkmcnt(8)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
||||
load(b1, bs1);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
|
|
@ -137,10 +169,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
||||
load(a, as1);
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B);
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
|
|
@ -150,7 +181,7 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B);
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
|
||||
asm volatile("s_waitcnt vmcnt(6)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
|
|
@ -162,19 +193,23 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
|
||||
{ // Epilogue k = k_iters - 2
|
||||
int k = k_iters - 2;
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
if (k + 1 < k_iters) {
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
||||
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
|
|
@ -217,17 +252,17 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
tic ^= 1; toc ^= 1;
|
||||
tic_scales ^= 1; toc_scales ^= 1;
|
||||
}
|
||||
|
||||
{ // Epilogue k = k_iters - 1 (final)
|
||||
int k = k_iters - 1;
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
||||
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
|
|
@ -263,6 +298,38 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
} else {
|
||||
#pragma unroll 1
|
||||
for (int kk = 0; kk < k_iters; kk++) {
|
||||
G::load(As[0][0], A, {0, 0, block_row * 2, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][0]));
|
||||
G::load(As[0][1], A, {0, 0, block_row * 2 + 1, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][1]));
|
||||
G::load(Bs[0][0], B, {0, 0, block_col * 2, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][0]));
|
||||
G::load(Bs[0][1], B, {0, 0, block_col * 2 + 1, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][1]));
|
||||
G::load(scale_A_smem[0], scale_A_gl, {kk * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[0], scale_B_gl, {kk * tiles_N + block_col, 0, 0, 0});
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[0].data, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[0].data, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[0].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[0].data, b_row_h1);
|
||||
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][0], {warp_n, 0}); load(b0, bs0);
|
||||
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][1], {warp_n, 0}); load(b1, bs1);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[0][0], {warp_m, 0}); load(a, as0);
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
||||
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
||||
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[0][1], {warp_m, 0}); load(a, as1);
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
||||
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
}
|
||||
|
||||
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
||||
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
|
||||
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
||||
|
|
|
|||
|
|
@ -438,6 +438,35 @@ __device__ __forceinline__ T buffer_load_dwordx2(
|
|||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_load_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("buffer_load_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a u32x3 (ext_vector_type(3)). Note sizeof(u32x3) == 16 due to vector
|
||||
// alignment padding, but the compiler still allocates exactly 3 consecutive VGPRs
|
||||
// for the "=v" operand and emits buffer_load_dwordx3 v[N:N+2].
|
||||
template<typename T = u32x3>
|
||||
__device__ __forceinline__ T buffer_load_dwordx3(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
T result;
|
||||
asm volatile("buffer_load_dwordx3 %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
|
|
@ -465,6 +494,116 @@ __device__ __forceinline__ T buffer_load_dwordx4(
|
|||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_ubyte(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_sbyte(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_sbyte %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_ushort(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_sshort(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_sshort %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
// d16 variants merge the loaded sub-dword into the low (or high, for _hi) 16 bits
|
||||
// of the destination register, leaving the other half untouched: read-modify-write.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_ubyte_d16(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_ubyte_d16 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_ubyte_d16_hi(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_ubyte_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_sbyte_d16(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_sbyte_d16 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_sbyte_d16_hi(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_sbyte_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_short_d16(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_short_d16 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_short_d16_hi(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_short_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR>
|
||||
__device__ __forceinline__ void buffer_store_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
|
|
@ -519,6 +658,32 @@ __device__ __forceinline__ void buffer_store_dwordx2(
|
|||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_store_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
// VGPRS
|
||||
} else {
|
||||
asm volatile("buffer_store_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T = u32x3>
|
||||
__device__ __forceinline__ void buffer_store_dwordx3(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
asm volatile("buffer_store_dwordx3 %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
|
|
@ -546,6 +711,50 @@ __device__ __forceinline__ void buffer_store_dwordx4(
|
|||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the low 8 bits of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_byte(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_byte %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the high 16 bits' low 8 bits (bits [23:16]) of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_byte_d16_hi(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_byte_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the low 16 bits of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_short(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_short %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the high 16 bits of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_short_d16_hi(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_short_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR>
|
||||
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR >= 256) {
|
||||
|
|
@ -52,6 +52,8 @@ __device__ __forceinline__ int laneid() { return threadIdx.x & 0x3f; }
|
|||
|
||||
using i32x2 = int32_t __attribute__((ext_vector_type(2)));
|
||||
using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
|
||||
using i32x3 = int32_t __attribute__((ext_vector_type(3)));
|
||||
using u32x3 = uint32_t __attribute__((ext_vector_type(3)));
|
||||
using i32x4 = int32_t __attribute__((ext_vector_type(4)));
|
||||
using u32x4 = uint32_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
10
extra/thunder/amd/include/cdna4/includes.cuh
Normal file
10
extra/thunder/amd/include/cdna4/includes.cuh
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief Aggregate header for CDNA4 ThunderKittens core include tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/common.cuh"
|
||||
#include "types/types.cuh"
|
||||
#include "ops/ops.cuh"
|
||||
|
|
@ -248,7 +248,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
|
|||
"ds_read_b64_tr_b16 %0, %2 offset:%3\n"
|
||||
"ds_read_b64_tr_b16 %1, %2 offset:%4\n"
|
||||
// "s_waitcnt lgkmcnt(0)\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
|
||||
: "v"(addr), "i"(offset), "i"(offset + 4 * ST::underlying_subtile_row_bytes)
|
||||
: "memory"
|
||||
|
|
@ -259,7 +259,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
|
|||
"ds_read_b64_tr_b16 %0, %2 offset:%4\n"
|
||||
"ds_read_b64_tr_b16 %1, %3 offset:%4\n"
|
||||
// "s_waitcnt lgkmcnt(0)\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
|
||||
: "v"(addr), "v"(next_addr), "i"(offset)
|
||||
: "memory"
|
||||
|
|
@ -688,4 +688,88 @@ __device__ inline static void store(ST &dst, const RT &src) {
|
|||
}
|
||||
}
|
||||
|
||||
template<ducks::rt::col_layout RT, ducks::st::all ST>
|
||||
__device__ inline static void load(RT &dst, const ST &src, int col_offset) {
|
||||
static_assert(RT::cols == ST::rows,
|
||||
"col_layout load with col_offset: ST.rows must equal RT::cols");
|
||||
static_assert(RT::width == 1, "col_layout load with col_offset: only width==1 supported");
|
||||
|
||||
using T2 = RT::dtype;
|
||||
using U = ST::dtype;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
|
||||
const int laneid = kittens::laneid();
|
||||
const uint32_t src_ptr = (uint32_t)(uintptr_t)(&src.data[0]);
|
||||
constexpr int subtile_bytes = ST::underlying_subtile_bytes;
|
||||
|
||||
if constexpr (std::is_same_v<T2, fp8e4m3_4>) {
|
||||
const int block_id = laneid / 16;
|
||||
const int l_within = laneid % 16;
|
||||
const int tr_k_grp = l_within / 2;
|
||||
const int m_half = l_within & 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RT::height; i++) {
|
||||
const int m_col = col_offset + i * RT::base_tile_rows + m_half * 8;
|
||||
|
||||
const int subtile_base = block_id;
|
||||
uint32_t addr = src_ptr
|
||||
+ (uint32_t)(subtile_base * subtile_bytes)
|
||||
+ src.swizzle({tr_k_grp, m_col});
|
||||
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
|
||||
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
|
||||
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
|
||||
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[4]))
|
||||
: "v"(addr), "i"(4 * subtile_bytes)
|
||||
: "memory"
|
||||
);
|
||||
|
||||
addr ^= 1088u;
|
||||
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
|
||||
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
|
||||
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2])),
|
||||
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[6]))
|
||||
: "v"(addr), "i"(4 * subtile_bytes)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
} else if constexpr (std::is_same_v<U2, bf16_2> || std::is_same_v<U2, half_2>) {
|
||||
const int row_offset = ((laneid % 16) / 4) + ((laneid / 16) * 4);
|
||||
const int col_offset_lane = (laneid % 4) * 4;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RT::height; i++) {
|
||||
const int m_col = col_offset + i * RT::base_tile_rows + col_offset_lane;
|
||||
const uint32_t addr = src_ptr + src.swizzle({row_offset, m_col});
|
||||
const uint32_t next_addr = src_ptr + src.swizzle({row_offset + 4, m_col});
|
||||
|
||||
if constexpr (RT::base_tile_stride == 8) {
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b16 %0, %2 offset:0\n"
|
||||
"ds_read_b64_tr_b16 %1, %3 offset:0\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
|
||||
"=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2]))
|
||||
: "v"(addr), "v"(next_addr)
|
||||
: "memory"
|
||||
);
|
||||
} else if constexpr (RT::base_tile_stride == 4) {
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b16 %0, %1 offset:0\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0]))
|
||||
: "v"(addr)
|
||||
: "memory"
|
||||
);
|
||||
} else {
|
||||
static_assert(false, "Unsupported stride for col_layout load with col_offset");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(false, "Unsupported type for col_layout load with col_offset");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kittens
|
||||
|
|
@ -201,10 +201,10 @@ __device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::col,
|
|||
* @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
|
||||
template<ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
|
||||
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col, D_shape> &d,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b, // in row-major mode
|
||||
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
|
||||
const rt_base<float, ducks::rt_layout::col, C_shape> &c) {
|
||||
|
||||
static_assert(std::is_same_v<D_shape, C_shape>, "D and C must have the same shape");
|
||||
|
|
@ -243,6 +243,32 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
|
||||
*
|
||||
* @param smem_scales LDS pointer to scale tile data.
|
||||
* @param row_offset Starting row within the scale region (warp's tile offset).
|
||||
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
|
||||
*/
|
||||
__device__ __forceinline__ fp8e8m0_4 pack_scales(
|
||||
const fp8e8m0 *smem_scales, int row_offset) {
|
||||
int lid = laneid();
|
||||
int r16 = lid % 16;
|
||||
int k_sub = lid / 16;
|
||||
|
||||
const fp8e8m0_4 *s4 = (const fp8e8m0_4 *)smem_scales;
|
||||
fp8e8m0_4 w0 = s4[row_offset + 0 * 16 + r16];
|
||||
fp8e8m0_4 w1 = s4[row_offset + 1 * 16 + r16];
|
||||
fp8e8m0_4 w2 = s4[row_offset + 2 * 16 + r16];
|
||||
fp8e8m0_4 w3 = s4[row_offset + 3 * 16 + r16];
|
||||
|
||||
fp8e8m0_4 sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
|
||||
fp8e8m0_4 lo = __builtin_amdgcn_perm(w0, w1, sel);
|
||||
fp8e8m0_4 hi = __builtin_amdgcn_perm(w2, w3, sel);
|
||||
|
||||
return lo | (hi << 16);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Base dot product operation for row layout.
|
||||
*
|
||||
|
|
@ -254,10 +280,10 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
|
|||
* @param[in] b The second input rt_base<Operand_T, row_layout> matrix.
|
||||
* @param[in] c The input rt_base<float, col_layout> accumulator matrix.
|
||||
*/
|
||||
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
|
||||
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
|
||||
__device__ static inline void mma_ABt_base_scaled(rt_base<float, ducks::rt_layout::col, D_shape> &d,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b,
|
||||
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
|
||||
const rt_base<float, ducks::rt_layout::col, C_shape> &c,
|
||||
const fp8e8m0_4 *scale_a,
|
||||
const fp8e8m0_4 *scale_b) {
|
||||
|
|
@ -441,7 +467,7 @@ __device__ static inline void mma_AB(D &d,
|
|||
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
|
||||
template<ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
|
||||
__device__ static inline void mma_ABt(D &d,
|
||||
const A &a,
|
||||
const B &b, // notice row and (M, K) instead of col and (K, M)
|
||||
|
|
@ -499,7 +525,7 @@ __device__ static inline void mma_ABt(D &d,
|
|||
* @param[in] scale_a Pointer to the packed E8M0 scale for the A matrix.
|
||||
* @param[in] scale_b Pointer to the packed E8M0 scale for the B matrix.
|
||||
*/
|
||||
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
|
||||
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
|
||||
__device__ static inline void mma_ABt_scaled(D &d,
|
||||
const A &a,
|
||||
const B &b,
|
||||
|
|
@ -9,6 +9,5 @@
|
|||
#include "maps.cuh"
|
||||
#include "reductions.cuh"
|
||||
#include "mma.cuh"
|
||||
#include "scales.cuh"
|
||||
|
||||
#include "assembly/tile.cuh"
|
||||
|
|
@ -41,7 +41,7 @@ struct st_16x16 {
|
|||
} else {
|
||||
static_assert(false, "Unsupported type");
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -205,6 +205,23 @@ struct st_8x32 {
|
|||
}
|
||||
};
|
||||
|
||||
struct st_16x64 {
|
||||
static constexpr int rows = 16;
|
||||
static constexpr int cols = 64;
|
||||
|
||||
template<typename _T>
|
||||
static constexpr int bytes_per_thread() {
|
||||
static_assert(sizeof(_T) == 1 || sizeof(_T) == 2 || sizeof(_T) == 4, "Unsupported type");
|
||||
return 16;
|
||||
}
|
||||
|
||||
template<typename _T>
|
||||
__device__ __forceinline__ static const uint32_t swizzle (int2 coord) {
|
||||
static_assert(sizeof(_T) == 1 || sizeof(_T) == 2 || sizeof(_T) == 4, "Unsupported type");
|
||||
return sizeof(_T) * (coord.x * cols + coord.y);
|
||||
}
|
||||
};
|
||||
|
||||
struct st_16x128 {
|
||||
static constexpr int rows = 16;
|
||||
static constexpr int cols = 128;
|
||||
|
|
@ -236,12 +253,13 @@ struct st_16x128 {
|
|||
};
|
||||
|
||||
template<typename T>
|
||||
concept all = std::is_same_v<T, st_16x16> ||
|
||||
std::is_same_v<T, st_16x16_swizzled> ||
|
||||
std::is_same_v<T, st_32x32> ||
|
||||
std::is_same_v<T, st_16x32> ||
|
||||
std::is_same_v<T, st_32x16> ||
|
||||
concept all = std::is_same_v<T, st_16x16> ||
|
||||
std::is_same_v<T, st_16x16_swizzled> ||
|
||||
std::is_same_v<T, st_32x32> ||
|
||||
std::is_same_v<T, st_16x32> ||
|
||||
std::is_same_v<T, st_32x16> ||
|
||||
std::is_same_v<T, st_8x32> ||
|
||||
std::is_same_v<T, st_16x64> ||
|
||||
std::is_same_v<T, st_16x128>;
|
||||
|
||||
|
||||
|
|
@ -75,6 +75,7 @@ using st_32x32_s = ducks::st_shape::st_32x32;
|
|||
using st_16x32_s = ducks::st_shape::st_16x32;
|
||||
using st_32x16_s = ducks::st_shape::st_32x16;
|
||||
using st_8x32_s = ducks::st_shape::st_8x32;
|
||||
using st_16x64_s = ducks::st_shape::st_16x64;
|
||||
using st_16x128_s = ducks::st_shape::st_16x128;
|
||||
|
||||
}
|
||||
|
|
@ -5,9 +5,12 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "common/common.cuh"
|
||||
#include "types/types.cuh"
|
||||
#include "ops/ops.cuh"
|
||||
#if defined(KITTENS_CDNA4)
|
||||
#include "cdna4/includes.cuh"
|
||||
#elif defined(KITTENS_UDNA1)
|
||||
#include "udna1/includes.cuh"
|
||||
#endif
|
||||
|
||||
#include "pyutils/util.cuh"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,74 +0,0 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief MXFP8 block scale loading and packing utilities.
|
||||
*
|
||||
* Provides functions for staging E8M0 block scales in LDS and packing them
|
||||
* into fp8e8m0_4 registers for use with scaled MFMA instructions.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/**
|
||||
* @brief Load iteration-major packed E8M0 scales from global memory into LDS.
|
||||
*
|
||||
* First 256 threads each load one uint32 (4 packed E8M0 bytes) for A and B.
|
||||
* A scales are placed at smem[0..1023], B scales at smem[1024..2047].
|
||||
*
|
||||
* @param smem_scales LDS buffer, must be >= 2048 bytes.
|
||||
* @param scale_A_iter Iteration-major A scales: [k_iter * M + row] as uint32.
|
||||
* @param scale_B_iter Iteration-major B scales: [k_iter * N + row] as uint32.
|
||||
* @param block_m Starting row offset for A within the current block.
|
||||
* @param block_n Starting row offset for B within the current block.
|
||||
* @param k_iter Current K iteration index.
|
||||
* @param M_dim M dimension of the matrix.
|
||||
* @param N_dim N dimension of the matrix.
|
||||
*/
|
||||
__device__ __forceinline__ void load_scales_to_lds(
|
||||
uint8_t *smem_scales,
|
||||
const uint32_t *__restrict__ scale_A_iter,
|
||||
const uint32_t *__restrict__ scale_B_iter,
|
||||
int block_m, int block_n, int k_iter, int M_dim, int N_dim) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < 256) {
|
||||
uint32_t sa = scale_A_iter[k_iter * M_dim + block_m + tid];
|
||||
uint32_t sb = scale_B_iter[k_iter * N_dim + block_n + tid];
|
||||
*(uint32_t *)&smem_scales[tid * 4] = sa;
|
||||
*(uint32_t *)&smem_scales[1024 + tid * 4] = sb;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
|
||||
*
|
||||
* Each lane (r16 = laneid%16, k_sub = laneid/16) loads 4 dwords from
|
||||
* consecutive 16-row groups, then uses v_perm_b32 to extract byte k_sub
|
||||
* from each, producing the packed scale register for scaled MFMA.
|
||||
*
|
||||
* @param smem_scales LDS pointer to scale region.
|
||||
* @param lds_base Byte offset within smem_scales (0 for A, 1024 for B).
|
||||
* @param row_offset Starting row within the scale region (warp's tile offset).
|
||||
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
|
||||
*/
|
||||
__device__ __forceinline__ fp8e8m0_4 pack_scales(
|
||||
const uint8_t *smem_scales, int lds_base, int row_offset) {
|
||||
int lid = laneid();
|
||||
int r16 = lid % 16;
|
||||
int k_sub = lid / 16;
|
||||
|
||||
const uint32_t *s32 = (const uint32_t *)(smem_scales + lds_base);
|
||||
uint32_t w0 = s32[row_offset + 0 * 16 + r16];
|
||||
uint32_t w1 = s32[row_offset + 1 * 16 + r16];
|
||||
uint32_t w2 = s32[row_offset + 2 * 16 + r16];
|
||||
uint32_t w3 = s32[row_offset + 3 * 16 + r16];
|
||||
|
||||
uint32_t sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
|
||||
uint32_t lo = __builtin_amdgcn_perm(w0, w1, sel);
|
||||
uint32_t hi = __builtin_amdgcn_perm(w2, w3, sel);
|
||||
|
||||
return (fp8e8m0_4)(lo | (hi << 16));
|
||||
}
|
||||
} // namespace kittens
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include "../ops/ops.cuh"
|
||||
#include <iostream>
|
||||
|
||||
#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
|
||||
|
|
|
|||
449
extra/thunder/amd/include/udna1/common/base_ops.cuh
Normal file
449
extra/thunder/amd/include/udna1/common/base_ops.cuh
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief Basic operations on generic types.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include "base_types.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/**
|
||||
* @namespace base_ops
|
||||
*
|
||||
* @brief A namespace for operations on basic data types.
|
||||
*/
|
||||
namespace base_ops {
|
||||
|
||||
/* ---------- CONST OPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Represents the zero constant operation.
|
||||
*
|
||||
* This operation returns the zero value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the zero value.
|
||||
* @return The zero value of type T.
|
||||
*/
|
||||
struct zero {
|
||||
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::zero(); }
|
||||
};
|
||||
/**
|
||||
* @brief Represents the ones constant operation.
|
||||
*
|
||||
* This operation returns the ones value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the ones value.
|
||||
* @return The ones value of type T.
|
||||
*/
|
||||
struct ones {
|
||||
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::ones(); }
|
||||
};
|
||||
/**
|
||||
* @brief Represents the positive infinity constant operation.
|
||||
*
|
||||
* This operation returns the positive infinity value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the positive infinity value.
|
||||
* @return The positive infinity value of type T.
|
||||
*/
|
||||
struct pos_infty {
|
||||
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::pos_infty(); }
|
||||
};
|
||||
/**
|
||||
* @brief Represents the negative infinity constant operation.
|
||||
*
|
||||
* This operation returns the negative infinity value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the negative infinity value.
|
||||
* @return The negative infinity value of type T.
|
||||
*/
|
||||
struct neg_infty {
|
||||
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::neg_infty(); }
|
||||
};
|
||||
|
||||
|
||||
/* ---------- UNARY OPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Exponential function operation.
|
||||
*
|
||||
* This operation calculates the exponential of the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The exponential of the input value.
|
||||
*/
|
||||
struct exp {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return exp(x); }
|
||||
};
|
||||
template<> __device__ inline float exp::op<float> (const float &x ) { return __expf(x); }
|
||||
template<> __device__ inline float2 exp::op<float2>(const float2 &x) { return float2{__expf(x.x), __expf(x.y)}; }
|
||||
template<> __device__ inline bf16 exp::op<bf16> (const bf16 &x ) { return hexp(x); }
|
||||
template<> __device__ inline bf16_2 exp::op<bf16_2>(const bf16_2 &x) { return h2exp(x); }
|
||||
template<> __device__ inline half exp::op<half> (const half &x ) { return hexp(x); }
|
||||
template<> __device__ inline half_2 exp::op<half_2>(const half_2 &x) { return h2exp(x); }
|
||||
|
||||
// /**
|
||||
// * @brief Exponential function operation, in base 2
|
||||
// *
|
||||
// * This operation calculates the exponential of the input value, in base 2.
|
||||
// *
|
||||
// * @tparam T The data type of the input and output values.
|
||||
// * @param x[in] The input value.
|
||||
// * @return The exponential of the input value.
|
||||
// */
|
||||
// struct exp2 {
|
||||
// template<typename T> static __device__ inline T op(const T &x) { return exp2f(x); }
|
||||
// };
|
||||
// template<> __device__ inline float exp2::op<float> (const float &x ) { return exp2f(x); }
|
||||
// template<> __device__ inline float2 exp2::op<float2>(const float2 &x) { return float2{exp2f(x.x), exp2f(x.y)}; }
|
||||
// template<> __device__ inline bf16 exp2::op<bf16> (const bf16 &x ) { return hexp2(x); }
|
||||
// template<> __device__ inline bf16_2 exp2::op<bf16_2>(const bf16_2 &x) { return h2exp2(x); }
|
||||
// template<> __device__ inline half exp2::op<half> (const half &x ) { return hexp2(x); }
|
||||
// template<> __device__ inline half_2 exp2::op<half_2>(const half_2 &x) { return h2exp2(x); }
|
||||
|
||||
|
||||
/**
|
||||
* @brief Base-2 exponential operation using `__builtin_amdgcn_exp2_f32`
|
||||
*
|
||||
* Maps directly to `v_exp_f32_e32` on AMD, for highest performance.
|
||||
* Expects `x` to be in a safe numerical range (e.g., [-64, 88]).
|
||||
*/
|
||||
struct exp2 {
|
||||
template <typename T>
|
||||
static __device__ inline T op(const T &x) {
|
||||
return exp2f(x); // fallback
|
||||
}
|
||||
};
|
||||
|
||||
// Force hardware v_exp_f32 for float
|
||||
template<>
|
||||
__device__ inline float exp2::op<float>(const float &x) {
|
||||
return __builtin_amdgcn_exp2f(x); // Emits v_exp_f32_e32
|
||||
}
|
||||
|
||||
// Force hardware v_exp_f32 for float2
|
||||
template<>
|
||||
__device__ inline float2 exp2::op<float2>(const float2 &x) {
|
||||
return {
|
||||
__builtin_amdgcn_exp2f(x.x),
|
||||
__builtin_amdgcn_exp2f(x.y)
|
||||
};
|
||||
}
|
||||
|
||||
// Delegate to low-precision approximations
|
||||
template<> __device__ inline half exp2::op<half>(const half &x) { return hexp2(x); }
|
||||
template<> __device__ inline half_2 exp2::op<half_2>(const half_2 &x) { return h2exp2(x); }
|
||||
template<> __device__ inline bf16 exp2::op<bf16>(const bf16 &x) { return hexp2(x); }
|
||||
template<> __device__ inline bf16_2 exp2::op<bf16_2>(const bf16_2 &x) { return h2exp2(x); }
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @brief Natural log function operation.
|
||||
*
|
||||
* This operation calculates the natural logarithm of the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The natural logarithm of the input value.
|
||||
*/
|
||||
struct log {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return log(x); }
|
||||
};
|
||||
template<> __device__ inline float log::op<float> (const float &x ) { return __logf(x); }
|
||||
template<> __device__ inline float2 log::op<float2>(const float2 &x) { return float2{__logf(x.x), __logf(x.y)}; }
|
||||
template<> __device__ inline bf16 log::op<bf16> (const bf16 &x ) { return hlog(x); }
|
||||
template<> __device__ inline bf16_2 log::op<bf16_2>(const bf16_2 &x) { return h2log(x); }
|
||||
template<> __device__ inline half log::op<half> (const half &x ) { return hlog(x); }
|
||||
template<> __device__ inline half_2 log::op<half_2>(const half_2 &x) { return h2log(x); }
|
||||
/**
|
||||
* @brief Logarithm base 2 operation.
|
||||
*
|
||||
* This operation calculates the logarithm base 2 of the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The logarithm base 2 of the input value.
|
||||
*/
|
||||
struct log2 {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return log2(x); }
|
||||
};
|
||||
template<> __device__ inline float log2::op<float> (const float &x ) { return __log2f(x); }
|
||||
template<> __device__ inline float2 log2::op<float2>(const float2 &x) { return float2{__log2f(x.x), __log2f(x.y)}; }
|
||||
template<> __device__ inline bf16 log2::op<bf16> (const bf16 &x ) { return hlog2(x); }
|
||||
template<> __device__ inline bf16_2 log2::op<bf16_2>(const bf16_2 &x) { return h2log2(x); }
|
||||
template<> __device__ inline half log2::op<half> (const half &x ) { return hlog2(x); }
|
||||
template<> __device__ inline half_2 log2::op<half_2>(const half_2 &x) { return h2log2(x); }
|
||||
/**
|
||||
* @brief Absolute value operation.
|
||||
*
|
||||
* This operation calculates the absolute value of the input.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The absolute value of the input.
|
||||
*/
|
||||
struct abs {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return abs(x); }
|
||||
};
|
||||
template<> __device__ inline float abs::op<float> (const float &x ) { return fabsf(x); }
|
||||
template<> __device__ inline float2 abs::op<float2>(const float2 &x) { return float2{fabsf(x.x), fabsf(x.y)}; }
|
||||
template<> __device__ inline bf16 abs::op<bf16> (const bf16 &x ) { return __habs(x); }
|
||||
template<> __device__ inline bf16_2 abs::op<bf16_2>(const bf16_2 &x) { return __habs2(x); }
|
||||
template<> __device__ inline half abs::op<half> (const half &x ) { return __habs(x); }
|
||||
template<> __device__ inline half_2 abs::op<half_2>(const half_2 &x) { return __habs2(x); }
|
||||
/**
|
||||
* @brief Rectified Linear Unit (ReLU) operation.
|
||||
*
|
||||
* This operation applies the ReLU function to the input, which is the
|
||||
* maximum of zero and the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The result of ReLU function applied to the input.
|
||||
*/
|
||||
struct relu {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return max(x, base_types::constants<T>::zero()); }
|
||||
};
|
||||
template<> __device__ inline float relu::op<float> (const float &x ) { return max(x, 0.f); }
|
||||
template<> __device__ inline float2 relu::op<float2>(const float2 &x) { return float2{max(x.x, 0.f), max(x.y, 0.f)}; }
|
||||
template<> __device__ inline bf16 relu::op<bf16> (const bf16 &x ) { return __hmax(x, base_types::constants<bf16>::zero()); }
|
||||
template<> __device__ inline bf16_2 relu::op<bf16_2>(const bf16_2 &x) { return __hmax2(x, base_types::constants<bf16_2>::zero()); }
|
||||
template<> __device__ inline half relu::op<half> (const half &x ) { return __hmax(x, base_types::constants<half>::zero()); }
|
||||
template<> __device__ inline half_2 relu::op<half_2>(const half_2 &x) { return half_2{__hmax(x.x, base_types::constants<half>::zero()),
|
||||
__hmax(x.y, base_types::constants<half>::zero())}; }
|
||||
|
||||
|
||||
constexpr float SQRT_2_OVER_PI = 0.7978845608028654f;
|
||||
constexpr float GELU_COEFF = 0.044715f;
|
||||
constexpr float GELU_INNER_COEFF = GELU_COEFF * SQRT_2_OVER_PI;
|
||||
constexpr float DGELU_COEFF = 3.0f * GELU_COEFF * SQRT_2_OVER_PI;
|
||||
|
||||
static __device__ inline float fast_tanh(float x) {
|
||||
x = fmaxf(fminf(x, 20.f), -20.f);
|
||||
float e2x = __builtin_amdgcn_exp2f(x * 2.8853900817779268f);
|
||||
return (e2x - 1.0f) * __frcp_rn(e2x + 1.0f);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gaussian Error Linear Unit (GELU) activation.
|
||||
*
|
||||
* Computes the GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))).
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The GELU activation applied to the input.
|
||||
*/
|
||||
struct gelu {
|
||||
template<typename T> static __device__ inline T op(const T &x);
|
||||
};
|
||||
template<> __device__ inline float gelu::op<float>(const float &x) {
|
||||
return x * (0.5f + 0.5f * fast_tanh(x * (SQRT_2_OVER_PI + GELU_INNER_COEFF * x * x)));
|
||||
}
|
||||
template<> __device__ inline float2 gelu::op<float2>(const float2 &x) {
|
||||
return float2{gelu::op<float>(x.x), gelu::op<float>(x.y)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Derivative of the GELU activation.
|
||||
*
|
||||
* Computes the derivative of the GELU approximation with respect to the input.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The derivative of GELU evaluated at the input.
|
||||
*/
|
||||
struct dgelu {
|
||||
template<typename T> static __device__ inline T op(const T &x);
|
||||
};
|
||||
template<> __device__ inline float dgelu::op<float>(const float &x) {
|
||||
float tanh_out = fast_tanh(SQRT_2_OVER_PI * x * (1.f + GELU_COEFF * x * x));
|
||||
return 0.5f * x * ((1.f - tanh_out * tanh_out) * (SQRT_2_OVER_PI + DGELU_COEFF * x * x)) +
|
||||
0.5f * (1.f + tanh_out);
|
||||
}
|
||||
template<> __device__ inline float2 dgelu::op<float2>(const float2 &x) {
|
||||
return float2{dgelu::op<float>(x.x), dgelu::op<float>(x.y)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copy operation.
|
||||
*
|
||||
* This operation returns the input value unchanged.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The input value.
|
||||
* @return The same value as the input.
|
||||
*/
|
||||
struct copy { // for non-compile-time setters.
|
||||
template<typename T> static __device__ inline T op(const T &a) { return a; }
|
||||
};
|
||||
|
||||
|
||||
/* ---------- BINARY OPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Copy2 operation.
|
||||
*
|
||||
* This operation returns the second input value unchanged.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value (ignored).
|
||||
* @param b[in] The second input value.
|
||||
* @return The same value as the second input.
|
||||
*/
|
||||
struct copy2 { // this turns out to be a slightly hacky op that makes some code cleaner :/
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return b; }
|
||||
};
|
||||
/**
|
||||
* @brief Sum operation.
|
||||
*
|
||||
* This operation calculates the sum of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The sum of the input values.
|
||||
*/
|
||||
struct sum {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a+b; }
|
||||
};
|
||||
template<> __device__ inline float2 sum::op<float2>(const float2 &a, const float2 &b) { return float2{a.x+b.x, a.y+b.y}; }
|
||||
template<> __device__ inline bf16 sum::op<bf16> (const bf16 &a, const bf16 &b) { return __hadd(a, b); }
|
||||
template<> __device__ inline bf16_2 sum::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hadd2(a, b); }
|
||||
template<> __device__ inline half sum::op<half> (const half &a, const half &b) { return __hadd(a, b); }
|
||||
template<> __device__ inline half_2 sum::op<half_2>(const half_2 &a, const half_2 &b) { return __hadd2(a, b); }
|
||||
/**
|
||||
* @brief Subtraction operation.
|
||||
*
|
||||
* This operation calculates the difference between two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The difference between the input values.
|
||||
*/
|
||||
struct sub {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a-b; }
|
||||
};
|
||||
template<> __device__ inline float2 sub::op<float2>(const float2 &a, const float2 &b) { return float2{a.x-b.x, a.y-b.y}; }
|
||||
template<> __device__ inline bf16 sub::op<bf16> (const bf16 &a, const bf16 &b) { return __hsub(a, b); }
|
||||
template<> __device__ inline bf16_2 sub::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hsub2(a, b); }
|
||||
template<> __device__ inline half sub::op<half> (const half &a, const half &b) { return __hsub(a, b); }
|
||||
template<> __device__ inline half_2 sub::op<half_2>(const half_2 &a, const half_2 &b) { return __hsub2(a, b); }
|
||||
/**
|
||||
* @brief Multiplication operation.
|
||||
*
|
||||
* This operation calculates the product of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The product of the input values.
|
||||
*/
|
||||
struct mul {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a*b; }
|
||||
};
|
||||
template<> __device__ inline float2 mul::op<float2>(const float2 &a, const float2 &b) { return float2{a.x*b.x, a.y*b.y}; }
|
||||
template<> __device__ inline bf16 mul::op<bf16> (const bf16 &a, const bf16 &b) { return __hmul(a, b); }
|
||||
template<> __device__ inline bf16_2 mul::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hmul2(a, b); }
|
||||
template<> __device__ inline half mul::op<half> (const half &a, const half &b) { return __hmul(a, b); }
|
||||
template<> __device__ inline half_2 mul::op<half_2>(const half_2 &a, const half_2 &b) { return __hmul2(a, b); }
|
||||
/**
|
||||
* @brief Division operation.
|
||||
*
|
||||
* This operation calculates the quotient of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The quotient of the input values.
|
||||
*/
|
||||
struct div {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a/b; }
|
||||
};
|
||||
template<> __device__ inline float2 div::op<float2>(const float2 &a, const float2 &b) { return float2{a.x/b.x, a.y/b.y}; }
|
||||
template<> __device__ inline bf16 div::op<bf16> (const bf16 &a, const bf16 &b) { return __hdiv(a, b); }
|
||||
template<> __device__ inline bf16_2 div::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __h2div(a, b); } // this op is a special snowflake
|
||||
template<> __device__ inline half div::op<half> (const half &a, const half &b) { return __hdiv(a, b); }
|
||||
template<> __device__ inline half_2 div::op<half_2>(const half_2 &a, const half_2 &b) { return __h2div(a, b); }
|
||||
/**
|
||||
* @brief Maximum operation.
|
||||
*
|
||||
* This operation calculates the maximum of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The maximum of the input values.
|
||||
*/
|
||||
struct max {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return ::max(a, b); }
|
||||
};
|
||||
template<> __device__ inline float2 max::op<float2>(const float2 &a, const float2 &b) { return float2{::max(a.x, b.x), ::max(a.y, b.y)}; }
|
||||
template<> __device__ inline bf16 max::op<bf16> (const bf16 &a, const bf16 &b) { return __hmax(a, b); }
|
||||
template<> __device__ inline bf16_2 max::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hmax2(a, b); }
|
||||
template<> __device__ inline half max::op<half> (const half &a, const half &b) { return __hmax(a, b); }
|
||||
template<> __device__ inline half_2 max::op<half_2>(const half_2 &a, const half_2 &b) { return half_2{__hmax(a.x, b.x), __hmax(a.y, b.y)}; }
|
||||
/**
|
||||
* @brief Minimum operation.
|
||||
*
|
||||
* This operation calculates the minimum of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The minimum of the input values.
|
||||
*/
|
||||
struct min {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return ::min(a, b); }
|
||||
};
|
||||
template<> __device__ inline float2 min::op<float2>(const float2 &a, const float2 &b) { return float2{::min(a.x, b.x), ::min(a.y, b.y)}; }
|
||||
template<> __device__ inline bf16 min::op<bf16> (const bf16 &a, const bf16 &b) { return __hmin(a, b); }
|
||||
template<> __device__ inline bf16_2 min::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hmin2(a, b); }
|
||||
template<> __device__ inline half min::op<half> (const half &a, const half &b) { return __hmin(a, b); }
|
||||
template<> __device__ inline half_2 min::op<half_2>(const half_2 &a, const half_2 &b) { return half_2{__hmin(a.x, b.x), __hmin(a.y, b.y)}; }
|
||||
|
||||
|
||||
/* ---------- TERNARY OPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Fused multiply-add operation A * B + C.
|
||||
*
|
||||
* This operation performs a fused multiply-add, computing (A * B) + C with only one rounding.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @param c[in] The third input value to be added.
|
||||
* @return The result of the fused multiply-add operation.
|
||||
*/
|
||||
struct fma_AxBtC {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b, const T &c) {
|
||||
return sum::op<T>(mul::op<T>(a, b), c);
|
||||
}
|
||||
};
|
||||
/**
|
||||
* @brief Fused multiply-add operation A * C + B.
|
||||
*
|
||||
* This operation performs a fused multiply-add, computing (A * C) + B with only one rounding.
|
||||
* This is particularly useful for attention mechanisms in neural networks.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The third input value to be added.
|
||||
* @param c[in] The second input value.
|
||||
* @return The result of the fused multiply-add operation.
|
||||
*/
|
||||
struct fma_AxCtB { // this is the one needed for attention
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b, const T &c) {
|
||||
return sum::op<T>(mul::op<T>(a, c), b);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace base_ops
|
||||
|
||||
} // namespace kittens
|
||||
454
extra/thunder/amd/include/udna1/common/base_types.cuh
Normal file
454
extra/thunder/amd/include/udna1/common/base_types.cuh
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief Declarations, manipulations, and wrappers for basic types.
|
||||
*
|
||||
* This file is a bunch of utilities for going back and forth between different types.
|
||||
*
|
||||
* Many of them are for the compiler, so as to clean up the code. It unfortunately
|
||||
* seems necessary when we have types we really care about that are less than word width.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_fp8.h>
|
||||
#include <hip/hip_fp4.h>
|
||||
#include <hip/amd_detail/amd_hip_ocp_types.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <string>
|
||||
#include <bit>
|
||||
|
||||
typedef uint32_t __amd_fp8x4_storage_t;
|
||||
|
||||
namespace kittens {
|
||||
|
||||
// /**
|
||||
// * @brief Bfloat16 floating-point type.
|
||||
// */
|
||||
using bf16 = __hip_bfloat16;
|
||||
/**
|
||||
* @brief Half-precision floating-point type.
|
||||
*/
|
||||
using half = __half;
|
||||
// /**
|
||||
// * @brief Packed word of two bfloat16 floating-point values.
|
||||
// */
|
||||
using bf16_2 = __hip_bfloat162;
|
||||
/**
|
||||
* @brief Packed word of two half-precision floating-point values.
|
||||
*/
|
||||
using half_2 = __half2;
|
||||
/**
|
||||
* @brief float8 floating-point type.
|
||||
*/
|
||||
using fp8e4m3 = __hip_fp8_e4m3;
|
||||
/**
|
||||
* @brief Packed word of two float8 floating-point values.
|
||||
*/
|
||||
using fp8e4m3_2 = __hip_fp8x2_e4m3;
|
||||
/**
|
||||
* @brief Packed word of four float8 floating-point values.
|
||||
*/
|
||||
using fp8e4m3_4 = __hip_fp8x4_e4m3;
|
||||
/**
|
||||
* @brief 8-bit exponent-only block-scaling scale type.
|
||||
*/
|
||||
using fp8e8m0 = __amd_scale_t;
|
||||
/**
|
||||
* @brief Packed word of two 8-bit exponent-only block-scaling scale values.
|
||||
*/
|
||||
using fp8e8m0_2 = __amd_fp8x2_storage_t;
|
||||
/**
|
||||
* @brief Packed word of four 8-bit exponent-only block-scaling scale values.
|
||||
*/
|
||||
using fp8e8m0_4 = __amd_fp8x4_storage_t;
|
||||
/**
|
||||
* @brief FP4 E2M1 floating-point type.
|
||||
*/
|
||||
using fp4e2m1 = __hip_fp4_e2m1;
|
||||
/**
|
||||
* @brief Packed word of two FP4 E2M1 floating-point values.
|
||||
*/
|
||||
using fp4e2m1_2 = __hip_fp4x2_e2m1;
|
||||
/**
|
||||
* @brief Packed word of four FP4 E2M1 floating-point values.
|
||||
*/
|
||||
using fp4e2m1_4 = __hip_fp4x4_e2m1;
|
||||
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace base_types
|
||||
*
|
||||
* @brief A namespace for concepts for basic data types.
|
||||
*/
|
||||
namespace base_types {
|
||||
|
||||
template<typename T>
|
||||
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4>
|
||||
|| std::is_same_v<T, fp4e2m1_4>;
|
||||
template<typename T>
|
||||
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3>
|
||||
|| std::is_same_v<T, fp4e2m1>;
|
||||
|
||||
} // namespace base_types
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @namespace base_types
|
||||
*
|
||||
* @brief A namespace for ThunderKittens basic data types.
|
||||
*/
|
||||
namespace base_types {
|
||||
|
||||
/**
|
||||
* @brief Provides compile-time constants for different types.
|
||||
*
|
||||
* @tparam T The type for which to provide constants.
|
||||
*/
|
||||
template<typename T> struct constants {
|
||||
/**
|
||||
* @brief Zero
|
||||
* @return Constexpr zero with type T
|
||||
*/
|
||||
static __device__ inline constexpr T zero() { return T{0}; }
|
||||
/**
|
||||
* @brief Ones
|
||||
* @return Constexpr ones with type T
|
||||
*/
|
||||
static __device__ inline constexpr T ones() { return T{1}; }
|
||||
/**
|
||||
* @brief Positive infinity. Particularly useful for initializing before a min op.
|
||||
* @return Constexpr positive infinity with type T
|
||||
*/
|
||||
static __device__ inline constexpr T pos_infty() { return T{INFINITY}; } // I'll find a better way at some point but this appears to work.
|
||||
/**
|
||||
* @brief Negative infinity. Particularly useful for initializing before a max op.
|
||||
* @return Constexpr negative infinity with type T
|
||||
*/
|
||||
static __device__ inline constexpr T neg_infty() { return T{-INFINITY}; }
|
||||
};
|
||||
template<> struct constants<float2> {
|
||||
static __device__ inline constexpr float2 zero() { return float2{0.f, 0.f}; }
|
||||
static __device__ inline constexpr float2 ones() { return float2{1.f, 1.f}; }
|
||||
static __device__ inline constexpr float2 pos_infty() { return float2{constants<float>::pos_infty(), constants<float>::pos_infty()}; }
|
||||
static __device__ inline constexpr float2 neg_infty() { return float2{constants<float>::neg_infty(), constants<float>::neg_infty()}; }
|
||||
};
|
||||
template<> struct constants<bf16> {
|
||||
static __device__ inline constexpr bf16 zero() { return std::bit_cast<bf16>(uint16_t(0x0000)); } // unfortunately __float2bf16_rn is not constexpr
|
||||
static __device__ inline constexpr bf16 ones() { return std::bit_cast<bf16>(uint16_t(0x3F80)); }
|
||||
static __device__ inline constexpr bf16 pos_infty() { return std::bit_cast<bf16>(uint16_t(0x7F80)); }
|
||||
static __device__ inline constexpr bf16 neg_infty() { return std::bit_cast<bf16>(uint16_t(0xFF80)); }
|
||||
};
|
||||
template<> struct constants<bf16_2> {
|
||||
static __device__ inline bf16_2 zero() { return bf16_2{constants<bf16>::zero(), constants<bf16>::zero()}; }
|
||||
static __device__ inline bf16_2 ones() { return bf16_2{constants<bf16>::ones(), constants<bf16>::ones()}; }
|
||||
static __device__ inline bf16_2 pos_infty() { return bf16_2{constants<bf16>::pos_infty(), constants<bf16>::pos_infty()}; }
|
||||
static __device__ inline bf16_2 neg_infty() { return bf16_2{constants<bf16>::neg_infty(), constants<bf16>::neg_infty()}; }
|
||||
};
|
||||
template<> struct constants<half> {
|
||||
static __device__ inline constexpr half zero() { return std::bit_cast<half>(uint16_t(0x0000)); }
|
||||
static __device__ inline constexpr half ones() { return std::bit_cast<half>(uint16_t(0x3C00)); }
|
||||
static __device__ inline constexpr half pos_infty() { return std::bit_cast<half>(uint16_t(0x7C00)); }
|
||||
static __device__ inline constexpr half neg_infty() { return std::bit_cast<half>(uint16_t(0xFC00)); }
|
||||
};
|
||||
template<> struct constants<half_2> {
|
||||
static __device__ inline constexpr half_2 zero() { return std::bit_cast<half_2>(uint32_t(0x00000000)); }
|
||||
static __device__ inline constexpr half_2 ones() { return std::bit_cast<half_2>(uint32_t(0x3C003C00)); }
|
||||
static __device__ inline constexpr half_2 pos_infty() { return std::bit_cast<half_2>(uint32_t(0x7C007C00)); }
|
||||
static __device__ inline constexpr half_2 neg_infty() { return std::bit_cast<half_2>(uint32_t(0xFC00FC00)); }
|
||||
};
|
||||
template<> struct constants<fp8e4m3> {
|
||||
static __device__ inline constexpr fp8e4m3 zero() { return std::bit_cast<fp8e4m3>(uint8_t(0x00)); }
|
||||
static __device__ inline constexpr fp8e4m3 one() { return std::bit_cast<fp8e4m3>(uint8_t(0x38)); }
|
||||
};
|
||||
template<> struct constants<fp8e4m3_2> {
|
||||
static __device__ inline constexpr fp8e4m3_2 zero() { return std::bit_cast<fp8e4m3_2>(uint16_t(0x0000)); }
|
||||
static __device__ inline constexpr fp8e4m3_2 one() { return std::bit_cast<fp8e4m3_2>(uint16_t(0x3838)); }
|
||||
};
|
||||
template<> struct constants<fp8e4m3_4> {
|
||||
static __device__ inline constexpr fp8e4m3_4 zero() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x00000000)); }
|
||||
static __device__ inline constexpr fp8e4m3_4 one() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x38383838)); }
|
||||
};
|
||||
template<> struct constants<fp8e8m0> {
|
||||
static __device__ inline constexpr fp8e8m0 zero() { return std::bit_cast<fp8e8m0>(uint8_t(0x00)); } // not actually 0
|
||||
static __device__ inline constexpr fp8e8m0 one() { return std::bit_cast<fp8e8m0>(uint8_t(0x7F)); }
|
||||
};
|
||||
template<> struct constants<fp8e8m0_2> {
|
||||
static __device__ inline constexpr fp8e8m0_2 zero() { return std::bit_cast<fp8e8m0_2>(uint16_t(0x0000)); } // not actually 0
|
||||
static __device__ inline constexpr fp8e8m0_2 one() { return std::bit_cast<fp8e8m0_2>(uint16_t(0x7F7F)); }
|
||||
};
|
||||
template<> struct constants<fp8e8m0_4> {
|
||||
static __device__ inline constexpr fp8e8m0_4 zero() { return std::bit_cast<fp8e8m0_4>(uint32_t(0x00000000)); } // not actually 0
|
||||
static __device__ inline constexpr fp8e8m0_4 one() { return std::bit_cast<fp8e8m0_4>(uint32_t(0x7F7F7F7F)); }
|
||||
};
|
||||
template<> struct constants<fp4e2m1> {
|
||||
static __device__ inline constexpr fp4e2m1 zero() { return std::bit_cast<fp4e2m1>(uint8_t(0x00)); }
|
||||
static __device__ inline constexpr fp4e2m1 one() { return std::bit_cast<fp4e2m1>(uint8_t(0x02)); }
|
||||
};
|
||||
template<> struct constants<fp4e2m1_4> {
|
||||
static __device__ inline constexpr fp4e2m1_4 zero() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x0000)); }
|
||||
static __device__ inline constexpr fp4e2m1_4 one() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x2222)); }
|
||||
};
|
||||
template<> struct constants<int> {
|
||||
static __device__ inline constexpr int zero() { return 0; }
|
||||
static __device__ inline constexpr int ones() { return 1; }
|
||||
};
|
||||
template<> struct constants<int2> {
|
||||
static __device__ inline constexpr int2 zero() { return int2{0, 0}; }
|
||||
static __device__ inline constexpr int2 ones() { return int2{1, 1}; }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Provides information about packing of elements for a given type.
|
||||
*
|
||||
* @tparam T The type for which to provide packing information.
|
||||
*/
|
||||
template<typename T> struct packing {
|
||||
/**
|
||||
* @brief The number of elements packed together.
|
||||
*
|
||||
* @return constexpr int representing number of elements within the type.
|
||||
*/
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
/**
|
||||
* @brief Packs a single T element twice (replicated) into its packed type.
|
||||
*
|
||||
* @param i[in] The element to pack.
|
||||
* @return The packed type.
|
||||
*/
|
||||
static __device__ inline constexpr T pack(const auto &i);
|
||||
};
|
||||
template<> struct packing<bf16> {
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = bf16;
|
||||
using packed_type = bf16_2;
|
||||
static __device__ inline bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; }
|
||||
};
|
||||
template<> struct packing<bf16_2> {
|
||||
static __host__ __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = bf16;
|
||||
using packed_type = bf16_2;
|
||||
static __device__ inline bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<half> {
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = half;
|
||||
using packed_type = half_2;
|
||||
static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; }
|
||||
};
|
||||
template<> struct packing<half_2> {
|
||||
static __host__ __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = half;
|
||||
using packed_type = half_2;
|
||||
static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<float> {
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = float;
|
||||
using packed_type = float2;
|
||||
static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; }
|
||||
};
|
||||
template<> struct packing<float2> {
|
||||
static __host__ __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = float;
|
||||
using packed_type = float2;
|
||||
static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<int> {
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = int;
|
||||
using packed_type = int2;
|
||||
static __device__ inline constexpr int2 pack(const int &i) { return int2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<int2> {
|
||||
static __host__ __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = int;
|
||||
using packed_type = int2;
|
||||
static __device__ inline constexpr int2 pack(const int &i) { return int2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<float4> {
|
||||
static __host__ __device__ inline constexpr int num() { return 4; }
|
||||
};
|
||||
template<> struct packing<int4> {
|
||||
static __host__ __device__ inline constexpr int num() { return 4; }
|
||||
};
|
||||
template<> struct packing<fp8e4m3> {
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = fp8e4m3;
|
||||
using packed_type = fp8e4m3_4;
|
||||
};
|
||||
template<> struct packing<fp8e4m3_4> {
|
||||
static __host__ __device__ inline constexpr int num() { return 4; }
|
||||
using unpacked_type = fp8e4m3;
|
||||
using packed_type = fp8e4m3_4;
|
||||
};
|
||||
template<> struct packing<fp8e8m0> {
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = fp8e8m0;
|
||||
using packed_type = fp8e8m0_4;
|
||||
};
|
||||
template<> struct packing<fp8e8m0_4> {
|
||||
static __host__ __device__ inline constexpr int num() { return 4; }
|
||||
using unpacked_type = fp8e8m0;
|
||||
using packed_type = fp8e8m0_4;
|
||||
};
|
||||
template<> struct packing<fp4e2m1> {
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = fp4e2m1;
|
||||
using packed_type = fp4e2m1_4;
|
||||
};
|
||||
template<> struct packing<fp4e2m1_4> {
|
||||
static __host__ __device__ inline constexpr int num() { return 4; }
|
||||
using unpacked_type = fp4e2m1;
|
||||
using packed_type = fp4e2m1_4;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Provides templated functionality to convert between different types.
|
||||
*
|
||||
* @tparam T The target type for conversion.
|
||||
* @tparam U The source type for conversion.
|
||||
*/
|
||||
template<typename T, typename U> struct convertor {
|
||||
/**
|
||||
* @brief Converts a value of type U to type T.
|
||||
*
|
||||
* @param u[in] The value of type U to convert.
|
||||
* @return T The converted value of type T.
|
||||
*/
|
||||
static __host__ __device__ inline T convert(const U & u) {
|
||||
return (T)u;
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float, bf16> {
|
||||
static __host__ __device__ inline float convert(const bf16 & u) {
|
||||
return __bfloat162float(u);
|
||||
}
|
||||
};
|
||||
// template<> struct convertor<bf16, float> {
|
||||
// static __host__ __device__ inline bf16 convert(const float & u) {
|
||||
// return __float2bfloat16(u);
|
||||
// }
|
||||
// };
|
||||
template<> struct convertor<bf16, float> {
|
||||
static __host__ __device__ inline bf16 convert(const float &u) {
|
||||
// Fast unsafe conversion (truncation only)
|
||||
return std::bit_cast<bf16>(
|
||||
static_cast<uint16_t>(
|
||||
std::bit_cast<uint32_t>(u) >> 16
|
||||
)
|
||||
);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float2, bf16_2> {
|
||||
static __host__ __device__ inline float2 convert(const bf16_2 & u) {
|
||||
return __bfloat1622float2(u);
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct convertor<bf16_2, float2> {
|
||||
static __host__ __device__ inline bf16_2 convert(const float2 &u) {
|
||||
uint32_t result;
|
||||
asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2"
|
||||
: "=v"(result)
|
||||
: "v"(u.x), "v"(u.y));
|
||||
return *reinterpret_cast<bf16_2*>(&result);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<> struct convertor<float, half> {
|
||||
static __host__ __device__ inline float convert(const half & u) {
|
||||
return __half2float(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<half, float> {
|
||||
static __host__ __device__ inline half convert(const float & u) {
|
||||
return __float2half(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float2, half_2> {
|
||||
static __host__ __device__ inline float2 convert(const half_2 & u) {
|
||||
return __half22float2(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<half_2, float2> {
|
||||
static __host__ __device__ inline half_2 convert(const float2 & u) {
|
||||
return __float22half2_rn(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<bf16, half> {
|
||||
static __host__ __device__ inline bf16 convert(const half & u) {
|
||||
return __float2bfloat16(__half2float(u));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<half, bf16> {
|
||||
static __host__ __device__ inline half convert(const bf16 & u) {
|
||||
return __float2half(__bfloat162float(u));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<bf16_2, half_2> {
|
||||
static __host__ __device__ inline bf16_2 convert(const half_2 & u) {
|
||||
return __float22bfloat162_rn(__half22float2(u));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<half_2, bf16_2> {
|
||||
static __host__ __device__ inline half_2 convert(const bf16_2 & u) {
|
||||
return __float22half2_rn(__bfloat1622float2(u));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp8e4m3_4, float4> {
|
||||
static __host__ __device__ inline fp8e4m3_4 convert(const float4& u) {
|
||||
return fp8e4m3_4(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float4, fp8e4m3_4> {
|
||||
static __host__ __device__ inline float4 convert(const fp8e4m3_4& u) {
|
||||
fp8e4m3 *vals = reinterpret_cast<fp8e4m3*>(const_cast<fp8e4m3_4*>(&u));
|
||||
return make_float4(float(vals[0]), float(vals[1]), float(vals[2]), float(vals[3]));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp8e4m3_2, float2> {
|
||||
static __host__ __device__ inline fp8e4m3_2 convert(const float2& u) {
|
||||
return fp8e4m3_2(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float2, fp8e4m3_2> {
|
||||
static __host__ __device__ inline float2 convert(const fp8e4m3_2& u) {
|
||||
fp8e4m3 *vals = reinterpret_cast<fp8e4m3*>(const_cast<fp8e4m3_2*>(&u));
|
||||
return make_float2(float(vals[0]), float(vals[1]));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp8e4m3, float> {
|
||||
static __host__ __device__ inline fp8e4m3 convert(const float & u) {
|
||||
return fp8e4m3(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float, fp8e4m3> {
|
||||
static __host__ __device__ inline float convert(const fp8e4m3 & u) {
|
||||
return float(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp4e2m1, float> {
|
||||
static __host__ __device__ inline fp4e2m1 convert(const float & u) {
|
||||
return fp4e2m1(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float, fp4e2m1> {
|
||||
static __host__ __device__ inline float convert(const fp4e2m1 & u) {
|
||||
return float(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp4e2m1_4, float4> {
|
||||
static __host__ __device__ inline fp4e2m1_4 convert(const float4& u) {
|
||||
return fp4e2m1_4(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float4, fp4e2m1_4> {
|
||||
static __host__ __device__ inline float4 convert(const fp4e2m1_4& u) {
|
||||
return float4(u);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
13
extra/thunder/amd/include/udna1/common/common.cuh
Normal file
13
extra/thunder/amd/include/udna1/common/common.cuh
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief A collection of common resources on which HipKittens depends.
|
||||
*/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "util.cuh"
|
||||
#include "base_types.cuh"
|
||||
#include "base_ops.cuh"
|
||||
|
||||
#include "macros.cuh"
|
||||
56
extra/thunder/amd/include/udna1/common/debug.cuh
Normal file
56
extra/thunder/amd/include/udna1/common/debug.cuh
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
#pragma once
|
||||
|
||||
// Reset
|
||||
#define TK_RESET "\033[0m"
|
||||
|
||||
// Foreground colors
|
||||
#define TK_FG_BLACK "\033[30m"
|
||||
#define TK_FG_RED "\033[31m"
|
||||
#define TK_FG_GREEN "\033[32m"
|
||||
#define TK_FG_YELLOW "\033[33m"
|
||||
#define TK_FG_BLUE "\033[34m"
|
||||
#define TK_FG_MAGENTA "\033[35m"
|
||||
#define TK_FG_CYAN "\033[36m"
|
||||
#define TK_FG_WHITE "\033[37m"
|
||||
|
||||
// Background colors
|
||||
#define TK_BG_BLACK "\033[40m"
|
||||
#define TK_BG_RED "\033[41m"
|
||||
#define TK_BG_GREEN "\033[42m"
|
||||
#define TK_BG_YELLOW "\033[43m"
|
||||
#define TK_BG_BLUE "\033[44m"
|
||||
#define TK_BG_MAGENTA "\033[45m"
|
||||
#define TK_BG_CYAN "\033[46m"
|
||||
#define TK_BG_WHITE "\033[47m"
|
||||
|
||||
// Bright foreground colors
|
||||
#define TK_FG_BRIGHT_BLACK "\033[90m"
|
||||
#define TK_FG_BRIGHT_RED "\033[91m"
|
||||
#define TK_FG_BRIGHT_GREEN "\033[92m"
|
||||
#define TK_FG_BRIGHT_YELLOW "\033[93m"
|
||||
#define TK_FG_BRIGHT_BLUE "\033[94m"
|
||||
#define TK_FG_BRIGHT_MAGENTA "\033[95m"
|
||||
#define TK_FG_BRIGHT_CYAN "\033[96m"
|
||||
#define TK_FG_BRIGHT_WHITE "\033[97m"
|
||||
|
||||
// Bright background colors
|
||||
#define TK_BG_BRIGHT_BLACK "\033[100m"
|
||||
#define TK_BG_BRIGHT_RED "\033[101m"
|
||||
#define TK_BG_BRIGHT_GREEN "\033[102m"
|
||||
#define TK_BG_BRIGHT_YELLOW "\033[103m"
|
||||
#define TK_BG_BRIGHT_BLUE "\033[104m"
|
||||
#define TK_BG_BRIGHT_MAGENTA "\033[105m"
|
||||
#define TK_BG_BRIGHT_CYAN "\033[106m"
|
||||
#define TK_BG_BRIGHT_WHITE "\033[107m"
|
||||
|
||||
// Text styles
|
||||
#define TK_BOLD "\033[1m"
|
||||
#define TK_DIM "\033[2m"
|
||||
#define TK_ITALIC "\033[3m"
|
||||
#define TK_UNDERLINE "\033[4m"
|
||||
#define TK_BLINK "\033[5m"
|
||||
#define TK_REVERSE "\033[7m"
|
||||
#define TK_HIDDEN "\033[8m"
|
||||
|
||||
// Macro to combine styles
|
||||
#define TK_STYLE(...) "\033[" #__VA_ARGS__ "m"
|
||||
1301
extra/thunder/amd/include/udna1/common/macros.cuh
Normal file
1301
extra/thunder/amd/include/udna1/common/macros.cuh
Normal file
File diff suppressed because it is too large
Load diff
437
extra/thunder/amd/include/udna1/common/util.cuh
Normal file
437
extra/thunder/amd/include/udna1/common/util.cuh
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief General utilities for ThunderKittens.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
#include <memory>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "base_types.cuh"
|
||||
|
||||
#ifndef __forceinline__
|
||||
#define __forceinline__ __attribute__((always_inline))
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @namespace kittens
|
||||
*
|
||||
* @brief The main namespace of ThunderKittens.
|
||||
*/
|
||||
namespace kittens {
|
||||
|
||||
/* ---------- GENERAL CONSTANTS FOR KITTENS ---------- */
|
||||
/**
|
||||
* @brief Constant representing number of threads in a warp.
|
||||
*
|
||||
* gfx1250 (UDNA1) is wave-32.
|
||||
*/
|
||||
constexpr int WARP_THREADS{32};
|
||||
|
||||
/**
|
||||
|
||||
* @brief Get the warp ID of the current thread.
|
||||
* @return The warp ID.
|
||||
*/
|
||||
__device__ __forceinline__ int warpid() { return threadIdx.x >> 5; }
|
||||
|
||||
/**
|
||||
* @brief Get the number of warps in the threadblock.
|
||||
* @return The number of warps in the threadblock.
|
||||
*/
|
||||
__device__ __forceinline__ int num_warps() { return blockDim.x / WARP_THREADS; }
|
||||
|
||||
/**
|
||||
* @brief Get the lane ID of the current thread within its warp.
|
||||
* @return The lane ID.
|
||||
*/
|
||||
__device__ __forceinline__ int laneid() { return threadIdx.x & 0x1f; }
|
||||
|
||||
using i32x2 = int32_t __attribute__((ext_vector_type(2)));
|
||||
using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
|
||||
using i32x3 = int32_t __attribute__((ext_vector_type(3)));
|
||||
using u32x3 = uint32_t __attribute__((ext_vector_type(3)));
|
||||
using i32x4 = int32_t __attribute__((ext_vector_type(4)));
|
||||
using u32x4 = uint32_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
struct buffer_resource {
|
||||
uint64_t ptr;
|
||||
uint32_t range;
|
||||
uint32_t config;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compute the ceiling division of two integers.
|
||||
* @param a The dividend.
|
||||
* @param b The divisor.
|
||||
* @return The ceiling division result.
|
||||
*/
|
||||
__host__ __device__ inline int ceil_div(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Transform a workgroup ID to a new workgroup ID based on the chunk size and number of XCDs.
|
||||
* @param workgroup_id The original workgroup ID.
|
||||
* @param num_workgroups The total number of workgroups.
|
||||
* @param num_xcds The number of XCDs.
|
||||
* @param chunk_size The chunk size.
|
||||
* @return The new workgroup ID.
|
||||
*/
|
||||
__host__ __device__ inline int chiplet_transform_chunked(
|
||||
int workgroup_id,
|
||||
int num_workgroups,
|
||||
int num_xcds,
|
||||
int chunk_size
|
||||
) {
|
||||
// Current XCD
|
||||
int xcd = workgroup_id % num_xcds;
|
||||
|
||||
// Largest full (NUM_XCDS*CHUNK_SIZE)-aligned block
|
||||
int block = num_xcds * chunk_size;
|
||||
int limit = (num_workgroups / block) * block;
|
||||
|
||||
// If pid beyond the last full block, leave unchanged
|
||||
if (workgroup_id > limit) return workgroup_id;
|
||||
|
||||
// Local PID (within round-robin assignment)
|
||||
int local_pid = workgroup_id / num_xcds;
|
||||
int chunk_idx = local_pid / chunk_size;
|
||||
int pos_in_chunk = local_pid % chunk_size;
|
||||
|
||||
// New PID
|
||||
return chunk_idx * block + xcd * chunk_size + pos_in_chunk;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief gfx1250 LDS capacity constants.
|
||||
*
|
||||
* On gfx1250, the **LDS scratchpad and the L1 data cache are one 384 KB SRAM
|
||||
* pool per Compute Unit (CU)**, partitioned into six 64 KB segments.
|
||||
* At least one segment must remain L1, leaving up to five segments
|
||||
* (320 KB) addressable as LDS.
|
||||
*
|
||||
* `MAX_SHARED_MEMORY_PER_SEGMENT` is one 64 KB segment; `MAX_SHARED_MEMORY` is
|
||||
* the full addressable LDS across all five segments. A kernel that fits in one
|
||||
* segment requests `MAX_SHARED_MEMORY_PER_SEGMENT`; one that needs more requests
|
||||
* a larger dynamic shared-memory size at launch via `hipFuncSetAttribute`.
|
||||
*/
|
||||
constexpr int MAX_SHARED_MEMORY_PER_SEGMENT = 65536;
|
||||
constexpr int SHARED_MEMORY_NUM_SEGMENTS = 5;
|
||||
constexpr int MAX_SHARED_MEMORY = MAX_SHARED_MEMORY_PER_SEGMENT * SHARED_MEMORY_NUM_SEGMENTS;
|
||||
constexpr int NUM_XCDS = 1;
|
||||
constexpr int CUS_PER_XCD = 64;
|
||||
constexpr int NUM_CUS = CUS_PER_XCD * NUM_XCDS;
|
||||
|
||||
/* ---------- CUSTOM TYPES ---------- */
|
||||
typedef uint32_t uint2_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
/* ---------- TYPE HELPERS ---------- */
|
||||
|
||||
/**
|
||||
* @namespace ducks
|
||||
*
|
||||
* @brief ThunderKittens' namespace for template metaprogramming..
|
||||
*
|
||||
* This includes primarily dummy types and concept wrappers, along
|
||||
* with a few additional utilities.
|
||||
*/
|
||||
namespace ducks {
|
||||
|
||||
/**
|
||||
* @brief A type representing an empty default for a template.
|
||||
*/
|
||||
struct default_type {};
|
||||
|
||||
// This macro can't be done as a template, so it doesn't really have a location in kittens.
|
||||
#define typeof(A) typename std::remove_const<typename std::remove_reference<decltype(A)>::type>::type
|
||||
|
||||
}
|
||||
|
||||
/* ---------- SHUFFLE UTILS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Mask constant for all active threads in a warp.
|
||||
*/
|
||||
static constexpr uint64_t MASK_ALL = 0xFFFFFFFFFFFFFFFF;
|
||||
|
||||
/**
|
||||
* @brief Perform a shuffle down operation on a packed type synchronously across a warp.
|
||||
* @tparam T The type of the value to be shuffled.
|
||||
* @param mask[in] The mask of active threads.
|
||||
* @param f[in] The value to be shuffled.
|
||||
* @param delta[in] The number of positions to shuffle down.
|
||||
* @return The result of the shuffle operation.
|
||||
*/
|
||||
template<typename T>
|
||||
__device__ static inline T packed_shfl_down(uint64_t mask, const T &f, int delta) {
|
||||
|
||||
if constexpr (std::is_same_v<T, bf16_2> || std::is_same_v<T, bf16>) {
|
||||
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
|
||||
union {
|
||||
__hip_bfloat162 bf162;
|
||||
unsigned int ui;
|
||||
} u;
|
||||
|
||||
if constexpr (std::is_same_v<T, bf16_2>) {
|
||||
u.bf162 = *reinterpret_cast<const __hip_bfloat162*>(&f);
|
||||
} else {
|
||||
u.bf162 = __hip_bfloat162{*reinterpret_cast<const __hip_bfloat16*>(&f),
|
||||
*reinterpret_cast<const __hip_bfloat16*>(&f)};
|
||||
}
|
||||
|
||||
u.ui = __shfl_down_sync<unsigned long long, unsigned int>(mask, u.ui, delta, 64);
|
||||
if constexpr (std::is_same_v<T, bf16>) {
|
||||
return *reinterpret_cast<const T*>(&u.bf162.x); // Extract single bf16 from the .x component
|
||||
} else {
|
||||
return u.bf162; // Return full bf162 for bf16_2 case
|
||||
}
|
||||
} else {
|
||||
return __shfl_down(f, delta);
|
||||
}
|
||||
}
|
||||
template<>
|
||||
__device__ inline float2 packed_shfl_down<float2>(uint64_t mask, const float2 &f, int delta) {
|
||||
float2 r;
|
||||
r.x = __shfl_down(f.x, delta);
|
||||
r.y = __shfl_down(f.y, delta);
|
||||
return r;
|
||||
}
|
||||
/**
|
||||
* @brief Perform a packed shuffle operation synchronously across a warp.
|
||||
* @tparam T The type of the value to be shuffled.
|
||||
* @param mask[in] The mask of active threads.
|
||||
* @param f[in] The value to be shuffled.
|
||||
* @param src[in] The source lane from which to shuffle.
|
||||
* @return The result of the shuffle operation.
|
||||
*/
|
||||
template<typename T>
|
||||
__device__ static inline T packed_shfl(uint64_t mask, const T &f, int src) {
|
||||
return __shfl(f, src);
|
||||
}
|
||||
template<>
|
||||
__device__ inline bf16 packed_shfl(uint64_t mask, const bf16 &f, int src) {
|
||||
float r = __shfl(base_types::convertor<float, bf16>::convert(f), src);
|
||||
return base_types::convertor<bf16, float>::convert(r);
|
||||
}
|
||||
template<>
|
||||
__device__ inline bf16_2 packed_shfl(uint64_t mask, const bf16_2 &f, int src) {
|
||||
float2 r;
|
||||
r.x = __shfl(base_types::convertor<float, bf16>::convert(f.x), src);
|
||||
r.y = __shfl(base_types::convertor<float, bf16>::convert(f.y), src);
|
||||
return base_types::convertor<bf16_2, float2>::convert(r);
|
||||
}
|
||||
template<>
|
||||
__device__ inline half packed_shfl(uint64_t mask, const half &f, int src) {
|
||||
float r = __shfl(base_types::convertor<float, half>::convert(f), src);
|
||||
return base_types::convertor<half, float>::convert(r);
|
||||
}
|
||||
template<>
|
||||
__device__ inline half_2 packed_shfl(uint64_t mask, const half_2 &f, int src) {
|
||||
float2 r;
|
||||
r.x = __shfl(base_types::convertor<float, half>::convert(f.x), src);
|
||||
r.y = __shfl(base_types::convertor<float, half>::convert(f.y), src);
|
||||
return base_types::convertor<half_2, float2>::convert(r);
|
||||
}
|
||||
template<>
|
||||
__device__ inline float2 packed_shfl<float2>(uint64_t mask, const float2 &f, int src) {
|
||||
float2 r;
|
||||
r.x = __shfl(f.x, src);
|
||||
r.y = __shfl(f.y, src);
|
||||
return r;
|
||||
}
|
||||
|
||||
using bytes_4 = HIP_vector_type<float, 1>;
|
||||
using bytes_8 = HIP_vector_type<float, 2>;
|
||||
using bytes_16 = HIP_vector_type<float, 4>;
|
||||
|
||||
/* ---------- SHARED MEMORY UTILS ---------- */
|
||||
|
||||
// namespace ducks {
|
||||
// namespace sb {
|
||||
// struct identifier {};
|
||||
// }
|
||||
// }
|
||||
|
||||
// template<typename Args...>
|
||||
// struct sb {
|
||||
// using identifier = ducks::sb::identifier;
|
||||
// Args... args;
|
||||
// };
|
||||
|
||||
// namespace ducks {
|
||||
// namespace sb {
|
||||
// template<typename T> concept all = requires {
|
||||
// typename T::identifier;
|
||||
// } && std::is_same_v<T::identifier, identifier>;
|
||||
// }
|
||||
// }
|
||||
|
||||
#define KITTENS_ALIGN_AS(n) alignas(n)
|
||||
#define KITTENS_DEFAULT_ALIGN KITTENS_ALIGN_AS(16)
|
||||
|
||||
/**
|
||||
* @brief Dummy structure for alignment purposes. Needed for WGMMA and TMA calls.
|
||||
*/
|
||||
struct KITTENS_DEFAULT_ALIGN alignment_dummy { int dummy; };
|
||||
|
||||
namespace detail {
|
||||
/// @brief 16B (`int4`) vector types tagged with the address spaces the gfx1250
|
||||
/// `*_load_async_to_lds_b128` builtins require (AS1 = global, AS3 = LDS).
|
||||
using i32x4_vec = int __attribute__((__vector_size__(16)));
|
||||
using i32x4_gvec = int __attribute__((__vector_size__(16))) __attribute__((address_space(1)));
|
||||
using i32x4_lvec = int __attribute__((__vector_size__(16))) __attribute__((address_space(3)));
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* @brief Compile-time tag selecting an LDS segment for tile placement on gfx1250.
|
||||
*
|
||||
* Background. LDS and L1 share one 384 KB SRAM pool per Compute Unit (CU),
|
||||
* partitioned at dispatch into six 64 KB segments (see `MAX_SHARED_MEMORY`
|
||||
* above). Up to five segments (indices 0..4, total 320 KB) are addressable as
|
||||
* LDS scratchpad; at least one segment must remain L1. By convention we leave
|
||||
* segment 5 as L1, so LDS-tile placement uses indices 0..4.
|
||||
*
|
||||
* Why segments matter. The LDS half of the pool is fronted by two read ports
|
||||
* delivering 256 B/cycle each. The two ports can issue in the same cycle only
|
||||
* when they target **different** segments, so placing operand `A` in
|
||||
* `segment<0>` and operand `B` in `segment<1>` lets the hardware satisfy both
|
||||
* reads in parallel and reach the full 512 B/cycle peak. Co-locating `A` and
|
||||
* `B` in the same segment serialises them at 256 B/cycle.
|
||||
*
|
||||
* @tparam IDX 0..4 -- segment index. The allocator aligns the allocation start
|
||||
* to `IDX * 64 KB` so multiple tiles can share a single segment.
|
||||
*/
|
||||
template<int IDX>
|
||||
struct segment {
|
||||
static_assert(IDX >= 0 && IDX < SHARED_MEMORY_NUM_SEGMENTS,
|
||||
"segment index must be in [0, 5)");
|
||||
static constexpr int index = IDX;
|
||||
static constexpr int byte_offset = IDX * MAX_SHARED_MEMORY_PER_SEGMENT;
|
||||
};
|
||||
|
||||
namespace ducks {
|
||||
namespace segment_tag {
|
||||
template<typename T> struct is_segment : std::false_type {};
|
||||
template<int I> struct is_segment<::kittens::segment<I>> : std::true_type {};
|
||||
template<typename T> concept all = is_segment<T>::value;
|
||||
} // namespace segment_tag
|
||||
} // namespace ducks
|
||||
/**
|
||||
* @brief Very simple allocator for dynamic shared memory. Advances pointer and tracks alignments.
|
||||
*
|
||||
* Maintains a bump cursor `ptr` that advances on every `allocate*()` call. On
|
||||
* gfx1250 the allocator also remembers `base` -- the unmoved origin of the
|
||||
* shared-memory region captured at construction -- so segment-aware
|
||||
* allocations (`allocate_in<segment<IDX>>`) can jump to `base + IDX * 64 KB`
|
||||
* regardless of how far the bump cursor has already advanced.
|
||||
*
|
||||
* @tparam default_alignment The default alignment this allocator will enforce. If <=0 (default -1) it will not align.
|
||||
*/
|
||||
template<int default_alignment=16>
|
||||
struct shared_allocator {
|
||||
int *ptr; ///< Bump cursor; advances on every allocate*() call.
|
||||
int *base; ///< Frozen origin captured at construction; never moves.
|
||||
///< Reference point for `allocate_in<segment<IDX>>` segment starts.
|
||||
|
||||
private:
|
||||
// Recursive template to generate N-dimensional array type
|
||||
template<typename A, size_t... dims>
|
||||
struct variadic_array;
|
||||
template<typename A, size_t first_dim, size_t... rest_dims>
|
||||
struct variadic_array<A, first_dim, rest_dims...> {
|
||||
using type = typename variadic_array<A, rest_dims...>::type[first_dim];
|
||||
};
|
||||
template<typename A>
|
||||
struct variadic_array<A> {
|
||||
using type = A;
|
||||
};
|
||||
template<typename A, size_t... dims>
|
||||
using variadic_array_t = typename variadic_array<A, dims...>::type;
|
||||
|
||||
template<int alignment>
|
||||
__device__ inline void align_ptr() {
|
||||
if constexpr (alignment > 0) {
|
||||
uint64_t p = reinterpret_cast<uint64_t>(ptr);
|
||||
if(p % alignment != 0) {
|
||||
ptr = (int*)(p + (alignment-(p%alignment)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new shared allocator using a pointer to extern shared memory.
|
||||
*
|
||||
* `_ptr` is captured into the bump cursor `ptr`; on gfx1250 it is also
|
||||
* stashed into `base` so segment-aware allocations can recover the
|
||||
* original origin regardless of how far the cursor has advanced.
|
||||
*
|
||||
* @param[in] _ptr Pointer to the start of the extern shared memory.
|
||||
*/
|
||||
__device__ shared_allocator(int *_ptr): ptr(_ptr), base(_ptr) {}
|
||||
/**
|
||||
* @brief Allocate shared memory for a single instance or N-dimensional array of type A.
|
||||
* @tparam A The type of the object to allocate.
|
||||
* @tparam dims... A list of dimensions for the N-dimensional array.
|
||||
* @return Reference to the allocated object.
|
||||
*/
|
||||
template<typename A, size_t... dims>
|
||||
__device__ inline variadic_array_t<A, dims...>& allocate() {
|
||||
// static_assert(sizeof(A) % default_alignment == 0, "Type is not aligned properly for array allocation");
|
||||
align_ptr<default_alignment>();
|
||||
using at = variadic_array_t<A, dims...>;
|
||||
at*p = reinterpret_cast<at*>(ptr);
|
||||
ptr += sizeof(at)/sizeof(int);
|
||||
return *p;
|
||||
}
|
||||
/**
|
||||
* @brief Allocate shared memory for a single instance or N-dimensional array of type A.
|
||||
* @tparam alignment An alignment to enforce for this particular object.
|
||||
* @tparam A The type of the object to allocate.
|
||||
* @tparam dims... A list of dimensions for the N-dimensional array.
|
||||
* @return Reference to the allocated object.
|
||||
*/
|
||||
template<int alignment, typename A, size_t... dims>
|
||||
__device__ inline variadic_array_t<A, dims...>& allocate() {
|
||||
// static_assert(sizeof(A) % alignment == 0, "Type is not aligned properly for array allocation");
|
||||
align_ptr<alignment>();
|
||||
using at = variadic_array_t<A, dims...>;
|
||||
at*p = reinterpret_cast<at*>(ptr);
|
||||
ptr += sizeof(at)/sizeof(int);
|
||||
return *p;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Allocate shared memory inside a specific LDS segment on gfx1250.
|
||||
*
|
||||
* Positions the allocator pointer at `base + IDX * 64KB` (where `base`
|
||||
* is the dynamic-shared-memory pointer this allocator was constructed
|
||||
* with), then allocates the requested type there. Multiple
|
||||
* `allocate_in<segment<IDX>>` calls into the same segment pack tightly.
|
||||
*
|
||||
* @tparam SEG A `kittens::segment<IDX>` tag.
|
||||
* @tparam A The type of the object to allocate.
|
||||
* @tparam dims Optional array dimensions.
|
||||
*/
|
||||
template<typename SEG, typename A, size_t... dims>
|
||||
requires ducks::segment_tag::all<SEG>
|
||||
__device__ inline variadic_array_t<A, dims...>& allocate_in() {
|
||||
int* target = base + (SEG::byte_offset / sizeof(int));
|
||||
// If we've already allocated past the requested segment, keep
|
||||
// packing where we are; otherwise jump forward to the segment.
|
||||
if (ptr < target) ptr = target;
|
||||
using at = variadic_array_t<A, dims...>;
|
||||
at* p = reinterpret_cast<at*>(ptr);
|
||||
ptr += sizeof(at) / sizeof(int);
|
||||
return *p;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace kittens
|
||||
10
extra/thunder/amd/include/udna1/includes.cuh
Normal file
10
extra/thunder/amd/include/udna1/includes.cuh
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief Aggregate header for UDNA1 (gfx1250) ThunderKittens core include tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/common.cuh"
|
||||
#include "types/types.cuh"
|
||||
#include "ops/ops.cuh"
|
||||
32
extra/thunder/amd/include/udna1/ops/group/group.cuh
Normal file
32
extra/thunder/amd/include/udna1/ops/group/group.cuh
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of all group (multi-warp) operations defined by ThunderKittens
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../common/common.cuh"
|
||||
#include "../../types/types.cuh"
|
||||
#include "../warp/warp.cuh" // several group memory ops rely on underlying warp-scope ops
|
||||
|
||||
// A "warpgroup" is a special group of 4 consecutive warps defined by NVIDIA for certain SM_90+ operations.
|
||||
#define KITTENS_CHECK_WARPGROUP static_assert(N_WARPS==4, "PTX warpgroup (N_WARPS=4) function called from a non-warpgroup group.");
|
||||
|
||||
namespace kittens {
|
||||
/*
|
||||
This is meant to be used with a `using group_N = kittens::group<NUM_WORKERS>;` at the start of every kernel.
|
||||
*/
|
||||
template<int N_WARPS>
|
||||
struct group {
|
||||
static constexpr int GROUP_WARPS = N_WARPS; // This alias produces nice parallelism.
|
||||
static constexpr int GROUP_THREADS = N_WARPS * kittens::WARP_THREADS; // This alias produces nice parallelism.
|
||||
__device__ static inline int laneid() { return threadIdx.x % GROUP_THREADS; }
|
||||
__device__ static inline int warpid() { return laneid() / kittens::WARP_THREADS; }
|
||||
__device__ static inline int groupid() { return threadIdx.x / GROUP_THREADS; }
|
||||
|
||||
#include "memory/memory.cuh"
|
||||
};
|
||||
|
||||
using warpgroup = group<4>; // special scope commonly used by SM_90 and later.
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of colaborative group memory movement operations
|
||||
*/
|
||||
|
||||
#include "tile/tile.cuh"
|
||||
#include "vec/vec.cuh"
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief Group (collaborative warp) ops for loading shared tiles from and storing to global memory.
|
||||
*/
|
||||
template<int axis, bool assume_aligned, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) {
|
||||
kittens::store<axis, assume_aligned, ST, GL, COORD, GROUP_THREADS>(dst, src, idx);
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>> // default case
|
||||
__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) {
|
||||
kittens::store<2, false, ST, GL, COORD, GROUP_THREADS>(dst, src, idx);
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of group memory operations on tiles.
|
||||
*/
|
||||
|
||||
#include "global_to_shared.cuh"
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief Group (collaborative warp) ops for loading shared vectors from and storing to global memory.
|
||||
*/
|
||||
template<ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>> // default case
|
||||
__device__ static inline void store(const GL &dst, const SV &src, const COORD &idx) {
|
||||
kittens::store<SV, GL, COORD, GROUP_THREADS>(dst, src, idx);
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of group memory operations on vectors.
|
||||
*/
|
||||
|
||||
#include "global_to_shared.cuh"
|
||||
9
extra/thunder/amd/include/udna1/ops/ops.cuh
Normal file
9
extra/thunder/amd/include/udna1/ops/ops.cuh
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief A collection of all of the operations that ThunderKittens defines.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "warp/warp.cuh"
|
||||
#include "group/group.cuh"
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue