mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
54 commits
remove_def
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
687ade119e |
||
|
|
0a8e61d0c5 |
||
|
|
dfea9e7994 |
||
|
|
ce87d80911 |
||
|
|
5a2b3b7b06 |
||
|
|
116045cc8e |
||
|
|
7c1d0b6d9a |
||
|
|
c9dc1d63cc |
||
|
|
da98fae9e1 |
||
|
|
15988b5941 |
||
|
|
cbfcf36e44 |
||
|
|
f9c8c697d6 |
||
|
|
0138480910 |
||
|
|
33b635d23a |
||
|
|
625d8bbd0d |
||
|
|
fe9b19b12d |
||
|
|
267af9c601 |
||
|
|
97da54b9d6 |
||
|
|
fd0dc40689 |
||
|
|
2d8b802958 |
||
|
|
ba1d3baae8 |
||
|
|
d80a41d559 |
||
|
|
5164c21b44 |
||
|
|
58ff75272e |
||
|
|
b50da5c205 |
||
|
|
4618d27129 |
||
|
|
9ae0a93d0e |
||
|
|
30830850a9 |
||
|
|
8b07cca9f7 |
||
|
|
b2199c54a3 |
||
|
|
1822eed8d3 |
||
|
|
bba611bb59 |
||
|
|
67c3e589a1 |
||
|
|
649971f02a |
||
|
|
b05bea81ce |
||
|
|
97c2e7a3d9 |
||
|
|
d7b10c69bc |
||
|
|
091ec8d10d |
||
|
|
925c49ce99 |
||
|
|
05249466ed |
||
|
|
4a4b6956df |
||
|
|
eda0a402d1 |
||
|
|
5989d0b150 |
||
|
|
d37248c3ec |
||
|
|
d74f488376 |
||
|
|
d7a1022188 |
||
|
|
924bece1d5 |
||
|
|
b753fb5e4c |
||
|
|
31094a794f |
||
|
|
1720987dc7 |
||
|
|
bed0c343a3 |
||
|
|
e0fe6e542e |
||
|
|
a74b7130b4 |
||
|
|
df015ad541 |
255 changed files with 15025 additions and 1666 deletions
6
.github/actions/setup-tinygrad/action.yml
vendored
6
.github/actions/setup-tinygrad/action.yml
vendored
|
|
@ -80,7 +80,7 @@ runs:
|
|||
- name: Cache Python packages (PR)
|
||||
if: github.event_name == 'pull_request'
|
||||
id: restore-venv-pr
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: /tmp/.uv-cache
|
||||
key: uv-${{ runner.os }}-${{ runner.arch }}-python-${{ inputs.python-version }}-${{ inputs.deps }}-${{ inputs.pydeps }}-${{ env.CACHE_VERSION }}
|
||||
|
|
@ -96,7 +96,7 @@ runs:
|
|||
|
||||
- name: Cache downloads (PR)
|
||||
if: inputs.key != '' && github.event_name == 'pull_request'
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: ${{ runner.os == 'Linux' && '~/.cache/tinygrad/downloads/' || '~/Library/Caches/tinygrad/downloads/' }}
|
||||
key: downloads-${{ github.job }}-${{ inputs.key }}-${{ env.CACHE_VERSION }}
|
||||
|
|
@ -203,7 +203,7 @@ runs:
|
|||
|
||||
- name: Cache apt (PR)
|
||||
if: runner.os == 'Linux' && (inputs.opencl == 'true' || inputs.amd == 'true' || inputs.webgpu == 'true' || inputs.llvm == 'true' || inputs.qemu == 'true') && github.event_name == 'pull_request'
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: /var/cache/apt/archives/
|
||||
key: ${{ runner.os }}-${{ runner.arch }}-apt-${{ steps.apt-pkgs.outputs.hash }}-${{ env.CACHE_VERSION }}
|
||||
|
|
|
|||
91
.github/workflows/benchmark.yml
vendored
91
.github/workflows/benchmark.yml
vendored
|
|
@ -99,7 +99,6 @@ jobs:
|
|||
ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu
|
||||
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
|
||||
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
|
||||
- name: setup staging db
|
||||
if: github.ref == 'refs/heads/update_benchmark_staging'
|
||||
|
|
@ -134,32 +133,10 @@ jobs:
|
|||
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
|
||||
- name: Fuzz Padded Tensor Core GEMM
|
||||
run: DEV=METAL M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run quantized LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
|
||||
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
|
||||
- name: Run quantized LLaMA3
|
||||
run: |
|
||||
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
|
||||
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
|
||||
#- name: Run LLaMA 7B on 4 (virtual) GPUs
|
||||
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run OLMoE
|
||||
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
|
||||
- name: Run llama3.2
|
||||
run: BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run olmoe
|
||||
run: BENCHMARK_LOG=olmoe JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m olmoe --benchmark --warmup
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
|
||||
|
||||
|
|
@ -235,9 +212,6 @@ jobs:
|
|||
- name: Symlink models and datasets
|
||||
run: |
|
||||
mkdir -p weights
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
|
||||
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
|
||||
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
|
||||
mkdir -p extra/datasets
|
||||
ln -s /raid/datasets/imagenet extra/datasets/imagenet
|
||||
|
|
@ -279,36 +253,16 @@ jobs:
|
|||
# TODO: too slow
|
||||
# - name: Run SDXL
|
||||
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit DEV=NV JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama DEV=NV JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run llama3.2
|
||||
run: DEV=NV BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run qwen3.5
|
||||
run: DEV=NV BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run quantized LLaMA3
|
||||
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit DEV=NV JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 DEV=NV JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half DEV=NV HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam DEV=NV HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: Speed (NVIDIA)
|
||||
|
|
@ -402,10 +356,7 @@ jobs:
|
|||
run: |
|
||||
mkdir -p weights
|
||||
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
|
||||
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
|
||||
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
|
||||
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
|
||||
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
|
||||
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
|
||||
mkdir -p extra/datasets
|
||||
ln -s /raid/datasets/imagenet extra/datasets/imagenet
|
||||
|
|
@ -458,18 +409,10 @@ jobs:
|
|||
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 DEV=AMD python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
|
||||
- name: Run SDXL
|
||||
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 DEV=AMD python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA 7B
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit DEV=AMD JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama DEV=AMD JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA 7B with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run llama3.2
|
||||
run: DEV=AMD BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
|
||||
- name: Run qwen3.5
|
||||
run: DEV=AMD BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
|
|
@ -478,16 +421,6 @@ jobs:
|
|||
# run: sudo modprobe amdgpu
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral DEV=AMD python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit DEV=AMD JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 DEV=AMD JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half DEV=AMD HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam DEV=AMD HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
|
|
|||
122
.github/workflows/test.yml
vendored
122
.github/workflows/test.yml
vendored
|
|
@ -133,46 +133,26 @@ jobs:
|
|||
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
|
||||
- name: Test IMAGE support
|
||||
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
|
||||
- name: Test emulated METAL tensor cores
|
||||
- name: Test emulated tensor cores
|
||||
env:
|
||||
DEV: 'PYTHON::METAL'
|
||||
DEBUG: 2
|
||||
N: 64
|
||||
CNT: 1
|
||||
SHOULD_USE_TC: 1
|
||||
run: |
|
||||
DEBUG=2 python3 test/backend/test_ops.py TestOps.test_big_gemm
|
||||
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test emulated AMD tensor cores
|
||||
env:
|
||||
DEV: 'PYTHON::gfx1100'
|
||||
parallel -k --link --tagstring '[{1}]' '{2} python3 ./extra/gemm/simple_matmul.py' \
|
||||
::: metal gfx950 gfx1100 gfx1100_acchalf gfx1201 gfx1201_acchalf sm_75 sm_80_half sm_80_tf32 \
|
||||
::: 'DEV=PYTHON::METAL' 'DEV=PYTHON::gfx950 HALF=1 ACC_HALF=0' \
|
||||
'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=1 ATOL=1e-3' \
|
||||
'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=1 ATOL=1e-3' \
|
||||
'DEV=PYTHON::sm_75 HALF=1' 'DEV=PYTHON::sm_80 HALF=1' 'DEV=PYTHON::sm_80 ALLOW_TF32=1'
|
||||
- name: Run additional tensor core tests
|
||||
run: |
|
||||
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test emulated AMD MFMA tensor cores
|
||||
env:
|
||||
DEV: 'PYTHON::gfx950'
|
||||
run: |
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test emulated AMD RDNA4 tensor cores
|
||||
env:
|
||||
DEV: 'PYTHON::gfx1201'
|
||||
run: |
|
||||
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test emulated CUDA tensor cores
|
||||
run: |
|
||||
DEBUG=2 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
|
||||
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
|
||||
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
|
||||
DEV=PYTHON::METAL python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
|
||||
DEV=PYTHON::gfx1100 python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
|
||||
DEV=PYTHON::gfx950 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
DEV=PYTHON::gfx1201 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
|
||||
- name: Test device flop counts
|
||||
run: |
|
||||
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
|
||||
linter:
|
||||
|
|
@ -218,14 +198,15 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: unittest-13
|
||||
pydeps: "pillow ftfy regex pre-commit"
|
||||
deps: testing_unit
|
||||
llvm: 'true'
|
||||
amd: 'true'
|
||||
- name: Run NULL backend tests
|
||||
run: DEV=NULL python -m pytest -n=auto test/null/ --durations=20
|
||||
- name: Run targeted tests on NULL backend
|
||||
run: DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
run: |
|
||||
DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
DEV=NULL VIZ=1 python3 -m pytest -n=auto test/null/test_viz.py
|
||||
# TODO: too slow
|
||||
# - name: Run SDXL on NULL backend
|
||||
# run: DEV=NULL DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
|
||||
|
|
@ -249,7 +230,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: unittest-13
|
||||
pydeps: "pillow ftfy regex pre-commit"
|
||||
pydeps: "pre-commit"
|
||||
deps: testing_unit
|
||||
llvm: 'true'
|
||||
- name: Run pre-commit test hooks
|
||||
|
|
@ -266,13 +247,6 @@ jobs:
|
|||
run: python3 test/external/external_benchmark_schedule.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
- name: Regen dataset on test_tiny
|
||||
run: |
|
||||
test/external/process_replay/reset.py
|
||||
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
|
||||
python extra/optimization/extract_dataset.py
|
||||
gzip -c /tmp/sops > extra/datasets/sops.gz
|
||||
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
|
||||
- name: Repo line count < 25000 lines
|
||||
run: MAX_LINE_COUNT=25000 python sz.py
|
||||
|
||||
|
|
@ -337,31 +311,6 @@ jobs:
|
|||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testgpumisc:
|
||||
name: CL Misc tests
|
||||
runs-on: *linux
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: gen-dataset
|
||||
deps: testing
|
||||
opencl: 'true'
|
||||
- name: Generate Dataset
|
||||
run: DEV=CL extra/optimization/generate_dataset.sh
|
||||
- name: Run Kernel Count Test
|
||||
run: DEV=CL python -m pytest -n=auto test/external/external_test_opt.py
|
||||
- name: Run fused optimizer tests
|
||||
run: DEV=CL FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py test/backend/test_optim.py -k "not muon"
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: sops.gz
|
||||
path: /tmp/sops.gz
|
||||
|
||||
testopenpilot:
|
||||
name: openpilot Compile Tests
|
||||
runs-on: *linux
|
||||
|
|
@ -378,7 +327,7 @@ jobs:
|
|||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp32 (test correctness)
|
||||
run: |
|
||||
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
|
||||
|
|
@ -421,7 +370,6 @@ jobs:
|
|||
with:
|
||||
key: optim
|
||||
deps: testing
|
||||
pydeps: "tensorflow==2.19"
|
||||
opencl: 'true'
|
||||
#- name: Test Optimization Helpers
|
||||
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
|
||||
|
|
@ -430,7 +378,7 @@ jobs:
|
|||
- name: Test Beam Search
|
||||
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
|
||||
- name: Test MLPerf stuff
|
||||
run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
||||
run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
||||
- name: DEV=NULL beautiful_mnist_multigpu
|
||||
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
|
||||
- name: Test Bert training
|
||||
|
|
@ -468,7 +416,7 @@ jobs:
|
|||
# ****** Models Tests ******
|
||||
|
||||
testmodels:
|
||||
name: Models (llvm+cpu+gpu)
|
||||
name: Models
|
||||
runs-on: *linux
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
|
|
@ -479,34 +427,12 @@ jobs:
|
|||
with:
|
||||
key: models
|
||||
deps: testing
|
||||
opencl: 'true'
|
||||
llvm: 'true'
|
||||
- name: Test models (llvm)
|
||||
run: DEV=CPU:LLVM python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (opencl)
|
||||
run: DEV=CL python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (cpu)
|
||||
run: DEV=CPU python -m pytest -n=auto test/models --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testmetalmodels:
|
||||
name: Models (metal)
|
||||
runs-on: &macos macos-26
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v6
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: metal
|
||||
deps: testing
|
||||
- name: Test models (Metal)
|
||||
run: DEV=METAL python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test LLaMA compile speed
|
||||
run: DEV=METAL python test/external/external_test_speed_llama.py
|
||||
|
||||
# ****** Feature Tests ******
|
||||
|
||||
testdsp:
|
||||
|
|
@ -715,7 +641,7 @@ jobs:
|
|||
|
||||
unittestmacos:
|
||||
name: MacOS (unit)
|
||||
runs-on: *macos
|
||||
runs-on: &macos macos-26
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
|
|||
Throw in an optimizer, a data loader, and some compute, and you have all you need.
|
||||
|
||||
```python
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad import Tensor, nn, Context
|
||||
|
||||
class LinearNet:
|
||||
def __init__(self):
|
||||
|
|
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
|
|||
|
||||
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
|
||||
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
for i in range(10):
|
||||
optim.zero_grad()
|
||||
loss = model(x).sparse_categorical_crossentropy(y).backward()
|
||||
|
|
|
|||
|
|
@ -165,13 +165,14 @@ from extra.datasets import fetch_mnist
|
|||
Now we have everything we need to start training our neural network.
|
||||
We will be training for 1000 steps with a batch size of 64.
|
||||
|
||||
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training.
|
||||
We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
|
||||
Upon exit, the flag is restored to its previous value by the context manager.
|
||||
|
||||
```python
|
||||
from tinygrad import Context
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
for step in range(1000):
|
||||
# random sample a batch
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(64))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Tuple
|
||||
import time
|
||||
from tinygrad import Tensor, TinyJit, nn
|
||||
from tinygrad import Tensor, TinyJit, nn, Context
|
||||
import gymnasium as gym
|
||||
from tinygrad.helpers import trange
|
||||
import numpy as np # TODO: remove numpy import
|
||||
|
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
|||
|
||||
@TinyJit
|
||||
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
log_dist, value = model(x)
|
||||
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()
|
||||
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ if __name__ == "__main__":
|
|||
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def train_step(idxs:Tensor) -> Tensor:
|
||||
X, Y = X_train[idxs], Y_train[idxs]
|
||||
if len(GPUS) > 1:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ class Model:
|
|||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import List, Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
|||
|
||||
@TinyJit
|
||||
def train_step() -> Tensor:
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import itertools
|
||||
from typing import Callable
|
||||
from tinygrad import nn, Tensor, dtypes, Device, TinyJit
|
||||
from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
|
||||
from tinygrad.helpers import getenv, trange, partition
|
||||
|
||||
class Model:
|
||||
|
|
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
|||
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def microbatch():
|
||||
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
|
||||
for t in params: t.grad = None
|
||||
|
|
|
|||
|
|
@ -359,7 +359,7 @@ def train_cifar():
|
|||
i = 0
|
||||
eval_acc_pct = 0.0
|
||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
st = time.monotonic()
|
||||
while i <= STEPS:
|
||||
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
import os, math, time
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
|
||||
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
|
|
@ -177,7 +177,7 @@ if __name__ == "__main__":
|
|||
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def step(x:Tensor, y:Tensor) -> Tensor:
|
||||
_, loss = model(x, y)
|
||||
optimizer.zero_grad()
|
||||
|
|
@ -204,4 +204,3 @@ if __name__ == "__main__":
|
|||
top_k = 40
|
||||
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
||||
print(decode(y[0].tolist()))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# much taken from https://github.com/cloneofsimo/minRF
|
||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
|
||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
|
||||
from tinygrad.helpers import getenv, trange
|
||||
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
|
||||
|
||||
|
|
@ -135,7 +135,7 @@ if __name__ == "__main__":
|
|||
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def train_step():
|
||||
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
|
||||
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])
|
||||
|
|
|
|||
|
|
@ -358,7 +358,7 @@ def eval_stable_diffusion():
|
|||
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
|
||||
return batch, unpadded_bs
|
||||
|
||||
@Tensor.train(mode=False)
|
||||
@Context(TRAINING=0)
|
||||
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
|
||||
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
|
||||
# Eval is divided into 5 jits, one per model
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
|
|||
from pathlib import Path
|
||||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
|
||||
|
|
@ -614,7 +614,7 @@ def train_retinanet():
|
|||
|
||||
if getenv("RESET_STEP", 1): _train_step.reset()
|
||||
|
||||
with Tensor.train(mode=False):
|
||||
with Context(TRAINING=0):
|
||||
if not RUNMLPERF:
|
||||
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
|
||||
else:
|
||||
|
|
@ -784,7 +784,7 @@ def train_unet3d():
|
|||
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
@Context(TRAINING=1)
|
||||
def train_step(model, x, y):
|
||||
optim.zero_grad()
|
||||
|
||||
|
|
@ -795,7 +795,7 @@ def train_unet3d():
|
|||
optim.step()
|
||||
return loss.realize()
|
||||
|
||||
@Tensor.train(mode=False)
|
||||
@Context(TRAINING=0)
|
||||
def eval_step(model, x, y):
|
||||
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
|
||||
y_hat, y = Tensor(y_hat), Tensor(y)
|
||||
|
|
@ -1490,7 +1490,7 @@ def train_llama3():
|
|||
return lr_cpu, grad_norm_cpu
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train(False)
|
||||
@Context(TRAINING=0)
|
||||
def eval_step(tokens:Tensor):
|
||||
if is_dp: tokens = tokens.to(None).shard(device, 0)
|
||||
if is_mp: tokens = tokens.shard(device)
|
||||
|
|
@ -1803,7 +1803,7 @@ if __name__ == "__main__":
|
|||
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
|
||||
else: bench_log_manager = contextlib.nullcontext()
|
||||
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
|
||||
nm = f"train_{m}"
|
||||
if nm in globals():
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
|||
|
||||
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
||||
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
|
||||
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
|
||||
grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]:
|
||||
if not fp8:
|
||||
if ASM_GEMM:
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
|
|
@ -47,12 +47,14 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
|
|||
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
|
||||
x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
|
||||
if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d)
|
||||
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
|
||||
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
|
||||
if can_use_asm_gemm(x_q, w.T):
|
||||
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
|
||||
mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0])
|
||||
mx_w_stored=True).reshape(*l_shape, w.shape[0])
|
||||
else:
|
||||
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1])
|
||||
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1])
|
||||
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
|
||||
return out, (amax_x.detach() if amax_x is not None else None), x_q
|
||||
if x_fp8 is None:
|
||||
|
|
@ -126,10 +128,8 @@ class FlatTransformer:
|
|||
|
||||
# FeedForward
|
||||
if SPLIT_W13:
|
||||
if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim)
|
||||
else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02)
|
||||
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
|
||||
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
|
||||
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim)
|
||||
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim)
|
||||
else:
|
||||
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
|
||||
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
|
||||
|
|
@ -160,7 +160,7 @@ class FlatTransformer:
|
|||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
|
||||
if w is None:
|
||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize()
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
|
||||
|
|
@ -216,8 +216,15 @@ class FlatTransformer:
|
|||
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
|
||||
amaxs.append(new_amax)
|
||||
saves.extend([*s, x_w3])
|
||||
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
|
||||
grad_amax_state=kwargs["grad_amax_xout"])
|
||||
if FUSED_SILU_W13 and MXFP8:
|
||||
from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
|
||||
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
|
||||
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
|
||||
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
|
||||
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
|
||||
else:
|
||||
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
|
||||
grad_amax_state=kwargs["grad_amax_xout"])
|
||||
amaxs.append(new_amax)
|
||||
saves.extend([*s, out])
|
||||
else:
|
||||
|
|
@ -247,20 +254,30 @@ class FlatTransformer:
|
|||
for v in get_parameters(self): v.shard_(device, axis=None)
|
||||
else:
|
||||
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
||||
def _shard_fp8(name:str, axis:int):
|
||||
getattr(self, name).shard_(device, axis=axis)
|
||||
scale_axis = axis if MXFP8 else (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||
def _shard_fp8(name:str, axis:int, std:float=0.02):
|
||||
w = getattr(self, name)
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std
|
||||
w_q, w_e8, _ = quantize_mxfp8(w_bf16)
|
||||
w.replace(w_q)
|
||||
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
|
||||
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
|
||||
else:
|
||||
w.shard_(device, axis=axis)
|
||||
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||
sstd = 0.02 / math.sqrt(2 * self.n_layers)
|
||||
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
|
||||
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
|
||||
_shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in
|
||||
if SPLIT_W13:
|
||||
_shard_fp8("w1", 1)
|
||||
_shard_fp8("w3", 1)
|
||||
else:
|
||||
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
|
||||
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
|
||||
_shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in
|
||||
self.attention_norm.shard_(device, axis=None).realize()
|
||||
self.ffn_norm.shard_(device, axis=None).realize()
|
||||
self.norm.weight.shard_(device, axis=None).realize()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from torchvision.utils import make_grid, save_image
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import trange
|
||||
from tinygrad.helpers import trange, Context
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ if __name__ == "__main__":
|
|||
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
||||
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
|
||||
# training loop
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
for epoch in (t := trange(epochs)):
|
||||
loss_g, loss_d = 0.0, 0.0
|
||||
for _ in range(n_steps):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# - symbolic removal
|
||||
|
||||
from examples.beautiful_mnist import Model
|
||||
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable
|
||||
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
|
||||
from tinygrad.nn.datasets import mnist
|
||||
from tinygrad.helpers import trange
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ if __name__ == "__main__":
|
|||
X_samp, Y_samp = X_train[samples], Y_train[samples]
|
||||
print("*** got samples")
|
||||
|
||||
with Tensor.train():
|
||||
with Context(TRAINING=1):
|
||||
"""
|
||||
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
|
||||
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.dtype import DType, dtypes, AddrSpace
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
|
@ -39,7 +39,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
|
|||
prg = to_program(call.src[0], Device[arg_uops[0].device].renderer)
|
||||
info = prg.arg
|
||||
functions[info.function_name] = prg.src[3].arg
|
||||
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR]
|
||||
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + list(info.vars)
|
||||
statements.append((info.function_name, cargs, info.global_size, info.local_size))
|
||||
|
||||
return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
|
||||
|
|
@ -253,17 +253,18 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
|
|||
symbolic_vars = OrderedDict()
|
||||
for i, (_, args, global_size, _) in enumerate(statements):
|
||||
for j, var in enumerate(args):
|
||||
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
|
||||
if getattr(var, "op", None) is Ops.PARAM and var.addrspace is AddrSpace.ALU and var.arg.name is not None:
|
||||
if var not in symbolic_vars:
|
||||
symbolic_vars[var] = var.arg[0]
|
||||
symbolic_vars[var] = var.expr
|
||||
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
|
||||
statements[i][1][j] = symbolic_vars[var]
|
||||
|
||||
if global_size:
|
||||
for j, dim in enumerate(global_size):
|
||||
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
|
||||
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and \
|
||||
any(s.op is Ops.PARAM and s.addrspace is AddrSpace.ALU for s in dim.src) and any(s.op is Ops.CONST for s in dim.src):
|
||||
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
|
||||
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
|
||||
global_size[j] = f"_{name.expr}[0] + {val.arg}"
|
||||
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
|
|
|
|||
|
|
@ -458,7 +458,8 @@ def test_matmul():
|
|||
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
|
||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
|
||||
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
|
||||
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
|
|||
|
||||
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
|
||||
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
acc = acc.after(acc.store(acc.zeros_like()))
|
||||
acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
|
||||
|
||||
if use_wmma:
|
||||
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
|
|||
P_lds = QP_lds[:, :BLOCK_N]
|
||||
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
|
||||
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
|
||||
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) — shaped store fails due to RESHAPE(DEFINE_LOCAL) surviving linearization
|
||||
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) -- shaped store fails due to RESHAPE(local BUFFER) surviving linearization
|
||||
rw1 = UOp.range(TM, 296, AxisType.LOOP)
|
||||
rw2 = UOp.range(TN, 297, AxisType.LOOP)
|
||||
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)
|
||||
|
|
|
|||
|
|
@ -2619,7 +2619,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
|||
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
|
||||
gidx = UOp.special(NUM_WG, "gidx0")
|
||||
insts = build_kernel(batch, M, N, K, A.dtype.base)
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds = UOp.placeholder((133_120,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
|
||||
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),
|
||||
|
|
@ -2674,14 +2674,14 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
|
|||
|
||||
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# 1x32 block scaling along the last axis
|
||||
rows, K = x.shape
|
||||
scale_K, k_iters = K // 32, K // 128
|
||||
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
|
||||
*batch, K = x.shape
|
||||
scale_K = K // 32
|
||||
amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
|
||||
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
|
||||
x_scaled = x.float() * qscale
|
||||
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
|
||||
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8)
|
||||
return x_clamped.cast(FP8_DTYPE), e8, (mx_pack(e8) if len(batch) == 1 else None)
|
||||
|
||||
def mx_pack(e8:Tensor) -> Tensor:
|
||||
rows, scale_K = e8.shape
|
||||
|
|
|
|||
|
|
@ -219,7 +219,8 @@ def test_matmul():
|
|||
def asm_kernel(A, B, C):
|
||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||
lidxs = [UOp.special(THREADS, "lidx0")]
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2))
|
||||
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
|
||||
arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from tinygrad.helpers import to_tuple, round_up
|
|||
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
|
||||
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
|
||||
from tinygrad.uop.symbolic import symbolic_simple, symbolic
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
|
||||
|
|
@ -143,14 +143,17 @@ def make_getaddr(u, device=None):
|
|||
def make_ins(op, *srcs):
|
||||
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
||||
|
||||
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
|
||||
dt = dtype or val.dtype
|
||||
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
|
||||
|
||||
def make_cmdbuf(lin, devs, tag):
|
||||
blob, patches = b'', []
|
||||
for s in (s for ins in lin.src for s in ins.src):
|
||||
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
||||
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
||||
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
|
||||
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
|
||||
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
|
||||
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches])
|
||||
|
||||
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
|
||||
|
||||
|
|
@ -211,15 +214,11 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
|
|||
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
|
||||
|
||||
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
|
||||
data, info = prg.arg
|
||||
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
|
||||
dtypes.uint64) for i,gi in enumerate(info.globals)] \
|
||||
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
||||
|
||||
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
||||
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
|
||||
|
||||
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
|
||||
(data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device)
|
||||
buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
||||
patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \
|
||||
+ [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
||||
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
|
||||
|
||||
pm_prep_runtime = PatternMatcher([
|
||||
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
|
||||
|
|
@ -288,19 +287,26 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
|||
new_src.append(call)
|
||||
continue
|
||||
|
||||
q = get_submit(call.src[0]).src[0]
|
||||
new_q = ctx.last_per_queue[q.arg] = q.rtag(next(ctx.opid))
|
||||
new_q = ctx.last_per_queue[q.arg] = (q:=get_submit(call.src[0]).src[0]).rtag(next(ctx.opid))
|
||||
qdevs, refs = to_tuple(new_q.arg[0]), get_call_arg_uops(call)
|
||||
|
||||
refs = get_call_arg_uops(call)
|
||||
deps = dedup(flatten(ctx.deps.access_resources([get_dep_buf(ctx, b, l) for b in refs], call.arg.aux.outs, new_q) for l in range(len(q.arg[0]))))
|
||||
# per-lane deps, tracked per (device, queue). skip self
|
||||
dep_lanes:list[tuple[UOp, int]] = []
|
||||
for lane, d in enumerate(qdevs):
|
||||
for dep in ctx.deps.access_resources([get_dep_buf(ctx, b, lane) for b in refs], call.arg.aux.outs, new_q.replace(arg=(d, new_q.arg[1]))):
|
||||
if dep.tag != new_q.tag: dep_lanes.append((dep, lane))
|
||||
|
||||
# optims: keep only the max wait per queue, and drop self-queue waits when the queue self-orders
|
||||
deps = {dep.arg:dep for dep in sorted(deps, key=lambda x: x.tag)}
|
||||
if to_tuple(new_q.arg[0])[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
|
||||
deps.pop(new_q.arg, None)
|
||||
# drop self-queue waits, queue self-orders
|
||||
if qdevs[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
|
||||
dep_lanes = [(dep, lane) for dep, lane in dep_lanes if dep.arg != (qdevs[lane], new_q.arg[1])]
|
||||
|
||||
new_q = new_q.after(*deps.values()).rtag("deps") if deps else new_q
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
|
||||
# keep latest dep per lane, group lanes
|
||||
latest = {(dep.arg, lane): dep for dep, lane in sorted(dep_lanes, key=lambda x: x[0].tag)}
|
||||
deps:dict[UOp, tuple[int, ...]] = collections.defaultdict(tuple)
|
||||
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
|
||||
|
||||
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}),)))
|
||||
return linear.replace(src=tuple(new_src))
|
||||
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
||||
|
||||
|
|
@ -314,7 +320,10 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
|
|||
|
||||
# queue is inc with deps
|
||||
submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0")
|
||||
submit = submit.replace(src=(submit.src[0].after(*queues).rtag("deps"),))
|
||||
|
||||
# split each (multi-device) queue into per-device deps so each finalizer lane waits on the matching device's signal
|
||||
lane_queues = [(q.replace(arg=(d, q.arg[1])), (devs.index(d),)) for q in queues for d in to_tuple(q.arg[0])]
|
||||
submit = submit.replace(src=(submit.src[0].after(*(q for q, _ in lane_queues), arg=tuple(l for _, l in lane_queues)).rtag("deps"),))
|
||||
|
||||
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
|
||||
patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]
|
||||
|
|
@ -335,12 +344,14 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
|
|||
cur_devs = to_tuple((cur:=deps.src[0]).arg[0])
|
||||
|
||||
waits = []
|
||||
for dep in deps.src[1:]:
|
||||
devs, queue = dep.arg
|
||||
for lanes, dep in zip(deps.arg, deps.src[1:]):
|
||||
dep_dev, queue = dep.arg # dep_dev is a single device (deps are recorded per-device)
|
||||
ctx.add(dep.tag) # mark op to update signal.
|
||||
|
||||
sig = make_mstack([make_signal(d, queue=queue, sentinel=d not in devs) for d in cur_devs])
|
||||
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0))
|
||||
# for lanes that need this dep, wait on the dep device's signal/value; other lanes get a passing sentinel
|
||||
lanes = set(lanes)
|
||||
sig = make_mstack([make_signal(dep_dev if j in lanes else d, queue=queue, sentinel=j not in lanes) for j, d in enumerate(cur_devs)])
|
||||
val = make_mstack([make_signal_value(dep_dev if j in lanes else d, queue=queue) for j, d in enumerate(cur_devs)]).index(UOp.const(dtypes.int, 0))
|
||||
waits.append(sig.wait(val + dep.tag))
|
||||
return cur.replace(src=tuple(waits) + cur.src)
|
||||
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
|
||||
|
|
@ -371,7 +382,8 @@ def merge_queues(linear:UOp) -> UOp:
|
|||
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
|
||||
|
||||
for call in linear.src:
|
||||
if call.tag != "hcq":
|
||||
# finalizer cannot be merged, since it bumps inner signal (this introduces race when multidevs).
|
||||
if call.tag != "hcq" or (call.tag == "hcq" and call.arg.aux.name == "hcq finalizer"):
|
||||
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
|
||||
continue
|
||||
|
||||
|
|
@ -410,7 +422,7 @@ pm_annotate_devs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
|
|||
# 4.4. replace params with per-submit input address loads
|
||||
|
||||
def replace_params(call:UOp) -> UOp|None:
|
||||
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is not None}): return None
|
||||
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is AddrSpace.GLOBAL}): return None
|
||||
|
||||
# fill new info
|
||||
hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call)))
|
||||
|
|
@ -504,8 +516,11 @@ def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
|
|||
|
||||
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
|
||||
if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf
|
||||
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr)
|
||||
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs))
|
||||
devs, b = to_tuple(g.src[1].arg), buf.buffer
|
||||
bufs = tuple(cast(Buffer, x.buffer) for x in buf.src) if buf.op is Ops.MSTACK else tuple(b.bufs if isinstance(b, MultiBuffer) else (b,)*len(devs))
|
||||
assert len(bufs) == len(devs), f"can't resolve {len(bufs)} buffers on {len(devs)} devices"
|
||||
addrs = tuple(UOp.const(dtypes.uint64, x.get_buf(d).va_addr) for x, d in zip(bufs, devs))
|
||||
return addrs[0] if len(addrs) == 1 else UOp(Ops.STACK, dtypes.uint64.vec(len(addrs)), addrs)
|
||||
|
||||
def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp:
|
||||
itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize
|
||||
|
|
@ -516,9 +531,9 @@ pm_resolve_patches = PatternMatcher([
|
|||
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
||||
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
||||
|
||||
# index on slice is index
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
|
||||
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
|
||||
# shrink on slice is shrink on base at offset
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"),
|
||||
lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))),
|
||||
|
||||
# getaddr
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||
|
|
@ -526,8 +541,8 @@ pm_resolve_patches = PatternMatcher([
|
|||
|
||||
# folders
|
||||
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
||||
fold_const_store),
|
||||
(UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast()
|
||||
.store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store),
|
||||
]) + symbolic_simple
|
||||
|
||||
# *****************
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast
|
||||
from typing import cast, Any, Callable
|
||||
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -89,25 +89,25 @@ def memory_barrier(ctx):
|
|||
reg_done=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff),
|
||||
acquire_mem(ctx)))
|
||||
|
||||
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.device))
|
||||
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.devs))
|
||||
|
||||
def pm4_barrier(ctx): return memory_barrier(ctx)
|
||||
|
||||
def pm4_store(ctx, dst, val):
|
||||
if val.op is Ops.BINARY: return None
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.devs), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
ctx.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
|
||||
|
||||
def pm4_timestamp(ctx, dst):
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.device), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
return release_mem(ctx, make_getaddr(dst, ctx.devs), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
|
||||
ctx.pm4.int_sel__mec_release_mem__none)
|
||||
|
||||
def pm4_program(ctx, prg):
|
||||
data, info = prg.arg
|
||||
lib_gpu, args = prg.src
|
||||
prog_addr = make_getaddr(lib_gpu, ctx.device) + data.entry_point_offset
|
||||
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.device)
|
||||
args_addr = make_getaddr(args, ctx.device)
|
||||
prog_addr = make_getaddr(lib_gpu, ctx.devs) + data.entry_point_offset
|
||||
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.devs)
|
||||
args_addr = make_getaddr(args, ctx.devs)
|
||||
|
||||
user_regs = []
|
||||
if data.enable_private_segment_sgpr:
|
||||
|
|
@ -174,7 +174,7 @@ pm_pm4_submit = PatternMatcher([(UPat(Ops.LINEAR, name="lin"),
|
|||
class SDMAOps(FastEnum): COPY = auto(); POLL_REGMEM = auto(); FENCE = auto(); TRAP = auto(); TIMESTAMP = auto() # noqa: E702
|
||||
|
||||
def sdma_copy(ctx, dst, src, copy):
|
||||
src_addr, dst_addr = make_getaddr(src, ctx.device), make_getaddr(dst, ctx.device)
|
||||
src_addr, dst_addr = make_getaddr(src, ctx.devs), make_getaddr(dst, ctx.devs)
|
||||
return UOp(Ops.LINEAR, dtypes.void, tuple([make_ins(SDMAOps.COPY,
|
||||
ctx.sdma.SDMA_OP_COPY | ctx.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_COPY_LINEAR),
|
||||
ctx.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(min(copy.arg - off, ctx.max_copy_size) - 1), 0,
|
||||
|
|
@ -183,17 +183,17 @@ def sdma_copy(ctx, dst, src, copy):
|
|||
def sdma_wait(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_POLL_REGMEM | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) \
|
||||
| ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1)
|
||||
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.device)), val, 0xffffffff,
|
||||
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.devs)), val, 0xffffffff,
|
||||
ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
|
||||
|
||||
def sdma_store(ctx, dst, val):
|
||||
op = ctx.sdma.SDMA_OP_FENCE | (ctx.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if ctx.target[0] != 9 else 0)
|
||||
return UOp(Ops.LINEAR, dtypes.void, (
|
||||
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.device)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
|
||||
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.devs)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
|
||||
|
||||
def sdma_timestamp(ctx, dst):
|
||||
op = ctx.sdma.SDMA_OP_TIMESTAMP | ctx.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL)
|
||||
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.device)))
|
||||
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.devs)))
|
||||
|
||||
pm_sdma_opsel = PatternMatcher([
|
||||
(UPat(Ops.BARRIER), lambda: UOp(Ops.NOOP, dtypes.void, ())),
|
||||
|
|
@ -516,11 +516,17 @@ class PCIIface(PCIIfaceBase):
|
|||
|
||||
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AMDEncodeCtx: # encode-time constants for one queue: devs (every cmdbuf address resolves into these) + gfx version + packet/ip modules
|
||||
devs: tuple[str, ...]; target: tuple[int, ...]; pm4: Any; sdma: Any; soc: Any # noqa: E702
|
||||
gc: AMDIP; nbio: AMDIP; xccs: int; max_copy_size: int; tmpring_size: Callable # noqa: E702
|
||||
|
||||
def encode_queue(q:UOp) -> UOp|None:
|
||||
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
|
||||
devs = to_tuple(q.arg[0])
|
||||
d = Device[(devs:=to_tuple(q.arg[0]))[0]]
|
||||
ctx = AMDEncodeCtx(devs, d.target, d.pm4, d.sdma, d.soc, d.gc, d.nbio, d.xccs, d.max_copy_size, d.tmpring_size)
|
||||
opsel, submit = (pm_pm4_opsel, pm_pm4_submit) if q.arg[1].startswith("COMPUTE") else (pm_sdma_opsel, pm_sdma_submit)
|
||||
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=Device[devs[0]], name=f"{q.arg[1]} opsel"))
|
||||
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=ctx, name=f"{q.arg[1]} opsel"))
|
||||
|
||||
pm_lower = PatternMatcher([
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),)), encode_queue),
|
||||
|
|
|
|||
104
extra/llama_kernels/fused_silu_mul_quantize_mxfp8/__init__.py
Normal file
104
extra/llama_kernels/fused_silu_mul_quantize_mxfp8/__init__.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import functools
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
|
||||
|
||||
BLK = 32
|
||||
PACK = 4
|
||||
LOG2E = 1.4426950408889634
|
||||
|
||||
@functools.cache
|
||||
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
|
||||
rows, K = x_w1.shape
|
||||
scale_K = K // BLK
|
||||
n_elems = rows * K
|
||||
n_super = n_elems // (BLK * PACK)
|
||||
sk4 = scale_K // PACK
|
||||
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
|
||||
nwg = n_super // THREADS_PER_WG
|
||||
|
||||
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
|
||||
fp8_out = fp8_out.reshape(n_elems)
|
||||
e8_out = e8_out.reshape(rows * scale_K)
|
||||
si_out = si_out.reshape(sk4 * rows)
|
||||
|
||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
||||
sb = UOp.range(PACK, 2, AxisType.UNROLL)
|
||||
lane = UOp.range(BLK, 3, AxisType.UNROLL)
|
||||
|
||||
super_idx = wg * THREADS_PER_WG + tid
|
||||
idx = super_idx * (BLK * PACK) + sb * BLK + lane
|
||||
|
||||
w1 = x_w1[idx].cast(dtypes.float)
|
||||
w3 = x_w3[idx].cast(dtypes.float)
|
||||
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
|
||||
act = w1 * sig * w3
|
||||
abs_a = (act < 0.0).where(-act, act)
|
||||
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
|
||||
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
|
||||
qscale = (127.0 - e8f).exp2()
|
||||
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
|
||||
e8u8 = e8f.cast(dtypes.uint8)
|
||||
|
||||
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
|
||||
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
|
||||
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
|
||||
row, col4 = super_idx // sk4, super_idx % sk4
|
||||
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
|
||||
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
|
||||
|
||||
@functools.cache
|
||||
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
|
||||
rows, K = x_w1.shape
|
||||
scale_K = K // BLK
|
||||
n_elems = rows * K
|
||||
VEC = 8
|
||||
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
|
||||
nwg = n_elems // (THREADS_PER_WG * VEC)
|
||||
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
|
||||
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
|
||||
|
||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
||||
lane = UOp.range(VEC, 2, AxisType.UNROLL)
|
||||
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
|
||||
|
||||
e8v = e8[idx // BLK].cast(dtypes.float)
|
||||
qscale = (127.0 - e8v).exp2()
|
||||
ga = grad_aq[idx].cast(dtypes.float) * qscale
|
||||
w1 = x_w1[idx].cast(dtypes.float)
|
||||
w3 = x_w3[idx].cast(dtypes.float)
|
||||
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
|
||||
s = w1 * sig
|
||||
sprime = sig * (1.0 + w1 * (1.0 - sig))
|
||||
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
|
||||
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
|
||||
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
|
||||
|
||||
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
|
||||
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
|
||||
device = x_w1.device
|
||||
rows, K = x_w1.shape
|
||||
axis = x_w1.axis if isinstance(device, tuple) else None
|
||||
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
|
||||
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
|
||||
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
|
||||
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
|
||||
fxn=_custom_silu_mul_bwd_mxfp8)
|
||||
return (None, None, None, gx1.uop, gx3.uop)
|
||||
|
||||
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
|
||||
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
|
||||
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
|
||||
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
|
||||
rows, K = x_w1.shape
|
||||
scale_K = K // BLK
|
||||
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
|
||||
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
|
||||
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
|
||||
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
|
||||
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
|
||||
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
|
||||
return fp8_out, e8_out, si_out
|
||||
|
|
@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
|
|||
step = THREADS_PER_WG // 2
|
||||
while step:
|
||||
active = tid < step
|
||||
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
|
||||
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
|
||||
other = lds[(tid + step).valid(active)].load()
|
||||
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
|
||||
step //= 2
|
||||
|
||||
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])
|
||||
|
|
|
|||
71
extra/llama_kernels/quantize_mxfp8_fused/__init__.py
Normal file
71
extra/llama_kernels/quantize_mxfp8_fused/__init__.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
import functools
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
|
||||
|
||||
BLK = 32
|
||||
PACK = 4
|
||||
|
||||
@functools.cache
|
||||
def _custom_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x:UOp) -> UOp:
|
||||
rows, K = x.shape
|
||||
scale_K = K // BLK
|
||||
n_elems = rows * K
|
||||
n_super = n_elems // (BLK * PACK)
|
||||
sk4 = scale_K // PACK
|
||||
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
|
||||
nwg = n_super // THREADS_PER_WG
|
||||
|
||||
x = x.reshape(n_elems)
|
||||
fp8_out = fp8_out.reshape(n_elems)
|
||||
e8_out = e8_out.reshape(rows * scale_K)
|
||||
si_out = si_out.reshape(sk4 * rows)
|
||||
|
||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
||||
sb = UOp.range(PACK, 2, AxisType.UNROLL)
|
||||
lane = UOp.range(BLK, 3, AxisType.UNROLL)
|
||||
|
||||
super_idx = wg * THREADS_PER_WG + tid
|
||||
idx = super_idx * (BLK * PACK) + sb * BLK + lane
|
||||
|
||||
x_f = x[idx].cast(dtypes.float)
|
||||
abs_x = (x_f < 0.0).where(-x_f, x_f)
|
||||
blk_max = abs_x.reduce(lane, arg=Ops.MAX)
|
||||
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
|
||||
qscale = (127.0 - e8f).exp2()
|
||||
scaled = (x_f * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
|
||||
e8u8 = e8f.cast(dtypes.uint8)
|
||||
|
||||
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
|
||||
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
|
||||
|
||||
# pack the 4 e8 of this super-block into one uint32 (little-endian: byte sb), write transposed (sk4, row)
|
||||
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
|
||||
row, col4 = super_idx // sk4, super_idx % sk4
|
||||
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
|
||||
return si_store.end(tid, wg).sink(arg=KernelInfo(f"quantize_mxfp8_{n_elems}", opts_to_apply=()))
|
||||
|
||||
def _quantize_mxfp8_fused_bwd(gradient:UOp, kernel:UOp):
|
||||
_, e8_out, _, x = kernel.src[1:]
|
||||
device = x.device
|
||||
rows, K = x.shape
|
||||
scale_K = K // BLK
|
||||
e8 = Tensor(e8_out, device=device).reshape(rows, scale_K)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, BLK).reshape(rows, K)
|
||||
grad_x = (Tensor(gradient, device=device).float() * qscale).cast(dtypes.bfloat16)
|
||||
return (None, None, None, grad_x.uop)
|
||||
|
||||
def quantize_mxfp8_fused(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
assert x.dtype == dtypes.bfloat16, f"expected bf16, got {x.dtype}"
|
||||
assert x.ndim == 2, f"expected 2d (rows, K), got {x.shape}"
|
||||
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
|
||||
rows, K = x.shape
|
||||
scale_K = K // BLK
|
||||
axis = x.uop.axis if isinstance(x.device, tuple) else None
|
||||
fp8_out = alloc_like((rows, K), FP8_DTYPE, x.device, axis)
|
||||
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x.device, axis)
|
||||
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x.device, None if axis is None else (1 if axis == 0 else 0))
|
||||
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x, fxn=_custom_quantize_mxfp8, grad_fxn=_quantize_mxfp8_fused_bwd)
|
||||
return fp8_out, e8_out, si_out
|
||||
|
|
@ -26,9 +26,7 @@ constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
|
|||
|
||||
using G = kittens::group<NUM_WARPS>;
|
||||
|
||||
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
|
||||
const uint32_t *__restrict__ scale_A_iter,
|
||||
const uint32_t *__restrict__ scale_B_iter,
|
||||
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr, fp8e8m0 *scale_A_ptr, fp8e8m0 *scale_B_ptr,
|
||||
const uint8_t *__restrict__ a_e8_unused,
|
||||
const uint8_t *__restrict__ b_e8_unused) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
|
|
@ -40,15 +38,19 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
constexpr int k_iters = K / BLOCK_K;
|
||||
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
|
||||
|
||||
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
|
||||
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
|
||||
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
|
||||
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
|
||||
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
|
||||
kittens::gl<fp8e8m0, k_iters * (M / BLOCK_ROW), 1, 16, 64> scale_A_gl{scale_A_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
kittens::gl<fp8e8m0, k_iters * (N / BLOCK_COL), 1, 16, 64> scale_B_gl{scale_B_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
|
||||
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
|
||||
using ST_Scale = st<fp8e8m0, 16, 64, st_16x64_s>;
|
||||
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
|
||||
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
|
||||
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
|
||||
|
||||
__shared__ ST_A As[2][2];
|
||||
__shared__ ST_B Bs[2][2];
|
||||
__shared__ uint8_t smem_scales[2048];
|
||||
__shared__ ST_Scale scale_A_smem[2], scale_B_smem[2];
|
||||
|
||||
RT_A a;
|
||||
RT_B b0, b1;
|
||||
|
|
@ -81,54 +83,87 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
G::prefill_swizzled_offsets(As[0][0], A, sw_A);
|
||||
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
|
||||
|
||||
const T *a_base = (const T *)&A[{0, 0, 0, 0}];
|
||||
const T *b_base = (const T *)&B[{0, 0, 0, 0}];
|
||||
const int a_row_stride = A.template stride<2>() * sizeof(T);
|
||||
const int b_row_stride = B.template stride<2>() * sizeof(T);
|
||||
i32x4 a_srd = make_srsrc(a_base, M * a_row_stride, a_row_stride);
|
||||
i32x4 b_srd = make_srsrc(b_base, N * b_row_stride, b_row_stride);
|
||||
|
||||
const int wid = warpid() % NUM_WARPS;
|
||||
constexpr int elem_per_warp = (16 / sizeof(T)) * kittens::WARP_THREADS;
|
||||
uint32_t a_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t a_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t a_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t a_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
uint32_t b_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
|
||||
|
||||
int a_row_h0 = warp_m * REG_M;
|
||||
int a_row_h1 = HALF_ROW + warp_m * REG_M;
|
||||
int b_row_h0 = warp_n * REG_N;
|
||||
int b_row_h1 = HALF_COL + warp_n * REG_N;
|
||||
|
||||
int tic = 0, toc = 1;
|
||||
int tic_scales = 0, toc_scales = 1;
|
||||
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A);
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B);
|
||||
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A);
|
||||
uint32_t a_lds[2][2] = {{a_lds_00, a_lds_01}, {a_lds_10, a_lds_11}};
|
||||
uint32_t b_lds[2][2] = {{b_lds_00, b_lds_01}, {b_lds_10, b_lds_11}};
|
||||
|
||||
if constexpr (k_iters >= 6 && (k_iters % 2 == 0)) {
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
|
||||
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][1]));
|
||||
|
||||
if (warp_m == 1) __builtin_amdgcn_s_barrier();
|
||||
asm volatile("s_waitcnt vmcnt(4)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A);
|
||||
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B);
|
||||
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B);
|
||||
|
||||
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][0]));
|
||||
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][0]));
|
||||
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][1]));
|
||||
asm volatile("s_waitcnt vmcnt(6)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1) {
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
G::load(scale_A_smem[0], scale_A_gl, {0 * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[0], scale_B_gl, {0 * tiles_N + block_col, 0, 0, 0});
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
int sa_idx = block_row, sb_idx = block_col;
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
|
||||
if (k + 1 < k_iters) {
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
|
||||
asm volatile("s_waitcnt lgkmcnt(8)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
||||
load(b1, bs1);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
|
|
@ -137,10 +172,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
||||
load(a, as1);
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B);
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
|
|
@ -150,7 +184,7 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B);
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
|
||||
asm volatile("s_waitcnt vmcnt(6)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
|
|
@ -162,19 +196,24 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
|
||||
{ // Epilogue k = k_iters - 2
|
||||
int k = k_iters - 2;
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
if (k + 1 < k_iters) {
|
||||
sa_idx += tiles_M; sb_idx += tiles_N;
|
||||
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
|
||||
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
||||
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
|
|
@ -217,17 +256,17 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
tic ^= 1; toc ^= 1;
|
||||
tic_scales ^= 1; toc_scales ^= 1;
|
||||
}
|
||||
|
||||
{ // Epilogue k = k_iters - 1 (final)
|
||||
int k = k_iters - 1;
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
|
||||
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
|
|
@ -263,6 +302,38 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
|
|||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
} else {
|
||||
#pragma unroll 1
|
||||
for (int kk = 0; kk < k_iters; kk++) {
|
||||
G::load(As[0][0], A, {0, 0, block_row * 2, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][0]));
|
||||
G::load(As[0][1], A, {0, 0, block_row * 2 + 1, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][1]));
|
||||
G::load(Bs[0][0], B, {0, 0, block_col * 2, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][0]));
|
||||
G::load(Bs[0][1], B, {0, 0, block_col * 2 + 1, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][1]));
|
||||
G::load(scale_A_smem[0], scale_A_gl, {kk * tiles_M + block_row, 0, 0, 0});
|
||||
G::load(scale_B_smem[0], scale_B_gl, {kk * tiles_N + block_col, 0, 0, 0});
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[0].data, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[0].data, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[0].data, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[0].data, b_row_h1);
|
||||
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][0], {warp_n, 0}); load(b0, bs0);
|
||||
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][1], {warp_n, 0}); load(b1, bs1);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[0][0], {warp_m, 0}); load(a, as0);
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
||||
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
||||
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[0][1], {warp_m, 0}); load(a, as1);
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
||||
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
}
|
||||
|
||||
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
||||
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
|
||||
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
||||
|
|
|
|||
|
|
@ -438,6 +438,35 @@ __device__ __forceinline__ T buffer_load_dwordx2(
|
|||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_load_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("buffer_load_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a u32x3 (ext_vector_type(3)). Note sizeof(u32x3) == 16 due to vector
|
||||
// alignment padding, but the compiler still allocates exactly 3 consecutive VGPRs
|
||||
// for the "=v" operand and emits buffer_load_dwordx3 v[N:N+2].
|
||||
template<typename T = u32x3>
|
||||
__device__ __forceinline__ T buffer_load_dwordx3(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
T result;
|
||||
asm volatile("buffer_load_dwordx3 %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
|
|
@ -465,6 +494,116 @@ __device__ __forceinline__ T buffer_load_dwordx4(
|
|||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_ubyte(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_sbyte(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_sbyte %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_ushort(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_sshort(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_sshort %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
// d16 variants merge the loaded sub-dword into the low (or high, for _hi) 16 bits
|
||||
// of the destination register, leaving the other half untouched: read-modify-write.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_ubyte_d16(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_ubyte_d16 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_ubyte_d16_hi(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_ubyte_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_sbyte_d16(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_sbyte_d16 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_sbyte_d16_hi(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_sbyte_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_short_d16(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_short_d16 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_load_short_d16_hi(
|
||||
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_load_short_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(dst)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR>
|
||||
__device__ __forceinline__ void buffer_store_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
|
|
@ -519,6 +658,32 @@ __device__ __forceinline__ void buffer_store_dwordx2(
|
|||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_store_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
// VGPRS
|
||||
} else {
|
||||
asm volatile("buffer_store_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T = u32x3>
|
||||
__device__ __forceinline__ void buffer_store_dwordx3(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
asm volatile("buffer_store_dwordx3 %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
|
|
@ -546,6 +711,50 @@ __device__ __forceinline__ void buffer_store_dwordx4(
|
|||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the low 8 bits of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_byte(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_byte %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the high 16 bits' low 8 bits (bits [23:16]) of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_byte_d16_hi(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_byte_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the low 16 bits of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_short(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_short %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Stores the high 16 bits of the value register.
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ void buffer_store_short_d16_hi(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("buffer_store_short_d16_hi %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR>
|
||||
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR >= 256) {
|
||||
|
|
@ -52,6 +52,8 @@ __device__ __forceinline__ int laneid() { return threadIdx.x & 0x3f; }
|
|||
|
||||
using i32x2 = int32_t __attribute__((ext_vector_type(2)));
|
||||
using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
|
||||
using i32x3 = int32_t __attribute__((ext_vector_type(3)));
|
||||
using u32x3 = uint32_t __attribute__((ext_vector_type(3)));
|
||||
using i32x4 = int32_t __attribute__((ext_vector_type(4)));
|
||||
using u32x4 = uint32_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
10
extra/thunder/amd/include/cdna4/includes.cuh
Normal file
10
extra/thunder/amd/include/cdna4/includes.cuh
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief Aggregate header for CDNA4 ThunderKittens core include tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/common.cuh"
|
||||
#include "types/types.cuh"
|
||||
#include "ops/ops.cuh"
|
||||
|
|
@ -248,7 +248,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
|
|||
"ds_read_b64_tr_b16 %0, %2 offset:%3\n"
|
||||
"ds_read_b64_tr_b16 %1, %2 offset:%4\n"
|
||||
// "s_waitcnt lgkmcnt(0)\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
|
||||
: "v"(addr), "i"(offset), "i"(offset + 4 * ST::underlying_subtile_row_bytes)
|
||||
: "memory"
|
||||
|
|
@ -259,7 +259,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
|
|||
"ds_read_b64_tr_b16 %0, %2 offset:%4\n"
|
||||
"ds_read_b64_tr_b16 %1, %3 offset:%4\n"
|
||||
// "s_waitcnt lgkmcnt(0)\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
|
||||
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
|
||||
: "v"(addr), "v"(next_addr), "i"(offset)
|
||||
: "memory"
|
||||
|
|
@ -688,4 +688,88 @@ __device__ inline static void store(ST &dst, const RT &src) {
|
|||
}
|
||||
}
|
||||
|
||||
template<ducks::rt::col_layout RT, ducks::st::all ST>
|
||||
__device__ inline static void load(RT &dst, const ST &src, int col_offset) {
|
||||
static_assert(RT::cols == ST::rows,
|
||||
"col_layout load with col_offset: ST.rows must equal RT::cols");
|
||||
static_assert(RT::width == 1, "col_layout load with col_offset: only width==1 supported");
|
||||
|
||||
using T2 = RT::dtype;
|
||||
using U = ST::dtype;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
|
||||
const int laneid = kittens::laneid();
|
||||
const uint32_t src_ptr = (uint32_t)(uintptr_t)(&src.data[0]);
|
||||
constexpr int subtile_bytes = ST::underlying_subtile_bytes;
|
||||
|
||||
if constexpr (std::is_same_v<T2, fp8e4m3_4>) {
|
||||
const int block_id = laneid / 16;
|
||||
const int l_within = laneid % 16;
|
||||
const int tr_k_grp = l_within / 2;
|
||||
const int m_half = l_within & 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RT::height; i++) {
|
||||
const int m_col = col_offset + i * RT::base_tile_rows + m_half * 8;
|
||||
|
||||
const int subtile_base = block_id;
|
||||
uint32_t addr = src_ptr
|
||||
+ (uint32_t)(subtile_base * subtile_bytes)
|
||||
+ src.swizzle({tr_k_grp, m_col});
|
||||
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
|
||||
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
|
||||
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
|
||||
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[4]))
|
||||
: "v"(addr), "i"(4 * subtile_bytes)
|
||||
: "memory"
|
||||
);
|
||||
|
||||
addr ^= 1088u;
|
||||
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
|
||||
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
|
||||
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2])),
|
||||
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[6]))
|
||||
: "v"(addr), "i"(4 * subtile_bytes)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
} else if constexpr (std::is_same_v<U2, bf16_2> || std::is_same_v<U2, half_2>) {
|
||||
const int row_offset = ((laneid % 16) / 4) + ((laneid / 16) * 4);
|
||||
const int col_offset_lane = (laneid % 4) * 4;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RT::height; i++) {
|
||||
const int m_col = col_offset + i * RT::base_tile_rows + col_offset_lane;
|
||||
const uint32_t addr = src_ptr + src.swizzle({row_offset, m_col});
|
||||
const uint32_t next_addr = src_ptr + src.swizzle({row_offset + 4, m_col});
|
||||
|
||||
if constexpr (RT::base_tile_stride == 8) {
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b16 %0, %2 offset:0\n"
|
||||
"ds_read_b64_tr_b16 %1, %3 offset:0\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
|
||||
"=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2]))
|
||||
: "v"(addr), "v"(next_addr)
|
||||
: "memory"
|
||||
);
|
||||
} else if constexpr (RT::base_tile_stride == 4) {
|
||||
asm volatile(
|
||||
"ds_read_b64_tr_b16 %0, %1 offset:0\n"
|
||||
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0]))
|
||||
: "v"(addr)
|
||||
: "memory"
|
||||
);
|
||||
} else {
|
||||
static_assert(false, "Unsupported stride for col_layout load with col_offset");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(false, "Unsupported type for col_layout load with col_offset");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kittens
|
||||
|
|
@ -201,10 +201,10 @@ __device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::col,
|
|||
* @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
|
||||
template<ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
|
||||
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col, D_shape> &d,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b, // in row-major mode
|
||||
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
|
||||
const rt_base<float, ducks::rt_layout::col, C_shape> &c) {
|
||||
|
||||
static_assert(std::is_same_v<D_shape, C_shape>, "D and C must have the same shape");
|
||||
|
|
@ -243,6 +243,32 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
|
||||
*
|
||||
* @param smem_scales LDS pointer to scale tile data.
|
||||
* @param row_offset Starting row within the scale region (warp's tile offset).
|
||||
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
|
||||
*/
|
||||
__device__ __forceinline__ fp8e8m0_4 pack_scales(
|
||||
const fp8e8m0 *smem_scales, int row_offset) {
|
||||
int lid = laneid();
|
||||
int r16 = lid % 16;
|
||||
int k_sub = lid / 16;
|
||||
|
||||
const fp8e8m0_4 *s4 = (const fp8e8m0_4 *)smem_scales;
|
||||
fp8e8m0_4 w0 = s4[row_offset + 0 * 16 + r16];
|
||||
fp8e8m0_4 w1 = s4[row_offset + 1 * 16 + r16];
|
||||
fp8e8m0_4 w2 = s4[row_offset + 2 * 16 + r16];
|
||||
fp8e8m0_4 w3 = s4[row_offset + 3 * 16 + r16];
|
||||
|
||||
fp8e8m0_4 sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
|
||||
fp8e8m0_4 lo = __builtin_amdgcn_perm(w0, w1, sel);
|
||||
fp8e8m0_4 hi = __builtin_amdgcn_perm(w2, w3, sel);
|
||||
|
||||
return lo | (hi << 16);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Base dot product operation for row layout.
|
||||
*
|
||||
|
|
@ -254,10 +280,10 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
|
|||
* @param[in] b The second input rt_base<Operand_T, row_layout> matrix.
|
||||
* @param[in] c The input rt_base<float, col_layout> accumulator matrix.
|
||||
*/
|
||||
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
|
||||
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
|
||||
__device__ static inline void mma_ABt_base_scaled(rt_base<float, ducks::rt_layout::col, D_shape> &d,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b,
|
||||
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
|
||||
const rt_base<float, ducks::rt_layout::col, C_shape> &c,
|
||||
const fp8e8m0_4 *scale_a,
|
||||
const fp8e8m0_4 *scale_b) {
|
||||
|
|
@ -441,7 +467,7 @@ __device__ static inline void mma_AB(D &d,
|
|||
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
|
||||
template<ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
|
||||
__device__ static inline void mma_ABt(D &d,
|
||||
const A &a,
|
||||
const B &b, // notice row and (M, K) instead of col and (K, M)
|
||||
|
|
@ -499,7 +525,7 @@ __device__ static inline void mma_ABt(D &d,
|
|||
* @param[in] scale_a Pointer to the packed E8M0 scale for the A matrix.
|
||||
* @param[in] scale_b Pointer to the packed E8M0 scale for the B matrix.
|
||||
*/
|
||||
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
|
||||
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
|
||||
__device__ static inline void mma_ABt_scaled(D &d,
|
||||
const A &a,
|
||||
const B &b,
|
||||
|
|
@ -9,6 +9,5 @@
|
|||
#include "maps.cuh"
|
||||
#include "reductions.cuh"
|
||||
#include "mma.cuh"
|
||||
#include "scales.cuh"
|
||||
|
||||
#include "assembly/tile.cuh"
|
||||
|
|
@ -41,7 +41,7 @@ struct st_16x16 {
|
|||
} else {
|
||||
static_assert(false, "Unsupported type");
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -205,6 +205,23 @@ struct st_8x32 {
|
|||
}
|
||||
};
|
||||
|
||||
struct st_16x64 {
|
||||
static constexpr int rows = 16;
|
||||
static constexpr int cols = 64;
|
||||
|
||||
template<typename _T>
|
||||
static constexpr int bytes_per_thread() {
|
||||
static_assert(sizeof(_T) == 1 || sizeof(_T) == 2 || sizeof(_T) == 4, "Unsupported type");
|
||||
return 16;
|
||||
}
|
||||
|
||||
template<typename _T>
|
||||
__device__ __forceinline__ static const uint32_t swizzle (int2 coord) {
|
||||
static_assert(sizeof(_T) == 1 || sizeof(_T) == 2 || sizeof(_T) == 4, "Unsupported type");
|
||||
return sizeof(_T) * (coord.x * cols + coord.y);
|
||||
}
|
||||
};
|
||||
|
||||
struct st_16x128 {
|
||||
static constexpr int rows = 16;
|
||||
static constexpr int cols = 128;
|
||||
|
|
@ -236,12 +253,13 @@ struct st_16x128 {
|
|||
};
|
||||
|
||||
template<typename T>
|
||||
concept all = std::is_same_v<T, st_16x16> ||
|
||||
std::is_same_v<T, st_16x16_swizzled> ||
|
||||
std::is_same_v<T, st_32x32> ||
|
||||
std::is_same_v<T, st_16x32> ||
|
||||
std::is_same_v<T, st_32x16> ||
|
||||
concept all = std::is_same_v<T, st_16x16> ||
|
||||
std::is_same_v<T, st_16x16_swizzled> ||
|
||||
std::is_same_v<T, st_32x32> ||
|
||||
std::is_same_v<T, st_16x32> ||
|
||||
std::is_same_v<T, st_32x16> ||
|
||||
std::is_same_v<T, st_8x32> ||
|
||||
std::is_same_v<T, st_16x64> ||
|
||||
std::is_same_v<T, st_16x128>;
|
||||
|
||||
|
||||
|
|
@ -75,6 +75,7 @@ using st_32x32_s = ducks::st_shape::st_32x32;
|
|||
using st_16x32_s = ducks::st_shape::st_16x32;
|
||||
using st_32x16_s = ducks::st_shape::st_32x16;
|
||||
using st_8x32_s = ducks::st_shape::st_8x32;
|
||||
using st_16x64_s = ducks::st_shape::st_16x64;
|
||||
using st_16x128_s = ducks::st_shape::st_16x128;
|
||||
|
||||
}
|
||||
|
|
@ -5,9 +5,12 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "common/common.cuh"
|
||||
#include "types/types.cuh"
|
||||
#include "ops/ops.cuh"
|
||||
#if defined(KITTENS_CDNA4)
|
||||
#include "cdna4/includes.cuh"
|
||||
#elif defined(KITTENS_UDNA1)
|
||||
#include "udna1/includes.cuh"
|
||||
#endif
|
||||
|
||||
#include "pyutils/util.cuh"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,74 +0,0 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief MXFP8 block scale loading and packing utilities.
|
||||
*
|
||||
* Provides functions for staging E8M0 block scales in LDS and packing them
|
||||
* into fp8e8m0_4 registers for use with scaled MFMA instructions.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/**
|
||||
* @brief Load iteration-major packed E8M0 scales from global memory into LDS.
|
||||
*
|
||||
* First 256 threads each load one uint32 (4 packed E8M0 bytes) for A and B.
|
||||
* A scales are placed at smem[0..1023], B scales at smem[1024..2047].
|
||||
*
|
||||
* @param smem_scales LDS buffer, must be >= 2048 bytes.
|
||||
* @param scale_A_iter Iteration-major A scales: [k_iter * M + row] as uint32.
|
||||
* @param scale_B_iter Iteration-major B scales: [k_iter * N + row] as uint32.
|
||||
* @param block_m Starting row offset for A within the current block.
|
||||
* @param block_n Starting row offset for B within the current block.
|
||||
* @param k_iter Current K iteration index.
|
||||
* @param M_dim M dimension of the matrix.
|
||||
* @param N_dim N dimension of the matrix.
|
||||
*/
|
||||
__device__ __forceinline__ void load_scales_to_lds(
|
||||
uint8_t *smem_scales,
|
||||
const uint32_t *__restrict__ scale_A_iter,
|
||||
const uint32_t *__restrict__ scale_B_iter,
|
||||
int block_m, int block_n, int k_iter, int M_dim, int N_dim) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < 256) {
|
||||
uint32_t sa = scale_A_iter[k_iter * M_dim + block_m + tid];
|
||||
uint32_t sb = scale_B_iter[k_iter * N_dim + block_n + tid];
|
||||
*(uint32_t *)&smem_scales[tid * 4] = sa;
|
||||
*(uint32_t *)&smem_scales[1024 + tid * 4] = sb;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
|
||||
*
|
||||
* Each lane (r16 = laneid%16, k_sub = laneid/16) loads 4 dwords from
|
||||
* consecutive 16-row groups, then uses v_perm_b32 to extract byte k_sub
|
||||
* from each, producing the packed scale register for scaled MFMA.
|
||||
*
|
||||
* @param smem_scales LDS pointer to scale region.
|
||||
* @param lds_base Byte offset within smem_scales (0 for A, 1024 for B).
|
||||
* @param row_offset Starting row within the scale region (warp's tile offset).
|
||||
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
|
||||
*/
|
||||
__device__ __forceinline__ fp8e8m0_4 pack_scales(
|
||||
const uint8_t *smem_scales, int lds_base, int row_offset) {
|
||||
int lid = laneid();
|
||||
int r16 = lid % 16;
|
||||
int k_sub = lid / 16;
|
||||
|
||||
const uint32_t *s32 = (const uint32_t *)(smem_scales + lds_base);
|
||||
uint32_t w0 = s32[row_offset + 0 * 16 + r16];
|
||||
uint32_t w1 = s32[row_offset + 1 * 16 + r16];
|
||||
uint32_t w2 = s32[row_offset + 2 * 16 + r16];
|
||||
uint32_t w3 = s32[row_offset + 3 * 16 + r16];
|
||||
|
||||
uint32_t sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
|
||||
uint32_t lo = __builtin_amdgcn_perm(w0, w1, sel);
|
||||
uint32_t hi = __builtin_amdgcn_perm(w2, w3, sel);
|
||||
|
||||
return (fp8e8m0_4)(lo | (hi << 16));
|
||||
}
|
||||
} // namespace kittens
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue