Merge branch 'master' into x86_cleanups

This commit is contained in:
George Hotz 2026-06-19 08:45:16 -07:00 committed by GitHub
commit fc45e15afb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
226 changed files with 14227 additions and 909 deletions

View file

@ -99,7 +99,6 @@ jobs:
ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
- name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging'
@ -134,32 +133,10 @@ jobs:
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
- name: Fuzz Padded Tensor Core GEMM
run: DEV=METAL M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
- name: Run LLaMA
run: |
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run quantized LLaMA
run: |
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
- name: Run quantized LLaMA3
run: |
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
#- name: Run LLaMA 7B on 4 (virtual) GPUs
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run OLMoE
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
- name: Run llama3.2
run: BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
- name: Run olmoe
run: BENCHMARK_LOG=olmoe JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m olmoe --benchmark --warmup
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
@ -235,9 +212,6 @@ jobs:
- name: Symlink models and datasets
run: |
mkdir -p weights
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
mkdir -p extra/datasets
ln -s /raid/datasets/imagenet extra/datasets/imagenet
@ -279,36 +253,16 @@ jobs:
# TODO: too slow
# - name: Run SDXL
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA
run: |
BENCHMARK_LOG=llama_nojit DEV=NV JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama DEV=NV JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run llama3.2
run: DEV=NV BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
- name: Run qwen3.5
run: DEV=NV BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run quantized LLaMA3
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
# - name: Run LLaMA-3 8B on 6 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-2 70B
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit DEV=NV JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 DEV=NV JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half DEV=NV HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam DEV=NV HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- uses: actions/upload-artifact@v7
with:
name: Speed (NVIDIA)
@ -402,10 +356,7 @@ jobs:
run: |
mkdir -p weights
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
mkdir -p extra/datasets
ln -s /raid/datasets/imagenet extra/datasets/imagenet
@ -458,18 +409,10 @@ jobs:
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 DEV=AMD python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
- name: Run SDXL
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 DEV=AMD python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA 7B
run: |
BENCHMARK_LOG=llama_nojit DEV=AMD JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama DEV=AMD JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA 7B with BEAM
run: BENCHMARK_LOG=llama_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run llama3.2
run: DEV=AMD BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
- name: Run qwen3.5
run: DEV=AMD BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-3 8B on 6 GPUs
@ -478,16 +421,6 @@ jobs:
# run: sudo modprobe amdgpu
# - name: Run LLaMA-2 70B
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral DEV=AMD python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit DEV=AMD JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 DEV=AMD JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half DEV=AMD HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam DEV=AMD HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run process replay tests
uses: ./.github/actions/process-replay
@ -594,10 +527,10 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
testqualcommbenchmark:
name: comma Benchmark
testcommalatest:
name: comma Benchmark (0.11.0)
runs-on: [self-hosted, Linux, comma]
timeout-minutes: 30
timeout-minutes: 10
defaults:
run:
shell: bash -e -o pipefail {0}
@ -622,6 +555,27 @@ jobs:
run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3.2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.11.0 dmonitoring
run: BENCHMARK_LOG=openpilot_0_11_0_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/dmonitoring_model.onnx
- name: Run process replay tests
uses: ./.github/actions/process-replay
testcommaold:
name: comma Benchmark (0.10.1)
runs-on: [self-hosted, Linux, comma]
timeout-minutes: 10
defaults:
run:
shell: bash -e -o pipefail {0}
if: github.repository_owner == 'tinygrad'
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging'
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: DEBUG=2 openpilot compile3 0.10.1 driving_vision
run: PYTHONPATH="." DEBUG=2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot compile3 0.10.1 driving_vision
@ -630,14 +584,44 @@ jobs:
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3.2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.10.1 dmonitoring
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
- name: Run process replay tests
uses: ./.github/actions/process-replay
testqualcommdsp:
name: DSP Benchmark
runs-on: [self-hosted, Linux, comma4]
timeout-minutes: 5
defaults:
run:
shell: bash -e -o pipefail {0}
if: github.repository_owner == 'tinygrad'
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging'
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: Checkout Code
uses: actions/checkout@v6
- name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging'
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: benchmark MobileNetV2 on DSP
run: |
# generate quantized weights
ln -s /data/home/tiny/tinygrad/extra/datasets/imagenet extra/datasets/imagenet
ln -s /data/home/tiny/tinygrad/testsig-*.so .
PYTHONPATH=. CC=clang-19 DEV=CPU QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
PYTHONPATH=. DEV=CPU QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
# benchmark on DSP with NOOPT=1, the devectorizer has issues
PYTHONPATH=. CC=clang-19 DEV=DSP NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
PYTHONPATH=. DEV=DSP NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
- name: Run process replay tests
uses: ./.github/actions/process-replay

View file

@ -218,14 +218,15 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: unittest-13
pydeps: "pillow ftfy regex pre-commit"
deps: testing_unit
llvm: 'true'
amd: 'true'
- name: Run NULL backend tests
run: DEV=NULL python -m pytest -n=auto test/null/ --durations=20
- name: Run targeted tests on NULL backend
run: DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
run: |
DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
DEV=NULL VIZ=1 python3 -m pytest -n=auto test/null/test_viz.py
# TODO: too slow
# - name: Run SDXL on NULL backend
# run: DEV=NULL DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
@ -249,7 +250,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: unittest-13
pydeps: "pillow ftfy regex pre-commit"
pydeps: "pre-commit"
deps: testing_unit
llvm: 'true'
- name: Run pre-commit test hooks
@ -378,7 +379,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness)
run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx

View file

@ -37,7 +37,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None,
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
if not fp8:
if ASM_GEMM:
@ -58,24 +58,25 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
if x_fp8 is None:
if FUSED_INPUT_QUANTIZE and amax_x is not None:
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed
x_fp8, x_scale, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
x_fp8, _, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
else:
x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
x_fp8, _, x_new_amax = quantize_fp8(x, amax_state=amax_x)
if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x_fp8, w.T):
assert amax_x is not None
if COLUMNWISE_WEIGHT_SCALE:
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
else:
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
return out, x_new_amax, x_fp8
return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
return (x_fp8.dot(w.T, dtype=dtypes.float) * ((amax_x.float() + 1e-8) / FP8_MAX) * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
def norm_quantize_matmul(x:Tensor, norm:Tensor, w:Tensor, w_inv_scale:Tensor, eps:float, amax_x:Tensor, grad_amax_state:Tensor):
if FUSED_ADD_NORM_MUL_QUANTIZE:
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_rmsnorm_mul_quantize_fp8
x_fp8, x_inv_scale, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
x_fp8, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
return out, x_normed, rrms, ret
x_normed, rrms = rmsnorm(x, eps)
out, *ret = matmul(x_normed * norm, w, amax_x=amax_x, w_inv_scale=w_inv_scale, grad_amax_state=grad_amax_state)
@ -85,8 +86,8 @@ def add_norm_quantize_matmul(x:Tensor, residual:Tensor, norm:Tensor, w:Tensor, w
grad_amax_state:Tensor|None=None):
if FUSED_ADD_NORM_MUL_QUANTIZE:
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_add_rmsnorm_mul_quantize_fp8
x_fp8, x_inv_scale, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
x_fp8, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
return out, h, x_normed, rrms, ret
h = x + residual
x_normed, rrms = rmsnorm(h, eps)
@ -98,8 +99,8 @@ def silu_w13_quantize_matmul(x_w13:Tensor, w2:Tensor, s_2:Tensor,
grad_amax_xw13:Tensor, grad_amax_xout:Tensor):
if FUSED_SILU_W13:
from extra.llama_kernels.cast_amax import fused_quantize_fp8_w13
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
x2_fp8, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, amax_x=amax_x2, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
return out, ret
hidden = x_w13.shape[-1] // 2
x_w1, x_w3 = x_w13[..., :hidden], x_w13[..., hidden:]
@ -125,8 +126,10 @@ class FlatTransformer:
# FeedForward
if SPLIT_W13:
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim)
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim)
if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim)
else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02)
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
else:
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
@ -154,9 +157,10 @@ class FlatTransformer:
self._fp8_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
self._fp8_next_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02):
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
if w is None:
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize()
if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))

View file

@ -1,5 +1,5 @@
from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType, dtypes
from tinygrad.dtype import DType, dtypes, AddrSpace
from tinygrad.tensor import Tensor
from tinygrad.device import Device, Buffer
from tinygrad.engine.jit import TinyJit
@ -39,7 +39,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
prg = to_program(call.src[0], Device[arg_uops[0].device].renderer)
info = prg.arg
functions[info.function_name] = prg.src[3].arg
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR]
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + list(info.vars)
statements.append((info.function_name, cargs, info.global_size, info.local_size))
return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
@ -253,17 +253,18 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
symbolic_vars = OrderedDict()
for i, (_, args, global_size, _) in enumerate(statements):
for j, var in enumerate(args):
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
if getattr(var, "op", None) is Ops.PARAM and var.addrspace is AddrSpace.ALU and var.arg.name is not None:
if var not in symbolic_vars:
symbolic_vars[var] = var.arg[0]
symbolic_vars[var] = var.expr
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
statements[i][1][j] = symbolic_vars[var]
if global_size:
for j, dim in enumerate(global_size):
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and \
any(s.op is Ops.PARAM and s.addrspace is AddrSpace.ALU for s in dim.src) and any(s.op is Ops.CONST for s in dim.src):
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
global_size[j] = f"_{name.expr}[0] + {val.arg}"
prg = ""
if target == "clang":

View file

@ -2630,7 +2630,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
@functools.cache
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3) -> UOp:
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0)
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0) + (1 if scale_mode & 4 else 0)
scales, extra = args[:n_scales], args[n_scales:]
M, K = A.shape[0]*A.shape[1], A.shape[2]
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
@ -2752,7 +2752,8 @@ def custom_hk_bf16_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str) -> UOp:
assert M % block_m == 0 and N % block_n == 0 and K % block_k == 0, f"invalid bf16 tile {(block_m, block_n, block_k)} for {(M, N, K)}"
threads = UOp.special(64 * num_warps, "lidx0")
workgroups = UOp.special((M // block_m) * (N // block_n), "gidx0")
sink = UOp.sink(C.base, A.base, B.base, threads, workgroups,
b_extra = args[0].base if len(args) >= 1 else B.base
sink = UOp.sink(C.base, A.base, B.base, b_extra, threads, workgroups,
arg=KernelInfo(f"hk_bf16_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K+M*N)*A.dtype.itemsize)))
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
src = (kittens_path/"gemm_bf16.cpp").read_text()
@ -2787,15 +2788,21 @@ def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
assert M % TILE_M == 0 and N % TILE_N == 0 and (batch * rows) % TILE_K == 0, \
f"atb shape {a.shape} {b.shape} must produce (M,N,K) multiples of ({TILE_M},{TILE_N},{TILE_K})"
is_multi = isinstance(a.device, tuple)
reduce_out = False
if is_multi:
out = Tensor(Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
ndev = len(a.device)
if a.uop.axis in (0, 1) or b.uop.axis in (0, 1): inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
elif b.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M, N // ndev, dtype=a.dtype, device=a.device), 2
elif a.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M // ndev, N, dtype=a.dtype, device=a.device), 1
else: inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
out = Tensor(inv.uop.multi(out_axis), device=a.device)
dname = a.device[0]
else:
out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device)
dname = a.device
dname = dname.split(":")[0]
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_hk_bf16_atb_gemm, dname=dname))[0]
if is_multi: out = out.sum(0)
if reduce_out: out = out.sum(0)
return out.squeeze(0) if out.ndim == 3 else out
@ -2807,13 +2814,15 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
out, a, b = inputs[:3]
i = 3
s_x = inputs[i]; i += 1
has_w = n_scales == 2
has_w = n_scales >= 2
s_w = inputs[i] if has_w else None; i += has_w
s_g = inputs[i] if n_scales == 3 else None; i += (n_scales == 3)
grad_amax_state = inputs[i] if has_grad_amax else None; i += has_grad_amax
w_post = inputs[i] if has_w_post else None
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
s_x_t = Tensor(s_x, device=a.device)
s_w_t = Tensor(s_w, device=a.device) if has_w else None
s_g_t = Tensor(s_g, device=a.device) if s_g is not None else None
w_post_t = Tensor(w_post, device=a.device) if has_w_post else None
g_t = g_t[:a.shape[0]]
from extra.llama_kernels.cast_amax import _grad_fp8_mailbox
@ -2838,7 +2847,8 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
store_effect = grad_amax_state.store(new_grad_amax.uop)
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t) if has_w else asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t)
if s_g_t is not None: g_scale = g_scale * s_g_t
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
# wgrad: no w_scale
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
@ -2846,7 +2856,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
else:
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
# wgrad: rescale if not scalar
if w_post_t is not None:
grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim)))
@ -2899,7 +2909,7 @@ def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool, w_stored:bool=F
# ** main gemm function
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False) -> Tensor:
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False, g_scale:Tensor|None=None) -> Tensor:
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
counters["used"] += 1
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
@ -2946,8 +2956,8 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
elif a.dtype == FP8_DTYPE:
scales = tuple(s for s in (x_scale, w_scale) if s is not None)
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0)
scales = tuple(s for s in (x_scale, w_scale, g_scale) if s is not None)
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0) | (4 if g_scale is not None else 0)
extra = ([grad_amax_state] if grad_amax_state is not None else []) + ([w_post_scale] if w_post_scale is not None else [])
fxn = functools.partial(custom_hk_fp8_gemm, dname=dname, scale_mode=scale_mode)
bw = functools.partial(custom_gemm_bw, n_scales=len(scales), has_grad_amax=grad_amax_state is not None, has_w_post=w_post_scale is not None)

View file

@ -1,145 +0,0 @@
from __future__ import annotations
import time
from typing import cast
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, MultiBuffer
from tinygrad.dtype import dtypes
from tinygrad.engine.jit import GraphRunner
from tinygrad.engine.realize import get_call_outs_ins, get_runtime
from tinygrad.helpers import round_up, ceildiv
from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, graph_rewrite
from extra.hcq2.hcq2 import HCQ2Compiled, HCQ2DeviceCtx, HCQ2LowerCtx, pm_prep_runtime, pm_lower_ops
from extra.hcq2.hcq2 import pm_split_into_queues, pm_add_barriers, pm_add_signals
from extra.hcq2.hcq2 import pm_bufferize, pm_lift_patches_to_cmdbuf, pm_resolve_patches, pm_parametrize_host_buffers
from extra.hcq2.hcq2 import pm_add_timeline_inc, pm_callify, pm_calc_kernargs_sizes
# **************** insert deps ****************
def insert_deps(ctx:HCQ2Graph, linear:UOp) -> UOp:
src = []
for j, call in enumerate(linear.src):
call = call.replace(tag=j)
_, _, bufs, _ = ctx.calls[j]
outs, ins = get_call_outs_ins(call)
deps = ctx._access_resources([bufs[i] for i in outs + ins], list(range(len(outs))), call)
src.append(UOp(Ops.AFTER, call.dtype, (call, *deps), tag=call.tag))
return linear.replace(src=tuple(src))
pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)])
pm_replace_params = PatternMatcher([
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
(UPat(Ops.SLICE, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
])
# **************** graph-only passes ****************
def alloc_queue_sig(ctx:HCQ2Graph, q:UOp) -> None:
if q.arg in ctx.queue_sigs: return None
dev = q.arg[0][0] # TODO: multi device
buf = Buffer(dev, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
ctx.queue_sig_bufs.append(buf)
ctx.queue_sigs[q.arg] = UOp.from_buffer(buf, dev)
return None
pm_alloc_queue_sigs = PatternMatcher([(UPat(Ops.LINEAR, src=UPat({Ops.PROGRAM, Ops.COPY}), name="q"), alloc_queue_sig)])
def lower_queue_deps(ctx:HCQ2Graph, after:UOp) -> UOp:
wrapper, deps, call_idx = after.src[0], after.src[1:], after.tag
def store(q_arg, v): return ctx.queue_sigs[q_arg].store(UOp.const(dtypes.uint32, v))
waits = tuple(UOp(Ops.WAIT, dtypes.void, (ctx.queue_sigs[dep.src[0].arg], UOp.const(dtypes.uint32, dep.tag),
store(dep.src[0].arg, dep.tag))) for dep in deps)
return wrapper.replace(src=tuple(q.replace(src=(*waits, *q.src, store(q.arg, call_idx))) for q in wrapper.src))
pm_lower_queue_deps = PatternMatcher([(UPat(Ops.AFTER, src=UPat(Ops.LINEAR), name="after"), lower_queue_deps)])
def optimize_queue_deps(ctx:HCQ2Graph, queue:UOp) -> UOp|None:
src, seen, pending, queue_sig = [], {}, {}, ctx.queue_sigs[queue.arg]
for x in queue.src:
if x.op is Ops.WAIT:
sig, val = x.src[0], x.src[1]
if sig is queue_sig or seen.get(sig, -1) >= val.arg: continue
if (old:=pending.get(sig)) is None or old.src[1].arg < val.arg: pending[sig] = x
continue
for wait in pending.values():
src.append(wait)
seen[wait.src[0]] = wait.src[1].arg
pending.clear()
src.append(x)
src += pending.values()
return queue.replace(src=tuple(src)) if tuple(src) != queue.src else None
pm_optimize_queue_deps = PatternMatcher([
(UPat(Ops.LINEAR, src=UPat({Ops.BARRIER, Ops.WAIT, Ops.STORE, Ops.PROGRAM, Ops.COPY}), name="queue"), optimize_queue_deps),
])
def drop_dead_stores(ctx:HCQ2Graph, outer:UOp) -> UOp:
live = {u.src[2] for u in outer.toposort() if u.op is Ops.WAIT}
return outer.replace(src=tuple(q.replace(src=tuple(x for x in q.src if x.op is not Ops.STORE or x in live)) for q in outer.src))
pm_drop_dead_stores = PatternMatcher([(UPat(Ops.LINEAR, src=UPat(Ops.LINEAR), name="outer"), drop_dead_stores)])
def add_queue_sig_resets(ctx:HCQ2Graph, x:UOp, cmdbuf:UOp) -> UOp|None:
if not ctx.queue_sig_bufs or cmdbuf.tag not in ("compute", "copy"): return None
resets = tuple((b:=UOp.from_buffer(sig)).index(UOp.const(dtypes.int, 0), dtype=b.dtype.ptr())
.cast(dtypes.uint64.ptr()).store(UOp.const(dtypes.uint64, 0)) for sig in ctx.queue_sig_bufs)
return x.replace(src=x.src + resets)
pm_add_queue_sig_resets = PatternMatcher([(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, name="cmdbuf"),), allow_any_len=True, name="x"),
add_queue_sig_resets)])
# **************** Graph ****************
class HCQ2Graph(GraphRunner):
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
super().__init__(linear, input_uops)
self.dev = cast(HCQ2Compiled, Device[self.device])
self.hcq_ctx = HCQ2LowerCtx(name="hcq_graph")
self.input_addrs = Buffer("CPU", max(len(input_uops), 1), dtypes.uint64, preallocate=True)
self.input_addrs_uop = UOp.from_buffer(self.input_addrs, "CPU")
self.linear = graph_rewrite(self.linear, pm_insert_deps, ctx=self, name="hcq: insert deps", walk=True)
self.linear = graph_rewrite(self.linear, pm_replace_params, ctx=self, name="hcq: replace params", walk=True)
self.linear = graph_rewrite(self.linear, pm_prep_runtime, ctx=self.hcq_ctx, name="hcq: prepare runtime")
self.linear = graph_rewrite(self.linear, pm_lower_ops, ctx=self.hcq_ctx, name="hcq: lower ops")
# per-queue signal state — populated as a side-effect by pm_alloc_queue_sigs walking the lowered linear.
self.queue_sig_bufs:list[Buffer] = []
self.queue_sigs:dict[tuple[str, str], UOp] = {}
graph_rewrite(self.linear, pm_alloc_queue_sigs, ctx=self, name="hcq: alloc queue sigs", walk=True)
self.linear = graph_rewrite(self.linear, pm_lower_queue_deps, ctx=self, name="hcq: lower queue deps")
self.linear = graph_rewrite(self.linear, pm_split_into_queues, ctx=self.hcq_ctx, name="hcq: split into queues")
self.linear = graph_rewrite(self.linear, pm_add_barriers, ctx=self.hcq_ctx, name="hcq: add barriers", walk=True)
self.linear = graph_rewrite(self.linear, pm_optimize_queue_deps, ctx=self, name="hcq: optimize queue deps", walk=True)
self.linear = graph_rewrite(self.linear, pm_drop_dead_stores, ctx=self, name="hcq: drop dead stores")
self.linear = graph_rewrite(self.linear, pm_add_signals, ctx=self.hcq_ctx, name="hcq: add signals", walk=True)
self.linear = graph_rewrite(self.linear, pm_add_timeline_inc, ctx=self.hcq_ctx, name="hcq: add submit", walk=True)
self.linear = graph_rewrite(self.linear, self.dev.pm_lower, ctx=self.hcq_ctx, name=f"hcq: encode cmdbuf {self.dev.device}", walk=True)
graph_rewrite(self.linear, pm_calc_kernargs_sizes, ctx=(sizes:={}), name=None)
for dev_name, sz in sizes.items():
buf = Buffer(dev_name, sz, dtypes.uint8, options=BufferSpec(cpu_access=True), preallocate=True)
self.hcq_ctx.dev_ctx[dev_name] = HCQ2DeviceCtx(dev_name, UOp.from_buffer(buf, dev_name), UOp.const(dtypes.uint64, buf._buf.va_addr))
self.linear = graph_rewrite(self.linear, pm_bufferize, ctx=self.hcq_ctx, bottom_up=True, name="realize binaries")
self.linear = graph_rewrite(self.linear, pm_lift_patches_to_cmdbuf, ctx=self.hcq_ctx, bottom_up=False, name="lift patches to cmdbuf")
self.linear = graph_rewrite(self.linear, pm_resolve_patches, ctx=self.hcq_ctx, bottom_up=False, name="simplify patches")
self.linear = graph_rewrite(self.linear, pm_add_queue_sig_resets, ctx=self, name="hcq: add queue sig resets", walk=True)
self.linear = graph_rewrite(self.linear, pm_parametrize_host_buffers, ctx=self.hcq_ctx, bottom_up=True, name="parametrize host buffers")
self.host_call = graph_rewrite(self.linear, pm_callify, ctx=self.hcq_ctx, name="hcq: callify")
self.host_rt, self.host_globals = get_runtime("CPU", self.host_call.src[0]), self.host_call.src[0].arg.globals
def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False) -> float|None:
addrs = self.input_addrs.as_memoryview(force_zero_copy=True).cast('Q')
for i, u in enumerate(input_uops):
buf = next(b for b in u.buffer.bufs if b.device == self.dev.device) if isinstance(u.buffer, MultiBuffer) else u.buffer
addrs[i] = buf._buf.va_addr
self.host_rt(*[self.hcq_ctx.inputs[i].get_buf("CPU") for i in self.host_globals], vals=self.host_call.src[0].arg.vals(var_vals), wait=True)
if wait:
st = time.perf_counter()
self.dev.synchronize()
return time.perf_counter() - st
return None
@staticmethod
def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool:
all_devs = GraphRunner._all_devs(batch_devs, new_call)
return new_call.src[0].op in (Ops.PROGRAM, Ops.COPY) and len(all_devs) == 1 and isinstance(all_devs[0], HCQ2Compiled)

View file

@ -1,14 +1,13 @@
from __future__ import annotations
from typing import cast, Callable, TypeVar, Generic, Any
import struct, functools, time, collections, importlib, itertools, weakref
from dataclasses import replace
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, pluralize, to_tuple
from dataclasses import replace, dataclass, field
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, DEBUG, dedup, flatten, pluralize
from tinygrad.helpers import to_tuple, round_up
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
from tinygrad.uop.symbolic import symbolic_simple, symbolic
from tinygrad.dtype import dtypes, DType
from dataclasses import dataclass, field
from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.renderer import Renderer, Estimates
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
@ -28,7 +27,7 @@ class HCQ2Compiled(Compiled):
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value()),
(UPat(Ops.BUFFER, tag="sentinel_signal"), lambda ctx: ctx.timeline_signal("sentinel", (1 << 64) - 1)),
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=False, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
])
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
@ -137,9 +136,9 @@ def unwrap_after(uop):
while uop.op is Ops.AFTER: uop = uop.src[0]
return uop
def make_getaddr(u, dev=None):
def make_getaddr(u, device=None):
if unwrap_after(u).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return u
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=dev or to_tuple(u.device)[0])))
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=device or to_tuple(u.device)[0])))
def make_ins(op, *srcs):
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
@ -149,7 +148,7 @@ def make_cmdbuf(lin, devs, tag):
for s in (s for ins in lin.src for s in ins.src):
if s.op is not Ops.CONST: patches.append((len(blob), s))
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
buf = UOp.new_buffer(devs if len(devs) > 1 else devs[0], len(blob), dtypes.uint8).rtag(tag)
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
@ -172,6 +171,10 @@ class HCQInfo:
name:str = ""
estimates:Estimates = Estimates()
outs:tuple[int, ...] = ()
devs:tuple[str, ...] = ()
params:tuple[int, ...] = ()
inputs:int|None = None
@staticmethod
def from_call(call:UOp) -> HCQInfo: return HCQInfo(get_call_name(call, get_call_arg_uops(call)), estimate_uop(call), get_call_outs_ins(call)[0])
@ -269,6 +272,14 @@ class DepsCtx:
deps:DepsTracker = field(default_factory=DepsTracker)
opid:itertools.count = field(default_factory=lambda: itertools.count(0))
last_per_queue:weakref.WeakValueDictionary[tuple[Any, str], UOp] = field(default_factory=weakref.WeakValueDictionary)
params:dict[tuple[int, int], Buffer] = field(default_factory=dict)
def get_dep_buf(ctx:DepsCtx, u:UOp, lane:int) -> Buffer:
# TODO: should this be a part of DepsTracker?
if u.op is Ops.PARAM: return ctx.params.setdefault((u.arg.slot, lane), Buffer("NULL", u.max_numel(), u.dtype.base))
if u.op is Ops.MSTACK: return get_dep_buf(ctx, u.src[lane], 0)
if u.op in (Ops.SLICE, Ops.MSELECT): return get_dep_buf(ctx, u.src[0], u.arg if u.op is Ops.MSELECT else lane)
return b.bufs[lane] if isinstance(b:=u.buffer, MultiBuffer) else b
def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
new_src = []
@ -277,15 +288,25 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
new_src.append(call)
continue
q = get_submit(call.src[0]).src[0]
new_q = ctx.last_per_queue[q.arg] = q.rtag(next(ctx.opid))
new_q = ctx.last_per_queue[q.arg] = (q:=get_submit(call.src[0]).src[0]).rtag(next(ctx.opid))
qdevs, refs = to_tuple(new_q.arg[0]), get_call_arg_uops(call)
deps = []
refs = [b.buffer for b in get_call_arg_uops(call)]
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.arg.aux.outs, new_q)
# per-lane deps, tracked per (device, queue). skip self
dep_lanes:list[tuple[UOp, int]] = []
for lane, d in enumerate(qdevs):
for dep in ctx.deps.access_resources([get_dep_buf(ctx, b, lane) for b in refs], call.arg.aux.outs, new_q.replace(arg=(d, new_q.arg[1]))):
if dep.tag != new_q.tag: dep_lanes.append((dep, lane))
new_q = new_q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else new_q
# drop self-queue waits, queue self-orders
if qdevs[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
dep_lanes = [(dep, lane) for dep, lane in dep_lanes if dep.arg != (qdevs[lane], new_q.arg[1])]
# keep latest dep per lane, group lanes
latest = {(dep.arg, lane): dep for dep, lane in sorted(dep_lanes, key=lambda x: x[0].tag)}
deps:dict[UOp, tuple[int, ...]] = collections.defaultdict(tuple)
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
@ -298,11 +319,16 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
zero = UOp.const(dtypes.int, 0)
tl = make_signal_value(devs)
submit = make_submit(make_signal(devs).store(tl.index(zero) + 1), devs=devs, queue="COMPUTE:0")
# queue is inc with deps
submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0")
# split each (multi-device) queue into per-device deps so each finalizer lane waits on the matching device's signal
lane_queues = [(q.replace(arg=(d, q.arg[1])), (devs.index(d),)) for q in queues for d in to_tuple(q.arg[0])]
submit = submit.replace(src=(submit.src[0].after(*(q for q, _ in lane_queues), arg=tuple(l for _, l in lane_queues)).rtag("deps"),))
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]) \
.sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]
return UOp.barrier(*patches).sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
parts:dict[str, list[UOp]] = collections.defaultdict(list)
@ -319,21 +345,22 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
cur_devs = to_tuple((cur:=deps.src[0]).arg[0])
waits = []
for dep in deps.src[1:]:
devs, queue = dep.arg
for lanes, dep in zip(deps.arg, deps.src[1:]):
dep_dev, queue = dep.arg # dep_dev is a single device (deps are recorded per-device)
ctx.add(dep.tag) # mark op to update signal.
sig = make_mstack([make_signal(d, queue=queue, sentinel=d not in devs) for d in cur_devs])
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0))
# for lanes that need this dep, wait on the dep device's signal/value; other lanes get a passing sentinel
lanes = set(lanes)
sig = make_mstack([make_signal(dep_dev if j in lanes else d, queue=queue, sentinel=j not in lanes) for j, d in enumerate(cur_devs)])
val = make_mstack([make_signal_value(dep_dev if j in lanes else d, queue=queue) for j, d in enumerate(cur_devs)]).index(UOp.const(dtypes.int, 0))
waits.append(sig.wait(val + dep.tag))
return cur.replace(src=tuple(waits) + cur.src)
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
src = q.src
if q.tag in ctx:
devs, queue = q.arg
src += (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp|None:
if q.tag not in ctx: return None
devs, queue = q.arg
src = q.src + (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
return submit.replace(src=(q.replace(src=src, tag=None),))
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
@ -342,34 +369,36 @@ pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", s
def get_submit(ast:UOp) -> UOp: return next(u for u in ast.toposort() if u.op is Ops.CUSTOM_FUNCTION and u.arg == "submit")
def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
old_submit, new_submit = get_submit(old_sink), get_submit(new_sink)
old_queue, new_queue = old_submit.src[0], new_submit.src[0]
merged_submit = new_submit.replace(src=(new_queue.replace(src=old_queue.src + new_queue.src),))
old_root = old_sink.src[0].substitute({old_submit: merged_submit})
new_anchor = merged_submit if old_sink.src[0] is old_submit else old_root
return new_sink.substitute({new_submit: new_anchor})
def merge_sink(sinks:list[UOp]) -> UOp:
if len(sinks) == 1: return sinks[0]
submits = [get_submit(sink) for sink in sinks]
queues = [submit.src[0] for submit in submits]
anchor = submits[-1].replace(src=(queues[-1].replace(src=tuple(x for q in queues for x in q.src)),))
for sink, submit in zip(sinks[:-1], submits[:-1]):
if sink.src[0] is not submit: anchor = sink.src[0].substitute({submit: anchor}, walk=True)
return sinks[-1].substitute({submits[-1]: anchor}, walk=True)
def merge_queues(linear:UOp) -> UOp:
new_src:list[UOp] = []
opened_qs:dict[tuple[tuple[str, ...], str], tuple[UOp, HCQInfo]] = {} # (devs, queue) -> (sink, aux), kept in submit order
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
for call in linear.src:
if call.tag != "hcq":
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in list(opened_qs)] + [call]
# finalizer cannot be merged, since it bumps inner signal (this introduces race when multidevs).
if call.tag != "hcq" or (call.tag == "hcq" and call.arg.aux.name == "hcq finalizer"):
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
continue
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
aux = call.arg.aux
new_rec = ([new_sink], call.arg.aux)
if (old:=opened_qs.pop((devs, queue), None)) is not None:
new_sink = merge_sinks(old[0], new_sink) # exact same queue: merge, and re-insert at the end
aux = replace(aux, name=f"{queue.lower()} submit", estimates=old[1].estimates + aux.estimates)
new_rec = (old[0] + [new_sink], replace(new_rec[1], name=f"{queue.lower()} submit", estimates=old[1].estimates + new_rec[1].estimates))
else:
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
opened_qs[(devs, queue)] = (new_sink, aux)
closing = [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in closing]
opened_qs[(devs, queue)] = new_rec
return linear.replace(src=tuple(new_src + [sink.call(aux=aux).rtag('hcq') for sink, aux in opened_qs.values()]))
return linear.replace(src=tuple(new_src + [merge_sink(sinks).call(aux=aux).rtag("hcq") for sinks, aux in opened_qs.values()]))
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
# *****************
@ -384,6 +413,29 @@ def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
return submit.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), wait, *q.src)),))
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
# *****************
# 4.3. annotate exec devs
pm_annotate_devs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
lambda call: call.replace(arg=replace(call.arg, aux=replace(call.arg.aux, devs=get_submit(call.src[0]).src[0].arg[0]))))])
# *****************
# 4.4. replace params with per-submit input address loads
def replace_params(call:UOp) -> UOp|None:
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is AddrSpace.GLOBAL}): return None
# fill new info
hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call)))
inputs = UOp.new_buffer(get_submit(call.src[0]).src[0].arg[0], len(hcqinfo.params), dtypes.uint64).rtag("inputs")
slot2idx = {s:i for i,s in enumerate(hcqinfo.params)}
body = call.src[0].substitute({u:inputs.index(UOp.const(dtypes.int, slot2idx[s])).load() for u,s in params.items()})
return call.replace(src=(body, *call.src[1:], inputs), arg=replace(call.arg, aux=hcqinfo))
pm_replace_params = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), replace_params)])
# *****************
# 5.1. encode cmdbufs
@ -411,7 +463,28 @@ pm_lift_patches_to_cmdbuf = PatternMatcher([
])
# *****************
# 5.3. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
# 5.3. pack placeholders buffers
def pack_hcq_placeholders(call:UOp) -> UOp|None:
bufs = [b for b in call.src[0].toposort() if b.op is Ops.BUFFER and b.tag in (maxtags:={"scratch"}) | (sumtags:={"program", "kernargs"})]
off_per_buf:dict[UOp, int] = {}
size_per_tag:dict[str, int] = {}
for b in bufs:
if b.tag in maxtags: size_per_tag[b.tag] = max(size_per_tag.get(b.tag, 0), b.arg)
elif b.tag in sumtags:
off_per_buf[b] = round_up(size_per_tag.get(b.tag, 0), {"program": 0x1000}.get(b.tag, 128))
size_per_tag[b.tag] = off_per_buf[b] + b.arg
count_per_tag = collections.Counter(b.tag for b in bufs)
ref_bufs = {b.tag:b for b in bufs if count_per_tag[b.tag] > 1}
bases = {tag:UOp.new_buffer(b.src[1].arg, size_per_tag[tag], b.dtype).rtag(tag) for tag,b in ref_bufs.items()}
subs = {b:UOp(Ops.SLICE, b.dtype, (bases[b.tag], UOp.const(dtypes.weakint, off_per_buf.get(b, 0))), b.arg) for b in bufs if b.tag in bases}
return call.replace(src=(call.src[0].substitute(subs, walk=True), *call.src[1:])) if subs else None
pm_pack_placeholders = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), pack_hcq_placeholders)])
# *****************
# 5.4. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
def hold_call_buffers(call:UOp) -> UOp|None:
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
@ -423,7 +496,7 @@ pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
def bufferize_buf(buf:UOp) -> UOp|None:
if buf.tag is None: return None
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), dev) for dev in to_tuple(buf.src[1].arg))
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), "CPU") for dev in to_tuple(buf.src[1].arg))
return make_mstack(uops)
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
@ -434,31 +507,42 @@ def push_stack(op, s): return UOp(Ops.STACK, op.dtype.scalar().vec(len(s.src)),
tuple(op.replace(dtype=op.dtype.scalar(), src=tuple(x if y is s else y for y in op.src)) for x in s.src))
def fold_blob_store(buf:UOp, blob:UOp) -> UOp:
for b in (buf.src if buf.op is Ops.MSTACK else (buf,)): b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
for b in (mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)): b.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
return UOp(Ops.NOOP)
def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
for b, v in zip((buf.src if buf.op is Ops.MSTACK else (buf,)), (val.src if val.op is Ops.STACK else (val,))):
struct.pack_into(f'<{v.dtype.fmt}', b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * b.dtype.base.itemsize, v.arg)
for b, v in zip((bs:=mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)), val.src if val.op is Ops.STACK else (val,)*len(bs)):
struct.pack_into(f'<{v.dtype.fmt}', b.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * buf.dtype.base.itemsize, v.arg)
return UOp(Ops.NOOP)
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr)
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs))
if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf
devs, b = to_tuple(g.src[1].arg), buf.buffer
bufs = tuple(cast(Buffer, x.buffer) for x in buf.src) if buf.op is Ops.MSTACK else tuple(b.bufs if isinstance(b, MultiBuffer) else (b,)*len(devs))
assert len(bufs) == len(devs), f"can't resolve {len(bufs)} buffers on {len(devs)} devices"
addrs = tuple(UOp.const(dtypes.uint64, x.get_buf(d).va_addr) for x, d in zip(bufs, devs))
return addrs[0] if len(addrs) == 1 else UOp(Ops.STACK, dtypes.uint64.vec(len(addrs)), addrs)
def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp:
itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize
return UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * itemsize)
pm_resolve_patches = PatternMatcher([
# multi
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
# index on slice is index
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
# getaddr
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), # getaddr(slice(x)) -> offset+getaddr(x)
lambda bv, dev: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
(UPat(Ops.GETADDR, src=(UPat({Ops.BUFFER, Ops.MSTACK, Ops.MSELECT}, name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
(UPat(Ops.GETADDR, src=(UPat(name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
# folders
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
fold_const_store),
]) + symbolic_simple
@ -466,19 +550,25 @@ pm_resolve_patches = PatternMatcher([
# 8. callify hcq programs
def to_param(bufs:list[UOp], ref:UOp) -> UOp:
bufs.append(ref)
return UOp.placeholder((ref.buffer.size,), ref.dtype, len(bufs)-1)
if ref not in bufs: bufs.append(ref)
return UOp.placeholder((ref.buffer.size,), ref.dtype, bufs.index(ref))
pm_to_param = PatternMatcher([(UPat({Ops.MSELECT, Ops.MSTACK, Ops.BUFFER}, name="r"), lambda ctx, r: to_param(ctx, r))])
def parametrize_host_buffers(call:UOp) -> UOp:
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=[]), bottom_up=True, name="parametrize host buffers")
return call.replace(src=(body, *bufs) + call.src[1:], tag="hcq_param")
# preserve original order of args
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=list(get_call_arg_uops(call))), bottom_up=True, name="parametrize host buffers")
# move vars to new slots
var_slots = {nm:len(bufs)+i for i,nm in enumerate(sorted({v.expr for v in body.variables() if v.op is Ops.PARAM}))}
body = body.substitute({v:v.replace(arg=replace(v.arg, slot=var_slots[v.expr])) for v in body.variables() if v.op is Ops.PARAM})
return call.replace(src=(body, *bufs) + tuple(x for x in call.src[1:] if x.op is Ops.BIND))
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
def callify_hcq(call:UOp) -> UOp:
prg = to_program(call.src[0].sink(arg=KernelInfo("hcq_submit"), tag=1), Device["CPU"].renderer)
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(prg,), arg="hcq").call(*call.src[1:], aux=call.arg.aux)
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq_param", name="call"), callify_hcq)])
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), callify_hcq)])
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
def hcq_schedule(linear:UOp) -> UOp:
@ -492,14 +582,17 @@ def hcq_schedule(linear:UOp) -> UOp:
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
linear = graph_rewrite(linear, pm_merge_queues, name="merge queues")
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
linear = graph_rewrite(linear, pm_annotate_devs, name="annotate devs")
linear = graph_rewrite(linear, pm_replace_params, name="replace params")
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
linear = graph_rewrite(linear, pm_pack_placeholders, walk=True, name="pack placeholders")
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
# realize starts from here
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, walk=True, name="bufferize placeholders", enter_calls=True)
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
linear = graph_rewrite(linear, pm_parametrize_host_buffers, name="parametrize host buffers")
linear = graph_rewrite(linear, pm_parametrize_host_buffers, walk=True, name="parametrize host buffers")
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
return linear

View file

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import cast
from typing import cast, Any, Callable
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
assert sys.platform != 'win32'
from dataclasses import dataclass
@ -89,25 +89,25 @@ def memory_barrier(ctx):
reg_done=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff),
acquire_mem(ctx)))
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.device))
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.devs))
def pm4_barrier(ctx): return memory_barrier(ctx)
def pm4_store(ctx, dst, val):
if val.op is Ops.BINARY: return None
return release_mem(ctx, make_getaddr(dst, ctx.device), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
return release_mem(ctx, make_getaddr(dst, ctx.devs), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
ctx.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
def pm4_timestamp(ctx, dst):
return release_mem(ctx, make_getaddr(dst, ctx.device), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
return release_mem(ctx, make_getaddr(dst, ctx.devs), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
ctx.pm4.int_sel__mec_release_mem__none)
def pm4_program(ctx, prg):
data, info = prg.arg
lib_gpu, args = prg.src
prog_addr = make_getaddr(lib_gpu, ctx.device) + data.entry_point_offset
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.device)
args_addr = make_getaddr(args, ctx.device)
prog_addr = make_getaddr(lib_gpu, ctx.devs) + data.entry_point_offset
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.devs)
args_addr = make_getaddr(args, ctx.devs)
user_regs = []
if data.enable_private_segment_sgpr:
@ -174,7 +174,7 @@ pm_pm4_submit = PatternMatcher([(UPat(Ops.LINEAR, name="lin"),
class SDMAOps(FastEnum): COPY = auto(); POLL_REGMEM = auto(); FENCE = auto(); TRAP = auto(); TIMESTAMP = auto() # noqa: E702
def sdma_copy(ctx, dst, src, copy):
src_addr, dst_addr = make_getaddr(src, ctx.device), make_getaddr(dst, ctx.device)
src_addr, dst_addr = make_getaddr(src, ctx.devs), make_getaddr(dst, ctx.devs)
return UOp(Ops.LINEAR, dtypes.void, tuple([make_ins(SDMAOps.COPY,
ctx.sdma.SDMA_OP_COPY | ctx.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_COPY_LINEAR),
ctx.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(min(copy.arg - off, ctx.max_copy_size) - 1), 0,
@ -183,17 +183,17 @@ def sdma_copy(ctx, dst, src, copy):
def sdma_wait(ctx, dst, val):
op = ctx.sdma.SDMA_OP_POLL_REGMEM | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) \
| ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1)
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.device)), val, 0xffffffff,
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.devs)), val, 0xffffffff,
ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
def sdma_store(ctx, dst, val):
op = ctx.sdma.SDMA_OP_FENCE | (ctx.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if ctx.target[0] != 9 else 0)
return UOp(Ops.LINEAR, dtypes.void, (
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.device)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.devs)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
def sdma_timestamp(ctx, dst):
op = ctx.sdma.SDMA_OP_TIMESTAMP | ctx.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL)
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.device)))
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.devs)))
pm_sdma_opsel = PatternMatcher([
(UPat(Ops.BARRIER), lambda: UOp(Ops.NOOP, dtypes.void, ())),
@ -516,11 +516,17 @@ class PCIIface(PCIIfaceBase):
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
@dataclass(frozen=True)
class AMDEncodeCtx: # encode-time constants for one queue: devs (every cmdbuf address resolves into these) + gfx version + packet/ip modules
devs: tuple[str, ...]; target: tuple[int, ...]; pm4: Any; sdma: Any; soc: Any # noqa: E702
gc: AMDIP; nbio: AMDIP; xccs: int; max_copy_size: int; tmpring_size: Callable # noqa: E702
def encode_queue(q:UOp) -> UOp|None:
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
devs = to_tuple(q.arg[0])
d = Device[(devs:=to_tuple(q.arg[0]))[0]]
ctx = AMDEncodeCtx(devs, d.target, d.pm4, d.sdma, d.soc, d.gc, d.nbio, d.xccs, d.max_copy_size, d.tmpring_size)
opsel, submit = (pm_pm4_opsel, pm_pm4_submit) if q.arg[1].startswith("COMPUTE") else (pm_sdma_opsel, pm_sdma_submit)
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=Device[devs[0]], name=f"{q.arg[1]} opsel"))
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=ctx, name=f"{q.arg[1]} opsel"))
pm_lower = PatternMatcher([
(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),)), encode_queue),

View file

@ -59,8 +59,8 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, inv_scale.uop)
return (None, None, grad_xw13_uop, None, None)
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor]:
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, new_amax)
# grad_amax_state: delayed amax for grad_xw13 fp8 quantization in the backward.
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
MBS, SEQ, H2 = xw13.shape
@ -72,5 +72,4 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname_of(xw13.device))
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, grad_amax_state,
fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, scalar_amax(amax_buf)
return fp8_out, scalar_amax(amax_buf)

View file

@ -112,8 +112,8 @@ def _fused_add_bwd(*args, **kwargs):
grad_h, grad_w = _bwd_common(fp8_grad_u, h_grad_u, x_u, x_normed_u, rrms_u, weight_u, amax_state_u, kernel)
return (None, None, None, None, None, grad_h, grad_h, grad_w, None)
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, inv_scale, new_amax, x_normed, rrms).
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor]:
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, new_amax, x_normed, rrms).
# x_normed + rrms are saved for the rmsnorm backward (also recomputed here from x regs).
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
@ -127,13 +127,12 @@ def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, e
fxn = functools.partial(_custom_fwd, dname=dname_of(x.device), eps_val=eps)
fp8_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
fp8_out, x_normed_out, rrms_out, amax_buf, x, weight, amax_state, fxn=fxn, grad_fxn=_fused_bwd)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, scalar_amax(amax_buf), x_normed_out, rrms_out
return fp8_out, scalar_amax(amax_buf), x_normed_out, rrms_out
def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor, amax_state:Tensor,
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# NOTE: h = x + residual; y_normed = rmsnorm(h); fp8 = quantize(y_normed * weight).
# Returns (fp8, inv_scale, new_amax, h, x_normed, rrms). h is also written so downstream can
# Returns (fp8, new_amax, h, x_normed, rrms). h is also written so downstream can
# reuse it without recomputing x+residual — eliminates the separate residual-add kernel.
assert x.dtype == dtypes.bfloat16 and residual.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
assert x.shape == residual.shape
@ -149,5 +148,4 @@ def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor,
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, x, residual, weight, amax_state,
fxn=fxn, grad_fxn=_fused_add_bwd)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out
return fp8_out, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out

View file

@ -0,0 +1,71 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.helpers import prod
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
@functools.cache
def _custom_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x:UOp) -> UOp:
rows, K = x.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x = x.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
x_f = x[idx].cast(dtypes.float)
abs_x = (x_f < 0.0).where(-x_f, x_f)
blk_max = abs_x.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (x_f * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
# pack the 4 e8 of this super-block into one uint32 (little-endian: byte sb), write transposed (sk4, row)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"quantize_mxfp8_{n_elems}", opts_to_apply=()))
def _quantize_mxfp8_fused_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x = kernel.src[1:]
device = x.device
rows, K = x.shape
scale_K = K // BLK
e8 = Tensor(e8_out, device=device).reshape(rows, scale_K)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, BLK).reshape(rows, K)
grad_x = (Tensor(gradient, device=device).float() * qscale).cast(dtypes.bfloat16)
return (None, None, None, grad_x.uop)
def quantize_mxfp8_fused(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x.dtype == dtypes.bfloat16, f"expected bf16, got {x.dtype}"
assert x.ndim == 2, f"expected 2d (rows, K), got {x.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x.shape
scale_K = K // BLK
axis = x.uop.axis if isinstance(x.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x, fxn=_custom_quantize_mxfp8, grad_fxn=_quantize_mxfp8_fused_bwd)
return fp8_out, e8_out, si_out

View file

@ -255,7 +255,7 @@ def _sign(serial_num, out_dir):
final = _build_elf(ehdr, allph, segs)
os.makedirs(out_dir, exist_ok=True)
out = os.path.join(out_dir, "testsig-0x%08X.so" % serial_num)
out = os.path.join(out_dir, "testsig-0x%08x.so" % serial_num)
with open(out, 'wb') as f:
f.write(final)
print("Signing complete! Output saved at %s" % out)

View file

@ -26,7 +26,7 @@ constexpr int NUM_THREADS = WARP_THREADS * NUM_WARPS;
using G = kittens::group<NUM_WARPS>;
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr) {
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr, bf16 *b_unused) {
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
static_assert(M % BLOCK_SIZE == 0 && N % BLOCK_SIZE == 0 && K % K_STEP == 0);

View file

@ -99,12 +99,14 @@ using G = kittens::group<NUM_WARPS>;
#endif
__global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr
#if SCALE_MODE == 1
#if SCALE_MODE & 1
, float *x_scale_ptr
#elif SCALE_MODE == 2
#endif
#if SCALE_MODE & 2
, float *w_scale_ptr
#elif SCALE_MODE == 3
, float *x_scale_ptr, float *w_scale_ptr
#endif
#if SCALE_MODE & 4
, float *g_scale_ptr
#endif
) {
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
@ -346,20 +348,19 @@ __global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_pt
}
// apply x_scale * w_scale before bf16 store to prevent overflow
#if SCALE_MODE == 1
float scale = *x_scale_ptr;
mul(cA, cA, scale);
mul(cB, cB, scale);
mul(cC, cC, scale);
mul(cD, cD, scale);
#elif SCALE_MODE == 2
float scale = *w_scale_ptr;
mul(cA, cA, scale);
mul(cB, cB, scale);
mul(cC, cC, scale);
mul(cD, cD, scale);
#elif SCALE_MODE == 3
float scale = *x_scale_ptr * *w_scale_ptr;
#if SCALE_MODE != 0
float scale = 1.0f;
#if SCALE_MODE & 1
float x_scale = (*x_scale_ptr + 1e-08f) * (1.0f / 448.0f);
scale *= x_scale;
#endif
#if SCALE_MODE & 2
scale *= *w_scale_ptr;
#endif
#if SCALE_MODE & 4
scale *= *g_scale_ptr;
#endif
mul(cA, cA, scale);
mul(cB, cB, scale);
mul(cC, cC, scale);

View file

@ -26,9 +26,7 @@ constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
using G = kittens::group<NUM_WARPS>;
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
const uint32_t *__restrict__ scale_A_iter,
const uint32_t *__restrict__ scale_B_iter,
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr, fp8e8m0 *scale_A_ptr, fp8e8m0 *scale_B_ptr,
const uint8_t *__restrict__ a_e8_unused,
const uint8_t *__restrict__ b_e8_unused) {
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
@ -40,15 +38,19 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
constexpr int k_iters = K / BLOCK_K;
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
kittens::gl<fp8e8m0, k_iters * (M / BLOCK_ROW), 1, 16, 64> scale_A_gl{scale_A_ptr, nullptr, nullptr, nullptr, nullptr};
kittens::gl<fp8e8m0, k_iters * (N / BLOCK_COL), 1, 16, 64> scale_B_gl{scale_B_ptr, nullptr, nullptr, nullptr, nullptr};
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
using ST_Scale = st<fp8e8m0, 16, 64, st_16x64_s>;
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
__shared__ ST_A As[2][2];
__shared__ ST_B Bs[2][2];
__shared__ uint8_t smem_scales[2048];
__shared__ ST_Scale scale_A_smem[2], scale_B_smem[2];
RT_A a;
RT_B b0, b1;
@ -81,54 +83,84 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
G::prefill_swizzled_offsets(As[0][0], A, sw_A);
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
const T *a_base = (const T *)&A[{0, 0, 0, 0}];
const T *b_base = (const T *)&B[{0, 0, 0, 0}];
const int a_row_stride = A.template stride<2>() * sizeof(T);
const int b_row_stride = B.template stride<2>() * sizeof(T);
i32x4 a_srd = make_srsrc(a_base, M * a_row_stride, a_row_stride);
i32x4 b_srd = make_srsrc(b_base, N * b_row_stride, b_row_stride);
const int wid = warpid() % NUM_WARPS;
constexpr int elem_per_warp = (16 / sizeof(T)) * kittens::WARP_THREADS;
uint32_t a_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t a_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t a_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t a_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
int a_row_h0 = warp_m * REG_M;
int a_row_h1 = HALF_ROW + warp_m * REG_M;
int b_row_h0 = warp_n * REG_N;
int b_row_h1 = HALF_COL + warp_n * REG_N;
int tic = 0, toc = 1;
int tic_scales = 0, toc_scales = 1;
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B);
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A);
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B);
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A);
uint32_t a_lds[2][2] = {{a_lds_00, a_lds_01}, {a_lds_10, a_lds_11}};
uint32_t b_lds[2][2] = {{b_lds_00, b_lds_01}, {b_lds_10, b_lds_11}};
if constexpr (k_iters >= 6 && (k_iters % 2 == 0)) {
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][1]));
if (warp_m == 1) __builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt vmcnt(4)");
__builtin_amdgcn_s_barrier();
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A);
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B);
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B);
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][0]));
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][0]));
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][1]));
asm volatile("s_waitcnt vmcnt(6)");
__builtin_amdgcn_s_barrier();
#pragma unroll 2
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1) {
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
G::load(scale_A_smem[0], scale_A_gl, {0 * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[0], scale_B_gl, {0 * tiles_N + block_col, 0, 0, 0});
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
#pragma unroll 2
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
if (k + 1 < k_iters) {
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
}
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
asm volatile("s_waitcnt lgkmcnt(8)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
load(b1, bs1);
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A);
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
@ -137,10 +169,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
load(a, as1);
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B);
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
@ -150,7 +181,7 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B);
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
asm volatile("s_waitcnt vmcnt(6)");
__builtin_amdgcn_s_barrier();
@ -162,19 +193,23 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
{ // Epilogue k = k_iters - 2
int k = k_iters - 2;
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
if (k + 1 < k_iters) {
G::load(scale_A_smem[toc_scales], scale_A_gl, {(k + 1) * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {(k + 1) * tiles_N + block_col, 0, 0, 0});
}
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
__builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt lgkmcnt(0)");
@ -217,17 +252,17 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_sched_barrier(0);
tic ^= 1; toc ^= 1;
tic_scales ^= 1; toc_scales ^= 1;
}
{ // Epilogue k = k_iters - 1 (final)
int k = k_iters - 1;
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0);
@ -263,6 +298,38 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_s_barrier();
}
} else {
#pragma unroll 1
for (int kk = 0; kk < k_iters; kk++) {
G::load(As[0][0], A, {0, 0, block_row * 2, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][0]));
G::load(As[0][1], A, {0, 0, block_row * 2 + 1, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][1]));
G::load(Bs[0][0], B, {0, 0, block_col * 2, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][0]));
G::load(Bs[0][1], B, {0, 0, block_col * 2 + 1, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][1]));
G::load(scale_A_smem[0], scale_A_gl, {kk * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[0], scale_B_gl, {kk * tiles_N + block_col, 0, 0, 0});
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[0].data, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[0].data, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[0].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[0].data, b_row_h1);
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][0], {warp_n, 0}); load(b0, bs0);
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][1], {warp_n, 0}); load(b1, bs1);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[0][0], {warp_m, 0}); load(a, as0);
asm volatile("s_waitcnt lgkmcnt(0)");
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[0][1], {warp_m, 0}); load(a, as1);
asm volatile("s_waitcnt lgkmcnt(0)");
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
__builtin_amdgcn_s_barrier();
}
}
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});

View file

@ -438,6 +438,35 @@ __device__ __forceinline__ T buffer_load_dwordx2(
return result;
}
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR_START >= 256) {
asm volatile("buffer_load_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
} else {
asm volatile("buffer_load_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
// Returns a u32x3 (ext_vector_type(3)). Note sizeof(u32x3) == 16 due to vector
// alignment padding, but the compiler still allocates exactly 3 consecutive VGPRs
// for the "=v" operand and emits buffer_load_dwordx3 v[N:N+2].
template<typename T = u32x3>
__device__ __forceinline__ T buffer_load_dwordx3(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
T result;
asm volatile("buffer_load_dwordx3 %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR_START >= 256) {
@ -465,6 +494,116 @@ __device__ __forceinline__ T buffer_load_dwordx4(
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_ubyte(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_sbyte(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_sbyte %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_ushort(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_sshort(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_sshort %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
// d16 variants merge the loaded sub-dword into the low (or high, for _hi) 16 bits
// of the destination register, leaving the other half untouched: read-modify-write.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_ubyte_d16(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_ubyte_d16 %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_ubyte_d16_hi(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_ubyte_d16_hi %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_sbyte_d16(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_sbyte_d16 %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_sbyte_d16_hi(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_sbyte_d16_hi %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_short_d16(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_short_d16 %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_short_d16_hi(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_short_d16_hi %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR>
__device__ __forceinline__ void buffer_store_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
@ -519,6 +658,32 @@ __device__ __forceinline__ void buffer_store_dwordx2(
: "memory");
}
template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
if constexpr (GPR_START >= 256) {
asm volatile("buffer_store_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
// VGPRS
} else {
asm volatile("buffer_store_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
template<typename T = u32x3>
__device__ __forceinline__ void buffer_store_dwordx3(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
asm volatile("buffer_store_dwordx3 %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
@ -546,6 +711,50 @@ __device__ __forceinline__ void buffer_store_dwordx4(
: "memory");
}
// Stores the low 8 bits of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_byte(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
// Stores the high 16 bits' low 8 bits (bits [23:16]) of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_byte_d16_hi(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_byte_d16_hi %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
// Stores the low 16 bits of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_short(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
// Stores the high 16 bits of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_short_d16_hi(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_short_d16_hi %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR>
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR >= 256) {

View file

@ -52,6 +52,8 @@ __device__ __forceinline__ int laneid() { return threadIdx.x & 0x3f; }
using i32x2 = int32_t __attribute__((ext_vector_type(2)));
using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
using i32x3 = int32_t __attribute__((ext_vector_type(3)));
using u32x3 = uint32_t __attribute__((ext_vector_type(3)));
using i32x4 = int32_t __attribute__((ext_vector_type(4)));
using u32x4 = uint32_t __attribute__((ext_vector_type(4)));

View file

@ -0,0 +1,10 @@
/**
* @file
* @brief Aggregate header for CDNA4 ThunderKittens core include tree.
*/
#pragma once
#include "common/common.cuh"
#include "types/types.cuh"
#include "ops/ops.cuh"

View file

@ -248,7 +248,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
"ds_read_b64_tr_b16 %0, %2 offset:%3\n"
"ds_read_b64_tr_b16 %1, %2 offset:%4\n"
// "s_waitcnt lgkmcnt(0)\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
: "v"(addr), "i"(offset), "i"(offset + 4 * ST::underlying_subtile_row_bytes)
: "memory"
@ -259,7 +259,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
"ds_read_b64_tr_b16 %0, %2 offset:%4\n"
"ds_read_b64_tr_b16 %1, %3 offset:%4\n"
// "s_waitcnt lgkmcnt(0)\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
: "v"(addr), "v"(next_addr), "i"(offset)
: "memory"
@ -688,4 +688,88 @@ __device__ inline static void store(ST &dst, const RT &src) {
}
}
template<ducks::rt::col_layout RT, ducks::st::all ST>
__device__ inline static void load(RT &dst, const ST &src, int col_offset) {
static_assert(RT::cols == ST::rows,
"col_layout load with col_offset: ST.rows must equal RT::cols");
static_assert(RT::width == 1, "col_layout load with col_offset: only width==1 supported");
using T2 = RT::dtype;
using U = ST::dtype;
using U2 = base_types::packing<U>::packed_type;
const int laneid = kittens::laneid();
const uint32_t src_ptr = (uint32_t)(uintptr_t)(&src.data[0]);
constexpr int subtile_bytes = ST::underlying_subtile_bytes;
if constexpr (std::is_same_v<T2, fp8e4m3_4>) {
const int block_id = laneid / 16;
const int l_within = laneid % 16;
const int tr_k_grp = l_within / 2;
const int m_half = l_within & 1;
#pragma unroll
for (int i = 0; i < RT::height; i++) {
const int m_col = col_offset + i * RT::base_tile_rows + m_half * 8;
const int subtile_base = block_id;
uint32_t addr = src_ptr
+ (uint32_t)(subtile_base * subtile_bytes)
+ src.swizzle({tr_k_grp, m_col});
asm volatile(
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[4]))
: "v"(addr), "i"(4 * subtile_bytes)
: "memory"
);
addr ^= 1088u;
asm volatile(
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2])),
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[6]))
: "v"(addr), "i"(4 * subtile_bytes)
: "memory"
);
}
} else if constexpr (std::is_same_v<U2, bf16_2> || std::is_same_v<U2, half_2>) {
const int row_offset = ((laneid % 16) / 4) + ((laneid / 16) * 4);
const int col_offset_lane = (laneid % 4) * 4;
#pragma unroll
for (int i = 0; i < RT::height; i++) {
const int m_col = col_offset + i * RT::base_tile_rows + col_offset_lane;
const uint32_t addr = src_ptr + src.swizzle({row_offset, m_col});
const uint32_t next_addr = src_ptr + src.swizzle({row_offset + 4, m_col});
if constexpr (RT::base_tile_stride == 8) {
asm volatile(
"ds_read_b64_tr_b16 %0, %2 offset:0\n"
"ds_read_b64_tr_b16 %1, %3 offset:0\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
"=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2]))
: "v"(addr), "v"(next_addr)
: "memory"
);
} else if constexpr (RT::base_tile_stride == 4) {
asm volatile(
"ds_read_b64_tr_b16 %0, %1 offset:0\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0]))
: "v"(addr)
: "memory"
);
} else {
static_assert(false, "Unsupported stride for col_layout load with col_offset");
}
}
} else {
static_assert(false, "Unsupported type for col_layout load with col_offset");
}
}
} // namespace kittens

View file

@ -201,10 +201,10 @@ __device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::col,
* @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode.
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
*/
template<ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
template<ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col, D_shape> &d,
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b, // in row-major mode
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
const rt_base<float, ducks::rt_layout::col, C_shape> &c) {
static_assert(std::is_same_v<D_shape, C_shape>, "D and C must have the same shape");
@ -243,6 +243,32 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
}
}
/**
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
*
* @param smem_scales LDS pointer to scale tile data.
* @param row_offset Starting row within the scale region (warp's tile offset).
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
*/
__device__ __forceinline__ fp8e8m0_4 pack_scales(
const fp8e8m0 *smem_scales, int row_offset) {
int lid = laneid();
int r16 = lid % 16;
int k_sub = lid / 16;
const fp8e8m0_4 *s4 = (const fp8e8m0_4 *)smem_scales;
fp8e8m0_4 w0 = s4[row_offset + 0 * 16 + r16];
fp8e8m0_4 w1 = s4[row_offset + 1 * 16 + r16];
fp8e8m0_4 w2 = s4[row_offset + 2 * 16 + r16];
fp8e8m0_4 w3 = s4[row_offset + 3 * 16 + r16];
fp8e8m0_4 sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
fp8e8m0_4 lo = __builtin_amdgcn_perm(w0, w1, sel);
fp8e8m0_4 hi = __builtin_amdgcn_perm(w2, w3, sel);
return lo | (hi << 16);
}
/**
* @brief Base dot product operation for row layout.
*
@ -254,10 +280,10 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
* @param[in] b The second input rt_base<Operand_T, row_layout> matrix.
* @param[in] c The input rt_base<float, col_layout> accumulator matrix.
*/
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
__device__ static inline void mma_ABt_base_scaled(rt_base<float, ducks::rt_layout::col, D_shape> &d,
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b,
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
const rt_base<float, ducks::rt_layout::col, C_shape> &c,
const fp8e8m0_4 *scale_a,
const fp8e8m0_4 *scale_b) {
@ -441,7 +467,7 @@ __device__ static inline void mma_AB(D &d,
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode.
* @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix.
*/
template<ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
template<ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
__device__ static inline void mma_ABt(D &d,
const A &a,
const B &b, // notice row and (M, K) instead of col and (K, M)
@ -499,7 +525,7 @@ __device__ static inline void mma_ABt(D &d,
* @param[in] scale_a Pointer to the packed E8M0 scale for the A matrix.
* @param[in] scale_b Pointer to the packed E8M0 scale for the B matrix.
*/
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
__device__ static inline void mma_ABt_scaled(D &d,
const A &a,
const B &b,

View file

@ -9,6 +9,5 @@
#include "maps.cuh"
#include "reductions.cuh"
#include "mma.cuh"
#include "scales.cuh"
#include "assembly/tile.cuh"

View file

@ -41,7 +41,7 @@ struct st_16x16 {
} else {
static_assert(false, "Unsupported type");
}
}
};
@ -205,6 +205,23 @@ struct st_8x32 {
}
};
struct st_16x64 {
static constexpr int rows = 16;
static constexpr int cols = 64;
template<typename _T>
static constexpr int bytes_per_thread() {
static_assert(sizeof(_T) == 1 || sizeof(_T) == 2 || sizeof(_T) == 4, "Unsupported type");
return 16;
}
template<typename _T>
__device__ __forceinline__ static const uint32_t swizzle (int2 coord) {
static_assert(sizeof(_T) == 1 || sizeof(_T) == 2 || sizeof(_T) == 4, "Unsupported type");
return sizeof(_T) * (coord.x * cols + coord.y);
}
};
struct st_16x128 {
static constexpr int rows = 16;
static constexpr int cols = 128;
@ -236,12 +253,13 @@ struct st_16x128 {
};
template<typename T>
concept all = std::is_same_v<T, st_16x16> ||
std::is_same_v<T, st_16x16_swizzled> ||
std::is_same_v<T, st_32x32> ||
std::is_same_v<T, st_16x32> ||
std::is_same_v<T, st_32x16> ||
concept all = std::is_same_v<T, st_16x16> ||
std::is_same_v<T, st_16x16_swizzled> ||
std::is_same_v<T, st_32x32> ||
std::is_same_v<T, st_16x32> ||
std::is_same_v<T, st_32x16> ||
std::is_same_v<T, st_8x32> ||
std::is_same_v<T, st_16x64> ||
std::is_same_v<T, st_16x128>;

View file

@ -75,6 +75,7 @@ using st_32x32_s = ducks::st_shape::st_32x32;
using st_16x32_s = ducks::st_shape::st_16x32;
using st_32x16_s = ducks::st_shape::st_32x16;
using st_8x32_s = ducks::st_shape::st_8x32;
using st_16x64_s = ducks::st_shape::st_16x64;
using st_16x128_s = ducks::st_shape::st_16x128;
}

View file

@ -5,9 +5,12 @@
#pragma once
#include "common/common.cuh"
#include "types/types.cuh"
#include "ops/ops.cuh"
#if defined(KITTENS_CDNA4)
#include "cdna4/includes.cuh"
#elif defined(KITTENS_UDNA1)
#include "udna1/includes.cuh"
#endif
#include "pyutils/util.cuh"

View file

@ -1,74 +0,0 @@
/**
* @file
* @brief MXFP8 block scale loading and packing utilities.
*
* Provides functions for staging E8M0 block scales in LDS and packing them
* into fp8e8m0_4 registers for use with scaled MFMA instructions.
*/
#pragma once
#include "../../../../common/common.cuh"
namespace kittens {
/**
* @brief Load iteration-major packed E8M0 scales from global memory into LDS.
*
* First 256 threads each load one uint32 (4 packed E8M0 bytes) for A and B.
* A scales are placed at smem[0..1023], B scales at smem[1024..2047].
*
* @param smem_scales LDS buffer, must be >= 2048 bytes.
* @param scale_A_iter Iteration-major A scales: [k_iter * M + row] as uint32.
* @param scale_B_iter Iteration-major B scales: [k_iter * N + row] as uint32.
* @param block_m Starting row offset for A within the current block.
* @param block_n Starting row offset for B within the current block.
* @param k_iter Current K iteration index.
* @param M_dim M dimension of the matrix.
* @param N_dim N dimension of the matrix.
*/
__device__ __forceinline__ void load_scales_to_lds(
uint8_t *smem_scales,
const uint32_t *__restrict__ scale_A_iter,
const uint32_t *__restrict__ scale_B_iter,
int block_m, int block_n, int k_iter, int M_dim, int N_dim) {
int tid = threadIdx.x;
if (tid < 256) {
uint32_t sa = scale_A_iter[k_iter * M_dim + block_m + tid];
uint32_t sb = scale_B_iter[k_iter * N_dim + block_n + tid];
*(uint32_t *)&smem_scales[tid * 4] = sa;
*(uint32_t *)&smem_scales[1024 + tid * 4] = sb;
}
}
/**
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
*
* Each lane (r16 = laneid%16, k_sub = laneid/16) loads 4 dwords from
* consecutive 16-row groups, then uses v_perm_b32 to extract byte k_sub
* from each, producing the packed scale register for scaled MFMA.
*
* @param smem_scales LDS pointer to scale region.
* @param lds_base Byte offset within smem_scales (0 for A, 1024 for B).
* @param row_offset Starting row within the scale region (warp's tile offset).
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
*/
__device__ __forceinline__ fp8e8m0_4 pack_scales(
const uint8_t *smem_scales, int lds_base, int row_offset) {
int lid = laneid();
int r16 = lid % 16;
int k_sub = lid / 16;
const uint32_t *s32 = (const uint32_t *)(smem_scales + lds_base);
uint32_t w0 = s32[row_offset + 0 * 16 + r16];
uint32_t w1 = s32[row_offset + 1 * 16 + r16];
uint32_t w2 = s32[row_offset + 2 * 16 + r16];
uint32_t w3 = s32[row_offset + 3 * 16 + r16];
uint32_t sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
uint32_t lo = __builtin_amdgcn_perm(w0, w1, sel);
uint32_t hi = __builtin_amdgcn_perm(w2, w3, sel);
return (fp8e8m0_4)(lo | (hi << 16));
}
} // namespace kittens

View file

@ -1,6 +1,5 @@
#pragma once
#include "../ops/ops.cuh"
#include <iostream>
#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)

View file

@ -0,0 +1,449 @@
/**
* @file
* @brief Basic operations on generic types.
*/
#pragma once
#include <limits>
#include "base_types.cuh"
namespace kittens {
/**
* @namespace base_ops
*
* @brief A namespace for operations on basic data types.
*/
namespace base_ops {
/* ---------- CONST OPS ---------- */
/**
* @brief Represents the zero constant operation.
*
* This operation returns the zero value of the specified type.
*
* @tparam T The data type for which to return the zero value.
* @return The zero value of type T.
*/
struct zero {
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::zero(); }
};
/**
* @brief Represents the ones constant operation.
*
* This operation returns the ones value of the specified type.
*
* @tparam T The data type for which to return the ones value.
* @return The ones value of type T.
*/
struct ones {
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::ones(); }
};
/**
* @brief Represents the positive infinity constant operation.
*
* This operation returns the positive infinity value of the specified type.
*
* @tparam T The data type for which to return the positive infinity value.
* @return The positive infinity value of type T.
*/
struct pos_infty {
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::pos_infty(); }
};
/**
* @brief Represents the negative infinity constant operation.
*
* This operation returns the negative infinity value of the specified type.
*
* @tparam T The data type for which to return the negative infinity value.
* @return The negative infinity value of type T.
*/
struct neg_infty {
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::neg_infty(); }
};
/* ---------- UNARY OPS ---------- */
/**
* @brief Exponential function operation.
*
* This operation calculates the exponential of the input value.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The exponential of the input value.
*/
struct exp {
template<typename T> static __device__ inline T op(const T &x) { return exp(x); }
};
template<> __device__ inline float exp::op<float> (const float &x ) { return __expf(x); }
template<> __device__ inline float2 exp::op<float2>(const float2 &x) { return float2{__expf(x.x), __expf(x.y)}; }
template<> __device__ inline bf16 exp::op<bf16> (const bf16 &x ) { return hexp(x); }
template<> __device__ inline bf16_2 exp::op<bf16_2>(const bf16_2 &x) { return h2exp(x); }
template<> __device__ inline half exp::op<half> (const half &x ) { return hexp(x); }
template<> __device__ inline half_2 exp::op<half_2>(const half_2 &x) { return h2exp(x); }
// /**
// * @brief Exponential function operation, in base 2
// *
// * This operation calculates the exponential of the input value, in base 2.
// *
// * @tparam T The data type of the input and output values.
// * @param x[in] The input value.
// * @return The exponential of the input value.
// */
// struct exp2 {
// template<typename T> static __device__ inline T op(const T &x) { return exp2f(x); }
// };
// template<> __device__ inline float exp2::op<float> (const float &x ) { return exp2f(x); }
// template<> __device__ inline float2 exp2::op<float2>(const float2 &x) { return float2{exp2f(x.x), exp2f(x.y)}; }
// template<> __device__ inline bf16 exp2::op<bf16> (const bf16 &x ) { return hexp2(x); }
// template<> __device__ inline bf16_2 exp2::op<bf16_2>(const bf16_2 &x) { return h2exp2(x); }
// template<> __device__ inline half exp2::op<half> (const half &x ) { return hexp2(x); }
// template<> __device__ inline half_2 exp2::op<half_2>(const half_2 &x) { return h2exp2(x); }
/**
* @brief Base-2 exponential operation using `__builtin_amdgcn_exp2_f32`
*
* Maps directly to `v_exp_f32_e32` on AMD, for highest performance.
* Expects `x` to be in a safe numerical range (e.g., [-64, 88]).
*/
struct exp2 {
template <typename T>
static __device__ inline T op(const T &x) {
return exp2f(x); // fallback
}
};
// Force hardware v_exp_f32 for float
template<>
__device__ inline float exp2::op<float>(const float &x) {
return __builtin_amdgcn_exp2f(x); // Emits v_exp_f32_e32
}
// Force hardware v_exp_f32 for float2
template<>
__device__ inline float2 exp2::op<float2>(const float2 &x) {
return {
__builtin_amdgcn_exp2f(x.x),
__builtin_amdgcn_exp2f(x.y)
};
}
// Delegate to low-precision approximations
template<> __device__ inline half exp2::op<half>(const half &x) { return hexp2(x); }
template<> __device__ inline half_2 exp2::op<half_2>(const half_2 &x) { return h2exp2(x); }
template<> __device__ inline bf16 exp2::op<bf16>(const bf16 &x) { return hexp2(x); }
template<> __device__ inline bf16_2 exp2::op<bf16_2>(const bf16_2 &x) { return h2exp2(x); }
/**
* @brief Natural log function operation.
*
* This operation calculates the natural logarithm of the input value.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The natural logarithm of the input value.
*/
struct log {
template<typename T> static __device__ inline T op(const T &x) { return log(x); }
};
template<> __device__ inline float log::op<float> (const float &x ) { return __logf(x); }
template<> __device__ inline float2 log::op<float2>(const float2 &x) { return float2{__logf(x.x), __logf(x.y)}; }
template<> __device__ inline bf16 log::op<bf16> (const bf16 &x ) { return hlog(x); }
template<> __device__ inline bf16_2 log::op<bf16_2>(const bf16_2 &x) { return h2log(x); }
template<> __device__ inline half log::op<half> (const half &x ) { return hlog(x); }
template<> __device__ inline half_2 log::op<half_2>(const half_2 &x) { return h2log(x); }
/**
* @brief Logarithm base 2 operation.
*
* This operation calculates the logarithm base 2 of the input value.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The logarithm base 2 of the input value.
*/
struct log2 {
template<typename T> static __device__ inline T op(const T &x) { return log2(x); }
};
template<> __device__ inline float log2::op<float> (const float &x ) { return __log2f(x); }
template<> __device__ inline float2 log2::op<float2>(const float2 &x) { return float2{__log2f(x.x), __log2f(x.y)}; }
template<> __device__ inline bf16 log2::op<bf16> (const bf16 &x ) { return hlog2(x); }
template<> __device__ inline bf16_2 log2::op<bf16_2>(const bf16_2 &x) { return h2log2(x); }
template<> __device__ inline half log2::op<half> (const half &x ) { return hlog2(x); }
template<> __device__ inline half_2 log2::op<half_2>(const half_2 &x) { return h2log2(x); }
/**
* @brief Absolute value operation.
*
* This operation calculates the absolute value of the input.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The absolute value of the input.
*/
struct abs {
template<typename T> static __device__ inline T op(const T &x) { return abs(x); }
};
template<> __device__ inline float abs::op<float> (const float &x ) { return fabsf(x); }
template<> __device__ inline float2 abs::op<float2>(const float2 &x) { return float2{fabsf(x.x), fabsf(x.y)}; }
template<> __device__ inline bf16 abs::op<bf16> (const bf16 &x ) { return __habs(x); }
template<> __device__ inline bf16_2 abs::op<bf16_2>(const bf16_2 &x) { return __habs2(x); }
template<> __device__ inline half abs::op<half> (const half &x ) { return __habs(x); }
template<> __device__ inline half_2 abs::op<half_2>(const half_2 &x) { return __habs2(x); }
/**
* @brief Rectified Linear Unit (ReLU) operation.
*
* This operation applies the ReLU function to the input, which is the
* maximum of zero and the input value.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The result of ReLU function applied to the input.
*/
struct relu {
template<typename T> static __device__ inline T op(const T &x) { return max(x, base_types::constants<T>::zero()); }
};
template<> __device__ inline float relu::op<float> (const float &x ) { return max(x, 0.f); }
template<> __device__ inline float2 relu::op<float2>(const float2 &x) { return float2{max(x.x, 0.f), max(x.y, 0.f)}; }
template<> __device__ inline bf16 relu::op<bf16> (const bf16 &x ) { return __hmax(x, base_types::constants<bf16>::zero()); }
template<> __device__ inline bf16_2 relu::op<bf16_2>(const bf16_2 &x) { return __hmax2(x, base_types::constants<bf16_2>::zero()); }
template<> __device__ inline half relu::op<half> (const half &x ) { return __hmax(x, base_types::constants<half>::zero()); }
template<> __device__ inline half_2 relu::op<half_2>(const half_2 &x) { return half_2{__hmax(x.x, base_types::constants<half>::zero()),
__hmax(x.y, base_types::constants<half>::zero())}; }
constexpr float SQRT_2_OVER_PI = 0.7978845608028654f;
constexpr float GELU_COEFF = 0.044715f;
constexpr float GELU_INNER_COEFF = GELU_COEFF * SQRT_2_OVER_PI;
constexpr float DGELU_COEFF = 3.0f * GELU_COEFF * SQRT_2_OVER_PI;
static __device__ inline float fast_tanh(float x) {
x = fmaxf(fminf(x, 20.f), -20.f);
float e2x = __builtin_amdgcn_exp2f(x * 2.8853900817779268f);
return (e2x - 1.0f) * __frcp_rn(e2x + 1.0f);
}
/**
* @brief Gaussian Error Linear Unit (GELU) activation.
*
* Computes the GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))).
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The GELU activation applied to the input.
*/
struct gelu {
template<typename T> static __device__ inline T op(const T &x);
};
template<> __device__ inline float gelu::op<float>(const float &x) {
return x * (0.5f + 0.5f * fast_tanh(x * (SQRT_2_OVER_PI + GELU_INNER_COEFF * x * x)));
}
template<> __device__ inline float2 gelu::op<float2>(const float2 &x) {
return float2{gelu::op<float>(x.x), gelu::op<float>(x.y)};
}
/**
* @brief Derivative of the GELU activation.
*
* Computes the derivative of the GELU approximation with respect to the input.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The derivative of GELU evaluated at the input.
*/
struct dgelu {
template<typename T> static __device__ inline T op(const T &x);
};
template<> __device__ inline float dgelu::op<float>(const float &x) {
float tanh_out = fast_tanh(SQRT_2_OVER_PI * x * (1.f + GELU_COEFF * x * x));
return 0.5f * x * ((1.f - tanh_out * tanh_out) * (SQRT_2_OVER_PI + DGELU_COEFF * x * x)) +
0.5f * (1.f + tanh_out);
}
template<> __device__ inline float2 dgelu::op<float2>(const float2 &x) {
return float2{dgelu::op<float>(x.x), dgelu::op<float>(x.y)};
}
/**
* @brief Copy operation.
*
* This operation returns the input value unchanged.
*
* @tparam T The data type of the input and output values.
* @param a[in] The input value.
* @return The same value as the input.
*/
struct copy { // for non-compile-time setters.
template<typename T> static __device__ inline T op(const T &a) { return a; }
};
/* ---------- BINARY OPS ---------- */
/**
* @brief Copy2 operation.
*
* This operation returns the second input value unchanged.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value (ignored).
* @param b[in] The second input value.
* @return The same value as the second input.
*/
struct copy2 { // this turns out to be a slightly hacky op that makes some code cleaner :/
template<typename T> static __device__ inline T op(const T &a, const T &b) { return b; }
};
/**
* @brief Sum operation.
*
* This operation calculates the sum of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The sum of the input values.
*/
struct sum {
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a+b; }
};
template<> __device__ inline float2 sum::op<float2>(const float2 &a, const float2 &b) { return float2{a.x+b.x, a.y+b.y}; }
template<> __device__ inline bf16 sum::op<bf16> (const bf16 &a, const bf16 &b) { return __hadd(a, b); }
template<> __device__ inline bf16_2 sum::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hadd2(a, b); }
template<> __device__ inline half sum::op<half> (const half &a, const half &b) { return __hadd(a, b); }
template<> __device__ inline half_2 sum::op<half_2>(const half_2 &a, const half_2 &b) { return __hadd2(a, b); }
/**
* @brief Subtraction operation.
*
* This operation calculates the difference between two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The difference between the input values.
*/
struct sub {
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a-b; }
};
template<> __device__ inline float2 sub::op<float2>(const float2 &a, const float2 &b) { return float2{a.x-b.x, a.y-b.y}; }
template<> __device__ inline bf16 sub::op<bf16> (const bf16 &a, const bf16 &b) { return __hsub(a, b); }
template<> __device__ inline bf16_2 sub::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hsub2(a, b); }
template<> __device__ inline half sub::op<half> (const half &a, const half &b) { return __hsub(a, b); }
template<> __device__ inline half_2 sub::op<half_2>(const half_2 &a, const half_2 &b) { return __hsub2(a, b); }
/**
* @brief Multiplication operation.
*
* This operation calculates the product of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The product of the input values.
*/
struct mul {
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a*b; }
};
template<> __device__ inline float2 mul::op<float2>(const float2 &a, const float2 &b) { return float2{a.x*b.x, a.y*b.y}; }
template<> __device__ inline bf16 mul::op<bf16> (const bf16 &a, const bf16 &b) { return __hmul(a, b); }
template<> __device__ inline bf16_2 mul::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hmul2(a, b); }
template<> __device__ inline half mul::op<half> (const half &a, const half &b) { return __hmul(a, b); }
template<> __device__ inline half_2 mul::op<half_2>(const half_2 &a, const half_2 &b) { return __hmul2(a, b); }
/**
* @brief Division operation.
*
* This operation calculates the quotient of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The quotient of the input values.
*/
struct div {
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a/b; }
};
template<> __device__ inline float2 div::op<float2>(const float2 &a, const float2 &b) { return float2{a.x/b.x, a.y/b.y}; }
template<> __device__ inline bf16 div::op<bf16> (const bf16 &a, const bf16 &b) { return __hdiv(a, b); }
template<> __device__ inline bf16_2 div::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __h2div(a, b); } // this op is a special snowflake
template<> __device__ inline half div::op<half> (const half &a, const half &b) { return __hdiv(a, b); }
template<> __device__ inline half_2 div::op<half_2>(const half_2 &a, const half_2 &b) { return __h2div(a, b); }
/**
* @brief Maximum operation.
*
* This operation calculates the maximum of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The maximum of the input values.
*/
struct max {
template<typename T> static __device__ inline T op(const T &a, const T &b) { return ::max(a, b); }
};
template<> __device__ inline float2 max::op<float2>(const float2 &a, const float2 &b) { return float2{::max(a.x, b.x), ::max(a.y, b.y)}; }
template<> __device__ inline bf16 max::op<bf16> (const bf16 &a, const bf16 &b) { return __hmax(a, b); }
template<> __device__ inline bf16_2 max::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hmax2(a, b); }
template<> __device__ inline half max::op<half> (const half &a, const half &b) { return __hmax(a, b); }
template<> __device__ inline half_2 max::op<half_2>(const half_2 &a, const half_2 &b) { return half_2{__hmax(a.x, b.x), __hmax(a.y, b.y)}; }
/**
* @brief Minimum operation.
*
* This operation calculates the minimum of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The minimum of the input values.
*/
struct min {
template<typename T> static __device__ inline T op(const T &a, const T &b) { return ::min(a, b); }
};
template<> __device__ inline float2 min::op<float2>(const float2 &a, const float2 &b) { return float2{::min(a.x, b.x), ::min(a.y, b.y)}; }
template<> __device__ inline bf16 min::op<bf16> (const bf16 &a, const bf16 &b) { return __hmin(a, b); }
template<> __device__ inline bf16_2 min::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hmin2(a, b); }
template<> __device__ inline half min::op<half> (const half &a, const half &b) { return __hmin(a, b); }
template<> __device__ inline half_2 min::op<half_2>(const half_2 &a, const half_2 &b) { return half_2{__hmin(a.x, b.x), __hmin(a.y, b.y)}; }
/* ---------- TERNARY OPS ---------- */
/**
* @brief Fused multiply-add operation A * B + C.
*
* This operation performs a fused multiply-add, computing (A * B) + C with only one rounding.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @param c[in] The third input value to be added.
* @return The result of the fused multiply-add operation.
*/
struct fma_AxBtC {
template<typename T> static __device__ inline T op(const T &a, const T &b, const T &c) {
return sum::op<T>(mul::op<T>(a, b), c);
}
};
/**
* @brief Fused multiply-add operation A * C + B.
*
* This operation performs a fused multiply-add, computing (A * C) + B with only one rounding.
* This is particularly useful for attention mechanisms in neural networks.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The third input value to be added.
* @param c[in] The second input value.
* @return The result of the fused multiply-add operation.
*/
struct fma_AxCtB { // this is the one needed for attention
template<typename T> static __device__ inline T op(const T &a, const T &b, const T &c) {
return sum::op<T>(mul::op<T>(a, c), b);
}
};
} // namespace base_ops
} // namespace kittens

View file

@ -0,0 +1,454 @@
/**
* @file
* @brief Declarations, manipulations, and wrappers for basic types.
*
* This file is a bunch of utilities for going back and forth between different types.
*
* Many of them are for the compiler, so as to clean up the code. It unfortunately
* seems necessary when we have types we really care about that are less than word width.
*/
#pragma once
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_fp4.h>
#include <hip/amd_detail/amd_hip_ocp_types.h>
#include <hip/hip_runtime.h>
#include <string>
#include <bit>
typedef uint32_t __amd_fp8x4_storage_t;
namespace kittens {
// /**
// * @brief Bfloat16 floating-point type.
// */
using bf16 = __hip_bfloat16;
/**
* @brief Half-precision floating-point type.
*/
using half = __half;
// /**
// * @brief Packed word of two bfloat16 floating-point values.
// */
using bf16_2 = __hip_bfloat162;
/**
* @brief Packed word of two half-precision floating-point values.
*/
using half_2 = __half2;
/**
* @brief float8 floating-point type.
*/
using fp8e4m3 = __hip_fp8_e4m3;
/**
* @brief Packed word of two float8 floating-point values.
*/
using fp8e4m3_2 = __hip_fp8x2_e4m3;
/**
* @brief Packed word of four float8 floating-point values.
*/
using fp8e4m3_4 = __hip_fp8x4_e4m3;
/**
* @brief 8-bit exponent-only block-scaling scale type.
*/
using fp8e8m0 = __amd_scale_t;
/**
* @brief Packed word of two 8-bit exponent-only block-scaling scale values.
*/
using fp8e8m0_2 = __amd_fp8x2_storage_t;
/**
* @brief Packed word of four 8-bit exponent-only block-scaling scale values.
*/
using fp8e8m0_4 = __amd_fp8x4_storage_t;
/**
* @brief FP4 E2M1 floating-point type.
*/
using fp4e2m1 = __hip_fp4_e2m1;
/**
* @brief Packed word of two FP4 E2M1 floating-point values.
*/
using fp4e2m1_2 = __hip_fp4x2_e2m1;
/**
* @brief Packed word of four FP4 E2M1 floating-point values.
*/
using fp4e2m1_4 = __hip_fp4x4_e2m1;
namespace ducks {
/**
* @namespace base_types
*
* @brief A namespace for concepts for basic data types.
*/
namespace base_types {
template<typename T>
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4>
|| std::is_same_v<T, fp4e2m1_4>;
template<typename T>
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3>
|| std::is_same_v<T, fp4e2m1>;
} // namespace base_types
} // namespace ducks
/**
* @namespace base_types
*
* @brief A namespace for ThunderKittens basic data types.
*/
namespace base_types {
/**
* @brief Provides compile-time constants for different types.
*
* @tparam T The type for which to provide constants.
*/
template<typename T> struct constants {
/**
* @brief Zero
* @return Constexpr zero with type T
*/
static __device__ inline constexpr T zero() { return T{0}; }
/**
* @brief Ones
* @return Constexpr ones with type T
*/
static __device__ inline constexpr T ones() { return T{1}; }
/**
* @brief Positive infinity. Particularly useful for initializing before a min op.
* @return Constexpr positive infinity with type T
*/
static __device__ inline constexpr T pos_infty() { return T{INFINITY}; } // I'll find a better way at some point but this appears to work.
/**
* @brief Negative infinity. Particularly useful for initializing before a max op.
* @return Constexpr negative infinity with type T
*/
static __device__ inline constexpr T neg_infty() { return T{-INFINITY}; }
};
template<> struct constants<float2> {
static __device__ inline constexpr float2 zero() { return float2{0.f, 0.f}; }
static __device__ inline constexpr float2 ones() { return float2{1.f, 1.f}; }
static __device__ inline constexpr float2 pos_infty() { return float2{constants<float>::pos_infty(), constants<float>::pos_infty()}; }
static __device__ inline constexpr float2 neg_infty() { return float2{constants<float>::neg_infty(), constants<float>::neg_infty()}; }
};
template<> struct constants<bf16> {
static __device__ inline constexpr bf16 zero() { return std::bit_cast<bf16>(uint16_t(0x0000)); } // unfortunately __float2bf16_rn is not constexpr
static __device__ inline constexpr bf16 ones() { return std::bit_cast<bf16>(uint16_t(0x3F80)); }
static __device__ inline constexpr bf16 pos_infty() { return std::bit_cast<bf16>(uint16_t(0x7F80)); }
static __device__ inline constexpr bf16 neg_infty() { return std::bit_cast<bf16>(uint16_t(0xFF80)); }
};
template<> struct constants<bf16_2> {
static __device__ inline bf16_2 zero() { return bf16_2{constants<bf16>::zero(), constants<bf16>::zero()}; }
static __device__ inline bf16_2 ones() { return bf16_2{constants<bf16>::ones(), constants<bf16>::ones()}; }
static __device__ inline bf16_2 pos_infty() { return bf16_2{constants<bf16>::pos_infty(), constants<bf16>::pos_infty()}; }
static __device__ inline bf16_2 neg_infty() { return bf16_2{constants<bf16>::neg_infty(), constants<bf16>::neg_infty()}; }
};
template<> struct constants<half> {
static __device__ inline constexpr half zero() { return std::bit_cast<half>(uint16_t(0x0000)); }
static __device__ inline constexpr half ones() { return std::bit_cast<half>(uint16_t(0x3C00)); }
static __device__ inline constexpr half pos_infty() { return std::bit_cast<half>(uint16_t(0x7C00)); }
static __device__ inline constexpr half neg_infty() { return std::bit_cast<half>(uint16_t(0xFC00)); }
};
template<> struct constants<half_2> {
static __device__ inline constexpr half_2 zero() { return std::bit_cast<half_2>(uint32_t(0x00000000)); }
static __device__ inline constexpr half_2 ones() { return std::bit_cast<half_2>(uint32_t(0x3C003C00)); }
static __device__ inline constexpr half_2 pos_infty() { return std::bit_cast<half_2>(uint32_t(0x7C007C00)); }
static __device__ inline constexpr half_2 neg_infty() { return std::bit_cast<half_2>(uint32_t(0xFC00FC00)); }
};
template<> struct constants<fp8e4m3> {
static __device__ inline constexpr fp8e4m3 zero() { return std::bit_cast<fp8e4m3>(uint8_t(0x00)); }
static __device__ inline constexpr fp8e4m3 one() { return std::bit_cast<fp8e4m3>(uint8_t(0x38)); }
};
template<> struct constants<fp8e4m3_2> {
static __device__ inline constexpr fp8e4m3_2 zero() { return std::bit_cast<fp8e4m3_2>(uint16_t(0x0000)); }
static __device__ inline constexpr fp8e4m3_2 one() { return std::bit_cast<fp8e4m3_2>(uint16_t(0x3838)); }
};
template<> struct constants<fp8e4m3_4> {
static __device__ inline constexpr fp8e4m3_4 zero() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x00000000)); }
static __device__ inline constexpr fp8e4m3_4 one() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x38383838)); }
};
template<> struct constants<fp8e8m0> {
static __device__ inline constexpr fp8e8m0 zero() { return std::bit_cast<fp8e8m0>(uint8_t(0x00)); } // not actually 0
static __device__ inline constexpr fp8e8m0 one() { return std::bit_cast<fp8e8m0>(uint8_t(0x7F)); }
};
template<> struct constants<fp8e8m0_2> {
static __device__ inline constexpr fp8e8m0_2 zero() { return std::bit_cast<fp8e8m0_2>(uint16_t(0x0000)); } // not actually 0
static __device__ inline constexpr fp8e8m0_2 one() { return std::bit_cast<fp8e8m0_2>(uint16_t(0x7F7F)); }
};
template<> struct constants<fp8e8m0_4> {
static __device__ inline constexpr fp8e8m0_4 zero() { return std::bit_cast<fp8e8m0_4>(uint32_t(0x00000000)); } // not actually 0
static __device__ inline constexpr fp8e8m0_4 one() { return std::bit_cast<fp8e8m0_4>(uint32_t(0x7F7F7F7F)); }
};
template<> struct constants<fp4e2m1> {
static __device__ inline constexpr fp4e2m1 zero() { return std::bit_cast<fp4e2m1>(uint8_t(0x00)); }
static __device__ inline constexpr fp4e2m1 one() { return std::bit_cast<fp4e2m1>(uint8_t(0x02)); }
};
template<> struct constants<fp4e2m1_4> {
static __device__ inline constexpr fp4e2m1_4 zero() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x0000)); }
static __device__ inline constexpr fp4e2m1_4 one() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x2222)); }
};
template<> struct constants<int> {
static __device__ inline constexpr int zero() { return 0; }
static __device__ inline constexpr int ones() { return 1; }
};
template<> struct constants<int2> {
static __device__ inline constexpr int2 zero() { return int2{0, 0}; }
static __device__ inline constexpr int2 ones() { return int2{1, 1}; }
};
/**
* @brief Provides information about packing of elements for a given type.
*
* @tparam T The type for which to provide packing information.
*/
template<typename T> struct packing {
/**
* @brief The number of elements packed together.
*
* @return constexpr int representing number of elements within the type.
*/
static __host__ __device__ inline constexpr int num() { return 1; }
/**
* @brief Packs a single T element twice (replicated) into its packed type.
*
* @param i[in] The element to pack.
* @return The packed type.
*/
static __device__ inline constexpr T pack(const auto &i);
};
template<> struct packing<bf16> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = bf16;
using packed_type = bf16_2;
static __device__ inline bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; }
};
template<> struct packing<bf16_2> {
static __host__ __device__ inline constexpr int num() { return 2; }
using unpacked_type = bf16;
using packed_type = bf16_2;
static __device__ inline bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; } // this replication makes code cleaner later.
};
template<> struct packing<half> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = half;
using packed_type = half_2;
static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; }
};
template<> struct packing<half_2> {
static __host__ __device__ inline constexpr int num() { return 2; }
using unpacked_type = half;
using packed_type = half_2;
static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; } // this replication makes code cleaner later.
};
template<> struct packing<float> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = float;
using packed_type = float2;
static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; }
};
template<> struct packing<float2> {
static __host__ __device__ inline constexpr int num() { return 2; }
using unpacked_type = float;
using packed_type = float2;
static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; } // this replication makes code cleaner later.
};
template<> struct packing<int> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = int;
using packed_type = int2;
static __device__ inline constexpr int2 pack(const int &i) { return int2{i, i}; } // this replication makes code cleaner later.
};
template<> struct packing<int2> {
static __host__ __device__ inline constexpr int num() { return 2; }
using unpacked_type = int;
using packed_type = int2;
static __device__ inline constexpr int2 pack(const int &i) { return int2{i, i}; } // this replication makes code cleaner later.
};
template<> struct packing<float4> {
static __host__ __device__ inline constexpr int num() { return 4; }
};
template<> struct packing<int4> {
static __host__ __device__ inline constexpr int num() { return 4; }
};
template<> struct packing<fp8e4m3> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = fp8e4m3;
using packed_type = fp8e4m3_4;
};
template<> struct packing<fp8e4m3_4> {
static __host__ __device__ inline constexpr int num() { return 4; }
using unpacked_type = fp8e4m3;
using packed_type = fp8e4m3_4;
};
template<> struct packing<fp8e8m0> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = fp8e8m0;
using packed_type = fp8e8m0_4;
};
template<> struct packing<fp8e8m0_4> {
static __host__ __device__ inline constexpr int num() { return 4; }
using unpacked_type = fp8e8m0;
using packed_type = fp8e8m0_4;
};
template<> struct packing<fp4e2m1> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = fp4e2m1;
using packed_type = fp4e2m1_4;
};
template<> struct packing<fp4e2m1_4> {
static __host__ __device__ inline constexpr int num() { return 4; }
using unpacked_type = fp4e2m1;
using packed_type = fp4e2m1_4;
};
/**
* @brief Provides templated functionality to convert between different types.
*
* @tparam T The target type for conversion.
* @tparam U The source type for conversion.
*/
template<typename T, typename U> struct convertor {
/**
* @brief Converts a value of type U to type T.
*
* @param u[in] The value of type U to convert.
* @return T The converted value of type T.
*/
static __host__ __device__ inline T convert(const U & u) {
return (T)u;
}
};
template<> struct convertor<float, bf16> {
static __host__ __device__ inline float convert(const bf16 & u) {
return __bfloat162float(u);
}
};
// template<> struct convertor<bf16, float> {
// static __host__ __device__ inline bf16 convert(const float & u) {
// return __float2bfloat16(u);
// }
// };
template<> struct convertor<bf16, float> {
static __host__ __device__ inline bf16 convert(const float &u) {
// Fast unsafe conversion (truncation only)
return std::bit_cast<bf16>(
static_cast<uint16_t>(
std::bit_cast<uint32_t>(u) >> 16
)
);
}
};
template<> struct convertor<float2, bf16_2> {
static __host__ __device__ inline float2 convert(const bf16_2 & u) {
return __bfloat1622float2(u);
}
};
template<> struct convertor<bf16_2, float2> {
static __host__ __device__ inline bf16_2 convert(const float2 &u) {
uint32_t result;
asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2"
: "=v"(result)
: "v"(u.x), "v"(u.y));
return *reinterpret_cast<bf16_2*>(&result);
}
};
template<> struct convertor<float, half> {
static __host__ __device__ inline float convert(const half & u) {
return __half2float(u);
}
};
template<> struct convertor<half, float> {
static __host__ __device__ inline half convert(const float & u) {
return __float2half(u);
}
};
template<> struct convertor<float2, half_2> {
static __host__ __device__ inline float2 convert(const half_2 & u) {
return __half22float2(u);
}
};
template<> struct convertor<half_2, float2> {
static __host__ __device__ inline half_2 convert(const float2 & u) {
return __float22half2_rn(u);
}
};
template<> struct convertor<bf16, half> {
static __host__ __device__ inline bf16 convert(const half & u) {
return __float2bfloat16(__half2float(u));
}
};
template<> struct convertor<half, bf16> {
static __host__ __device__ inline half convert(const bf16 & u) {
return __float2half(__bfloat162float(u));
}
};
template<> struct convertor<bf16_2, half_2> {
static __host__ __device__ inline bf16_2 convert(const half_2 & u) {
return __float22bfloat162_rn(__half22float2(u));
}
};
template<> struct convertor<half_2, bf16_2> {
static __host__ __device__ inline half_2 convert(const bf16_2 & u) {
return __float22half2_rn(__bfloat1622float2(u));
}
};
template<> struct convertor<fp8e4m3_4, float4> {
static __host__ __device__ inline fp8e4m3_4 convert(const float4& u) {
return fp8e4m3_4(u);
}
};
template<> struct convertor<float4, fp8e4m3_4> {
static __host__ __device__ inline float4 convert(const fp8e4m3_4& u) {
fp8e4m3 *vals = reinterpret_cast<fp8e4m3*>(const_cast<fp8e4m3_4*>(&u));
return make_float4(float(vals[0]), float(vals[1]), float(vals[2]), float(vals[3]));
}
};
template<> struct convertor<fp8e4m3_2, float2> {
static __host__ __device__ inline fp8e4m3_2 convert(const float2& u) {
return fp8e4m3_2(u);
}
};
template<> struct convertor<float2, fp8e4m3_2> {
static __host__ __device__ inline float2 convert(const fp8e4m3_2& u) {
fp8e4m3 *vals = reinterpret_cast<fp8e4m3*>(const_cast<fp8e4m3_2*>(&u));
return make_float2(float(vals[0]), float(vals[1]));
}
};
template<> struct convertor<fp8e4m3, float> {
static __host__ __device__ inline fp8e4m3 convert(const float & u) {
return fp8e4m3(u);
}
};
template<> struct convertor<float, fp8e4m3> {
static __host__ __device__ inline float convert(const fp8e4m3 & u) {
return float(u);
}
};
template<> struct convertor<fp4e2m1, float> {
static __host__ __device__ inline fp4e2m1 convert(const float & u) {
return fp4e2m1(u);
}
};
template<> struct convertor<float, fp4e2m1> {
static __host__ __device__ inline float convert(const fp4e2m1 & u) {
return float(u);
}
};
template<> struct convertor<fp4e2m1_4, float4> {
static __host__ __device__ inline fp4e2m1_4 convert(const float4& u) {
return fp4e2m1_4(u);
}
};
template<> struct convertor<float4, fp4e2m1_4> {
static __host__ __device__ inline float4 convert(const fp4e2m1_4& u) {
return float4(u);
}
};
}
}

View file

@ -0,0 +1,13 @@
/**
* @file
* @brief A collection of common resources on which HipKittens depends.
*/
#pragma once
#include "util.cuh"
#include "base_types.cuh"
#include "base_ops.cuh"
#include "macros.cuh"

View file

@ -0,0 +1,56 @@
#pragma once
// Reset
#define TK_RESET "\033[0m"
// Foreground colors
#define TK_FG_BLACK "\033[30m"
#define TK_FG_RED "\033[31m"
#define TK_FG_GREEN "\033[32m"
#define TK_FG_YELLOW "\033[33m"
#define TK_FG_BLUE "\033[34m"
#define TK_FG_MAGENTA "\033[35m"
#define TK_FG_CYAN "\033[36m"
#define TK_FG_WHITE "\033[37m"
// Background colors
#define TK_BG_BLACK "\033[40m"
#define TK_BG_RED "\033[41m"
#define TK_BG_GREEN "\033[42m"
#define TK_BG_YELLOW "\033[43m"
#define TK_BG_BLUE "\033[44m"
#define TK_BG_MAGENTA "\033[45m"
#define TK_BG_CYAN "\033[46m"
#define TK_BG_WHITE "\033[47m"
// Bright foreground colors
#define TK_FG_BRIGHT_BLACK "\033[90m"
#define TK_FG_BRIGHT_RED "\033[91m"
#define TK_FG_BRIGHT_GREEN "\033[92m"
#define TK_FG_BRIGHT_YELLOW "\033[93m"
#define TK_FG_BRIGHT_BLUE "\033[94m"
#define TK_FG_BRIGHT_MAGENTA "\033[95m"
#define TK_FG_BRIGHT_CYAN "\033[96m"
#define TK_FG_BRIGHT_WHITE "\033[97m"
// Bright background colors
#define TK_BG_BRIGHT_BLACK "\033[100m"
#define TK_BG_BRIGHT_RED "\033[101m"
#define TK_BG_BRIGHT_GREEN "\033[102m"
#define TK_BG_BRIGHT_YELLOW "\033[103m"
#define TK_BG_BRIGHT_BLUE "\033[104m"
#define TK_BG_BRIGHT_MAGENTA "\033[105m"
#define TK_BG_BRIGHT_CYAN "\033[106m"
#define TK_BG_BRIGHT_WHITE "\033[107m"
// Text styles
#define TK_BOLD "\033[1m"
#define TK_DIM "\033[2m"
#define TK_ITALIC "\033[3m"
#define TK_UNDERLINE "\033[4m"
#define TK_BLINK "\033[5m"
#define TK_REVERSE "\033[7m"
#define TK_HIDDEN "\033[8m"
// Macro to combine styles
#define TK_STYLE(...) "\033[" #__VA_ARGS__ "m"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,437 @@
/**
* @file
* @brief General utilities for ThunderKittens.
*/
#pragma once
#include <stdint.h>
#include <type_traits>
#include <concepts>
#include <memory>
#include <hip/hip_runtime.h>
#include "base_types.cuh"
#ifndef __forceinline__
#define __forceinline__ __attribute__((always_inline))
#endif
/**
* @namespace kittens
*
* @brief The main namespace of ThunderKittens.
*/
namespace kittens {
/* ---------- GENERAL CONSTANTS FOR KITTENS ---------- */
/**
* @brief Constant representing number of threads in a warp.
*
* gfx1250 (UDNA1) is wave-32.
*/
constexpr int WARP_THREADS{32};
/**
* @brief Get the warp ID of the current thread.
* @return The warp ID.
*/
__device__ __forceinline__ int warpid() { return threadIdx.x >> 5; }
/**
* @brief Get the number of warps in the threadblock.
* @return The number of warps in the threadblock.
*/
__device__ __forceinline__ int num_warps() { return blockDim.x / WARP_THREADS; }
/**
* @brief Get the lane ID of the current thread within its warp.
* @return The lane ID.
*/
__device__ __forceinline__ int laneid() { return threadIdx.x & 0x1f; }
using i32x2 = int32_t __attribute__((ext_vector_type(2)));
using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
using i32x3 = int32_t __attribute__((ext_vector_type(3)));
using u32x3 = uint32_t __attribute__((ext_vector_type(3)));
using i32x4 = int32_t __attribute__((ext_vector_type(4)));
using u32x4 = uint32_t __attribute__((ext_vector_type(4)));
struct buffer_resource {
uint64_t ptr;
uint32_t range;
uint32_t config;
};
/**
* @brief Compute the ceiling division of two integers.
* @param a The dividend.
* @param b The divisor.
* @return The ceiling division result.
*/
__host__ __device__ inline int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
/**
* @brief Transform a workgroup ID to a new workgroup ID based on the chunk size and number of XCDs.
* @param workgroup_id The original workgroup ID.
* @param num_workgroups The total number of workgroups.
* @param num_xcds The number of XCDs.
* @param chunk_size The chunk size.
* @return The new workgroup ID.
*/
__host__ __device__ inline int chiplet_transform_chunked(
int workgroup_id,
int num_workgroups,
int num_xcds,
int chunk_size
) {
// Current XCD
int xcd = workgroup_id % num_xcds;
// Largest full (NUM_XCDS*CHUNK_SIZE)-aligned block
int block = num_xcds * chunk_size;
int limit = (num_workgroups / block) * block;
// If pid beyond the last full block, leave unchanged
if (workgroup_id > limit) return workgroup_id;
// Local PID (within round-robin assignment)
int local_pid = workgroup_id / num_xcds;
int chunk_idx = local_pid / chunk_size;
int pos_in_chunk = local_pid % chunk_size;
// New PID
return chunk_idx * block + xcd * chunk_size + pos_in_chunk;
}
/**
* @brief gfx1250 LDS capacity constants.
*
* On gfx1250, the **LDS scratchpad and the L1 data cache are one 384 KB SRAM
* pool per Compute Unit (CU)**, partitioned into six 64 KB segments.
* At least one segment must remain L1, leaving up to five segments
* (320 KB) addressable as LDS.
*
* `MAX_SHARED_MEMORY_PER_SEGMENT` is one 64 KB segment; `MAX_SHARED_MEMORY` is
* the full addressable LDS across all five segments. A kernel that fits in one
* segment requests `MAX_SHARED_MEMORY_PER_SEGMENT`; one that needs more requests
* a larger dynamic shared-memory size at launch via `hipFuncSetAttribute`.
*/
constexpr int MAX_SHARED_MEMORY_PER_SEGMENT = 65536;
constexpr int SHARED_MEMORY_NUM_SEGMENTS = 5;
constexpr int MAX_SHARED_MEMORY = MAX_SHARED_MEMORY_PER_SEGMENT * SHARED_MEMORY_NUM_SEGMENTS;
constexpr int NUM_XCDS = 1;
constexpr int CUS_PER_XCD = 64;
constexpr int NUM_CUS = CUS_PER_XCD * NUM_XCDS;
/* ---------- CUSTOM TYPES ---------- */
typedef uint32_t uint2_t __attribute__((ext_vector_type(2)));
/* ---------- TYPE HELPERS ---------- */
/**
* @namespace ducks
*
* @brief ThunderKittens' namespace for template metaprogramming..
*
* This includes primarily dummy types and concept wrappers, along
* with a few additional utilities.
*/
namespace ducks {
/**
* @brief A type representing an empty default for a template.
*/
struct default_type {};
// This macro can't be done as a template, so it doesn't really have a location in kittens.
#define typeof(A) typename std::remove_const<typename std::remove_reference<decltype(A)>::type>::type
}
/* ---------- SHUFFLE UTILS ---------- */
/**
* @brief Mask constant for all active threads in a warp.
*/
static constexpr uint64_t MASK_ALL = 0xFFFFFFFFFFFFFFFF;
/**
* @brief Perform a shuffle down operation on a packed type synchronously across a warp.
* @tparam T The type of the value to be shuffled.
* @param mask[in] The mask of active threads.
* @param f[in] The value to be shuffled.
* @param delta[in] The number of positions to shuffle down.
* @return The result of the shuffle operation.
*/
template<typename T>
__device__ static inline T packed_shfl_down(uint64_t mask, const T &f, int delta) {
if constexpr (std::is_same_v<T, bf16_2> || std::is_same_v<T, bf16>) {
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
union {
__hip_bfloat162 bf162;
unsigned int ui;
} u;
if constexpr (std::is_same_v<T, bf16_2>) {
u.bf162 = *reinterpret_cast<const __hip_bfloat162*>(&f);
} else {
u.bf162 = __hip_bfloat162{*reinterpret_cast<const __hip_bfloat16*>(&f),
*reinterpret_cast<const __hip_bfloat16*>(&f)};
}
u.ui = __shfl_down_sync<unsigned long long, unsigned int>(mask, u.ui, delta, 64);
if constexpr (std::is_same_v<T, bf16>) {
return *reinterpret_cast<const T*>(&u.bf162.x); // Extract single bf16 from the .x component
} else {
return u.bf162; // Return full bf162 for bf16_2 case
}
} else {
return __shfl_down(f, delta);
}
}
template<>
__device__ inline float2 packed_shfl_down<float2>(uint64_t mask, const float2 &f, int delta) {
float2 r;
r.x = __shfl_down(f.x, delta);
r.y = __shfl_down(f.y, delta);
return r;
}
/**
* @brief Perform a packed shuffle operation synchronously across a warp.
* @tparam T The type of the value to be shuffled.
* @param mask[in] The mask of active threads.
* @param f[in] The value to be shuffled.
* @param src[in] The source lane from which to shuffle.
* @return The result of the shuffle operation.
*/
template<typename T>
__device__ static inline T packed_shfl(uint64_t mask, const T &f, int src) {
return __shfl(f, src);
}
template<>
__device__ inline bf16 packed_shfl(uint64_t mask, const bf16 &f, int src) {
float r = __shfl(base_types::convertor<float, bf16>::convert(f), src);
return base_types::convertor<bf16, float>::convert(r);
}
template<>
__device__ inline bf16_2 packed_shfl(uint64_t mask, const bf16_2 &f, int src) {
float2 r;
r.x = __shfl(base_types::convertor<float, bf16>::convert(f.x), src);
r.y = __shfl(base_types::convertor<float, bf16>::convert(f.y), src);
return base_types::convertor<bf16_2, float2>::convert(r);
}
template<>
__device__ inline half packed_shfl(uint64_t mask, const half &f, int src) {
float r = __shfl(base_types::convertor<float, half>::convert(f), src);
return base_types::convertor<half, float>::convert(r);
}
template<>
__device__ inline half_2 packed_shfl(uint64_t mask, const half_2 &f, int src) {
float2 r;
r.x = __shfl(base_types::convertor<float, half>::convert(f.x), src);
r.y = __shfl(base_types::convertor<float, half>::convert(f.y), src);
return base_types::convertor<half_2, float2>::convert(r);
}
template<>
__device__ inline float2 packed_shfl<float2>(uint64_t mask, const float2 &f, int src) {
float2 r;
r.x = __shfl(f.x, src);
r.y = __shfl(f.y, src);
return r;
}
using bytes_4 = HIP_vector_type<float, 1>;
using bytes_8 = HIP_vector_type<float, 2>;
using bytes_16 = HIP_vector_type<float, 4>;
/* ---------- SHARED MEMORY UTILS ---------- */
// namespace ducks {
// namespace sb {
// struct identifier {};
// }
// }
// template<typename Args...>
// struct sb {
// using identifier = ducks::sb::identifier;
// Args... args;
// };
// namespace ducks {
// namespace sb {
// template<typename T> concept all = requires {
// typename T::identifier;
// } && std::is_same_v<T::identifier, identifier>;
// }
// }
#define KITTENS_ALIGN_AS(n) alignas(n)
#define KITTENS_DEFAULT_ALIGN KITTENS_ALIGN_AS(16)
/**
* @brief Dummy structure for alignment purposes. Needed for WGMMA and TMA calls.
*/
struct KITTENS_DEFAULT_ALIGN alignment_dummy { int dummy; };
namespace detail {
/// @brief 16B (`int4`) vector types tagged with the address spaces the gfx1250
/// `*_load_async_to_lds_b128` builtins require (AS1 = global, AS3 = LDS).
using i32x4_vec = int __attribute__((__vector_size__(16)));
using i32x4_gvec = int __attribute__((__vector_size__(16))) __attribute__((address_space(1)));
using i32x4_lvec = int __attribute__((__vector_size__(16))) __attribute__((address_space(3)));
} // namespace detail
/**
* @brief Compile-time tag selecting an LDS segment for tile placement on gfx1250.
*
* Background. LDS and L1 share one 384 KB SRAM pool per Compute Unit (CU),
* partitioned at dispatch into six 64 KB segments (see `MAX_SHARED_MEMORY`
* above). Up to five segments (indices 0..4, total 320 KB) are addressable as
* LDS scratchpad; at least one segment must remain L1. By convention we leave
* segment 5 as L1, so LDS-tile placement uses indices 0..4.
*
* Why segments matter. The LDS half of the pool is fronted by two read ports
* delivering 256 B/cycle each. The two ports can issue in the same cycle only
* when they target **different** segments, so placing operand `A` in
* `segment<0>` and operand `B` in `segment<1>` lets the hardware satisfy both
* reads in parallel and reach the full 512 B/cycle peak. Co-locating `A` and
* `B` in the same segment serialises them at 256 B/cycle.
*
* @tparam IDX 0..4 -- segment index. The allocator aligns the allocation start
* to `IDX * 64 KB` so multiple tiles can share a single segment.
*/
template<int IDX>
struct segment {
static_assert(IDX >= 0 && IDX < SHARED_MEMORY_NUM_SEGMENTS,
"segment index must be in [0, 5)");
static constexpr int index = IDX;
static constexpr int byte_offset = IDX * MAX_SHARED_MEMORY_PER_SEGMENT;
};
namespace ducks {
namespace segment_tag {
template<typename T> struct is_segment : std::false_type {};
template<int I> struct is_segment<::kittens::segment<I>> : std::true_type {};
template<typename T> concept all = is_segment<T>::value;
} // namespace segment_tag
} // namespace ducks
/**
* @brief Very simple allocator for dynamic shared memory. Advances pointer and tracks alignments.
*
* Maintains a bump cursor `ptr` that advances on every `allocate*()` call. On
* gfx1250 the allocator also remembers `base` -- the unmoved origin of the
* shared-memory region captured at construction -- so segment-aware
* allocations (`allocate_in<segment<IDX>>`) can jump to `base + IDX * 64 KB`
* regardless of how far the bump cursor has already advanced.
*
* @tparam default_alignment The default alignment this allocator will enforce. If <=0 (default -1) it will not align.
*/
template<int default_alignment=16>
struct shared_allocator {
int *ptr; ///< Bump cursor; advances on every allocate*() call.
int *base; ///< Frozen origin captured at construction; never moves.
///< Reference point for `allocate_in<segment<IDX>>` segment starts.
private:
// Recursive template to generate N-dimensional array type
template<typename A, size_t... dims>
struct variadic_array;
template<typename A, size_t first_dim, size_t... rest_dims>
struct variadic_array<A, first_dim, rest_dims...> {
using type = typename variadic_array<A, rest_dims...>::type[first_dim];
};
template<typename A>
struct variadic_array<A> {
using type = A;
};
template<typename A, size_t... dims>
using variadic_array_t = typename variadic_array<A, dims...>::type;
template<int alignment>
__device__ inline void align_ptr() {
if constexpr (alignment > 0) {
uint64_t p = reinterpret_cast<uint64_t>(ptr);
if(p % alignment != 0) {
ptr = (int*)(p + (alignment-(p%alignment)));
}
}
}
public:
/**
* @brief Construct a new shared allocator using a pointer to extern shared memory.
*
* `_ptr` is captured into the bump cursor `ptr`; on gfx1250 it is also
* stashed into `base` so segment-aware allocations can recover the
* original origin regardless of how far the cursor has advanced.
*
* @param[in] _ptr Pointer to the start of the extern shared memory.
*/
__device__ shared_allocator(int *_ptr): ptr(_ptr), base(_ptr) {}
/**
* @brief Allocate shared memory for a single instance or N-dimensional array of type A.
* @tparam A The type of the object to allocate.
* @tparam dims... A list of dimensions for the N-dimensional array.
* @return Reference to the allocated object.
*/
template<typename A, size_t... dims>
__device__ inline variadic_array_t<A, dims...>& allocate() {
// static_assert(sizeof(A) % default_alignment == 0, "Type is not aligned properly for array allocation");
align_ptr<default_alignment>();
using at = variadic_array_t<A, dims...>;
at*p = reinterpret_cast<at*>(ptr);
ptr += sizeof(at)/sizeof(int);
return *p;
}
/**
* @brief Allocate shared memory for a single instance or N-dimensional array of type A.
* @tparam alignment An alignment to enforce for this particular object.
* @tparam A The type of the object to allocate.
* @tparam dims... A list of dimensions for the N-dimensional array.
* @return Reference to the allocated object.
*/
template<int alignment, typename A, size_t... dims>
__device__ inline variadic_array_t<A, dims...>& allocate() {
// static_assert(sizeof(A) % alignment == 0, "Type is not aligned properly for array allocation");
align_ptr<alignment>();
using at = variadic_array_t<A, dims...>;
at*p = reinterpret_cast<at*>(ptr);
ptr += sizeof(at)/sizeof(int);
return *p;
}
/**
* @brief Allocate shared memory inside a specific LDS segment on gfx1250.
*
* Positions the allocator pointer at `base + IDX * 64KB` (where `base`
* is the dynamic-shared-memory pointer this allocator was constructed
* with), then allocates the requested type there. Multiple
* `allocate_in<segment<IDX>>` calls into the same segment pack tightly.
*
* @tparam SEG A `kittens::segment<IDX>` tag.
* @tparam A The type of the object to allocate.
* @tparam dims Optional array dimensions.
*/
template<typename SEG, typename A, size_t... dims>
requires ducks::segment_tag::all<SEG>
__device__ inline variadic_array_t<A, dims...>& allocate_in() {
int* target = base + (SEG::byte_offset / sizeof(int));
// If we've already allocated past the requested segment, keep
// packing where we are; otherwise jump forward to the segment.
if (ptr < target) ptr = target;
using at = variadic_array_t<A, dims...>;
at* p = reinterpret_cast<at*>(ptr);
ptr += sizeof(at) / sizeof(int);
return *p;
}
};
} // namespace kittens

View file

@ -0,0 +1,10 @@
/**
* @file
* @brief Aggregate header for UDNA1 (gfx1250) ThunderKittens core include tree.
*/
#pragma once
#include "common/common.cuh"
#include "types/types.cuh"
#include "ops/ops.cuh"

View file

@ -0,0 +1,32 @@
/**
* @file
* @brief An aggregate header of all group (multi-warp) operations defined by ThunderKittens
*/
#pragma once
#include "../../common/common.cuh"
#include "../../types/types.cuh"
#include "../warp/warp.cuh" // several group memory ops rely on underlying warp-scope ops
// A "warpgroup" is a special group of 4 consecutive warps defined by NVIDIA for certain SM_90+ operations.
#define KITTENS_CHECK_WARPGROUP static_assert(N_WARPS==4, "PTX warpgroup (N_WARPS=4) function called from a non-warpgroup group.");
namespace kittens {
/*
This is meant to be used with a `using group_N = kittens::group<NUM_WORKERS>;` at the start of every kernel.
*/
template<int N_WARPS>
struct group {
static constexpr int GROUP_WARPS = N_WARPS; // This alias produces nice parallelism.
static constexpr int GROUP_THREADS = N_WARPS * kittens::WARP_THREADS; // This alias produces nice parallelism.
__device__ static inline int laneid() { return threadIdx.x % GROUP_THREADS; }
__device__ static inline int warpid() { return laneid() / kittens::WARP_THREADS; }
__device__ static inline int groupid() { return threadIdx.x / GROUP_THREADS; }
#include "memory/memory.cuh"
};
using warpgroup = group<4>; // special scope commonly used by SM_90 and later.
}

View file

@ -0,0 +1,7 @@
/**
* @file
* @brief An aggregate header of colaborative group memory movement operations
*/
#include "tile/tile.cuh"
#include "vec/vec.cuh"

View file

@ -0,0 +1,12 @@
/**
* @file
* @brief Group (collaborative warp) ops for loading shared tiles from and storing to global memory.
*/
template<int axis, bool assume_aligned, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) {
kittens::store<axis, assume_aligned, ST, GL, COORD, GROUP_THREADS>(dst, src, idx);
}
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>> // default case
__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) {
kittens::store<2, false, ST, GL, COORD, GROUP_THREADS>(dst, src, idx);
}

View file

@ -0,0 +1,8 @@
/**
* @file
* @brief An aggregate header of group memory operations on tiles.
*/
#include "global_to_shared.cuh"

View file

@ -0,0 +1,9 @@
/**
* @file
* @brief Group (collaborative warp) ops for loading shared vectors from and storing to global memory.
*/
template<ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>> // default case
__device__ static inline void store(const GL &dst, const SV &src, const COORD &idx) {
kittens::store<SV, GL, COORD, GROUP_THREADS>(dst, src, idx);
}

View file

@ -0,0 +1,6 @@
/**
* @file
* @brief An aggregate header of group memory operations on vectors.
*/
#include "global_to_shared.cuh"

View file

@ -0,0 +1,9 @@
/**
* @file
* @brief A collection of all of the operations that ThunderKittens defines.
*/
#pragma once
#include "warp/warp.cuh"
#include "group/group.cuh"

Some files were not shown because too many files have changed in this diff Show more