Compare commits

...

54 commits

Author SHA1 Message Date
chenyu
687ade119e
IMAGE hand_coded_optimizations update (#16720) 2026-06-23 21:55:28 -04:00
George Hotz
0a8e61d0c5
switch to the new memory coaleser [pr] (#16716)
* switch to the new memory coalese

* move that stuff

* copy in allowed length logic

* mulitple buffers

* new coalese is better

* fine

* earlier

* fixes

* work

* work

* valid

* stack on index const
2026-06-23 18:03:48 -07:00
wozeparrot
dfea9e7994
llama: fused silu mul quantize mxfp8 (#16704) 2026-06-23 16:59:50 -07:00
chenyu
ce87d80911
better _drop_valid_stmts [pr] (#16719)
also dropped the unused is_increasing
2026-06-23 19:35:01 -04:00
George Hotz
5a2b3b7b06
early dtype decomp (#16718)
* early dtype decomp

* simplify

* cleanup

* that goes there

* doing too much

* stupid symbolic rules
2026-06-23 16:07:20 -07:00
Christopher Milan
116045cc8e
ci: remove tensorflow from testoptim (#16717) 2026-06-23 18:11:48 -04:00
nimlgen
7c1d0b6d9a
hcq2: use shrink(bitcast) (#16713)
* hcq2: use shrink(bitcast)

* x
2026-06-23 18:11:39 +03:00
George Hotz
c9dc1d63cc
small changes from new codegen (#16712)
* small changes from new codegen

* shrink/flatten
2026-06-22 17:44:15 -07:00
Christopher Milan
da98fae9e1
ci: try parallelizing tc tests (#16710) 2026-06-22 20:43:32 -04:00
chenyu
15988b5941
contiguous to mixin and cleanups [PR] (#16711) 2026-06-22 20:18:18 -04:00
Christopher Milan
cbfcf36e44
ci: remove generate_dataset and CL misc (#16709) 2026-06-22 18:01:07 -04:00
nimlgen
f9c8c697d6
hcq2: drop args after inner deps (#16708) 2026-06-22 23:26:11 +03:00
chenyu
0138480910
dropout and scaled_dot_product_attention to mixin (#16707) 2026-06-22 16:17:45 -04:00
chenyu
33b635d23a
Tensor.train -> TRAINING [PR] (#16705)
* Tensor.train -> TRAINING [PR]

* doc
2026-06-22 15:13:22 -04:00
chenyu
625d8bbd0d
TRAINING ContextVar (#16703) 2026-06-22 13:03:08 -04:00
wozeparrot
fe9b19b12d
llama: more mp mem fixes (#16701)
* llama: more mp mem fixes

* clean: unused

* fix: batch
2026-06-22 10:54:35 -04:00
chenyu
267af9c601
full_like to CreationMixin [PR] (#16702) 2026-06-22 09:33:23 -04:00
chenyu
97da54b9d6
more method to CreationMixin [PR] (#16698) 2026-06-22 00:01:22 -04:00
chenyu
fd0dc40689
clean up CreationMixin and DTypeMixin [PR] (#16697) 2026-06-21 21:13:40 -04:00
chenyu
2d8b802958
contiguous in wino conv (#16696)
also fixed test_counters
2026-06-21 17:11:46 -04:00
chenyu
ba1d3baae8
masked_select and nonzero to mixin [PR] (#16695)
with a .data stub
2026-06-21 15:10:44 -04:00
chenyu
d80a41d559
some rand method to RandMixin [PR] (#16693) 2026-06-21 12:16:51 -04:00
wozeparrot
5164c21b44
gemm: keep shape thru mxfp8 quantize (#16692) 2026-06-20 22:28:53 -07:00
chenyu
58ff75272e
const_like and invalids to mixin [PR] (#16690)
* const_like and invalids to mixin [PR]

* empty_like

* einsum

* type
2026-06-21 00:02:29 -04:00
chenyu
b50da5c205
move Tensor.__getitem__ to mixin [PR] (#16689) 2026-06-20 22:01:45 -04:00
chenyu
4618d27129
final const cleanups [PR] (#16688) 2026-06-20 21:38:16 -04:00
chenyu
9ae0a93d0e
more const cleanups [PR] (#16682) 2026-06-20 20:41:43 -04:00
George Hotz
30830850a9
small changes from new codegen (#16681)
* small changes from new codegen

* revert that
2026-06-19 18:29:01 -07:00
chenyu
8b07cca9f7
invalid clone try 3+ [PR] (#16679) 2026-06-19 20:13:52 -04:00
Christopher Milan
b2199c54a3
ci: update actions/cache/restore to suppress warnings (#16680) 2026-06-19 18:27:52 -04:00
Christopher Milan
1822eed8d3
ci: only test models on cpu (#16678) 2026-06-19 18:16:59 -04:00
wozeparrot
bba611bb59
gemm: fix mxfp8 on more shapes (#16677) 2026-06-19 13:28:53 -07:00
chenyu
67c3e589a1
invalid clone tests and prereq [PR] (#16675) 2026-06-19 13:20:43 -04:00
George Hotz
649971f02a
remove DEFINE_LOCAL and DEFINE_REG (gpt) (#16673)
* remove define_local and define_reg (gpt)

* fix precommit

* cleanups

* regalloc fix

* cleanups 2
2026-06-19 10:07:50 -07:00
George Hotz
b05bea81ce
x86 cleanups (fable) [pr] (#16591)
* x86 cleanups (fable)

* support shrink

* remove ptr dtype

* move that

* is_lane helper

* Revert "is_lane helper"

This reverts commit ea4571254d.
2026-06-19 09:04:51 -07:00
nimlgen
97c2e7a3d9
spec: add getaddr (#16674) 2026-06-19 15:37:33 +03:00
George Hotz
d7b10c69bc
update placeholder to not create DEFINE_LOCAL/DEFINE_REG (#16671)
* update placeholder to not create DEFINE_LOCAL/DEFINE_REG

* simpler

* define_local
2026-06-18 21:21:06 -07:00
Christopher Milan
091ec8d10d
use tinygrad.llm in benchmarks (#16670) 2026-06-19 00:03:57 -04:00
George Hotz
925c49ce99
use placeholder in tests (#16672) 2026-06-18 20:51:44 -07:00
wozeparrot
05249466ed
llama: fused quantize mxfp8 (#16667) 2026-06-18 16:02:28 -07:00
George Hotz
4a4b6956df
remove DEFINE_VAR from codebase (gpt) (#16666)
* remove DEFINE_VAR from codebase

* junk

* remove junk
2026-06-18 15:33:50 -07:00
nimlgen
eda0a402d1
hcq2: fix multi (#16661) 2026-06-18 22:56:49 +03:00
George Hotz
5989d0b150
remove DEFINE_VAR try 2 (#16651)
* remove DEFINE_VAR try 2

* param

* null index

* fix fuzzing

* fixes

* no gather neg params

* param is just Irreducible

* fixes

* skip stack

* need to filter slots there
2026-06-18 12:34:25 -07:00
wozeparrot
d37248c3ec
gemm: fix mxfp8 on odd shapes (#16664) 2026-06-18 12:03:59 -07:00
chenyu
d74f488376
clean up _function.depth properly [PR] (#16663) 2026-06-18 14:10:22 -04:00
chenyu
d7a1022188
minor function.py cleanups [PR] (#16662) 2026-06-18 13:36:48 -04:00
qazal
924bece1d5
remove some old scheduler tests (#16660) 2026-06-18 22:15:00 +09:00
qazal
b753fb5e4c
viz: view source working even if compile failed (#16657)
* failing test

* hard

* ret_dict

* switch to _data for tests too

* update sqtt

* start work

* Ops.LINEAR looks good

* baseline with depth works

* support depth

* types

* @needs_tracked_pm

* update, marg can error too

* unwrap_or goes to many more places

* move things to soft_err

* soft_err everywhere needed

* diff cleanup

* use list

* rewrite it

* change

* update depth number

* small comment change
2026-06-18 17:34:53 +09:00
qazal
31094a794f
viz: data not sent to client side starts with _ (#16659)
* ret_dict

* switch to _data for tests too

* update sqtt

* rename to filter_keys

* not cfg
2026-06-18 15:25:22 +09:00
qazal
1720987dc7
include exception name in Ops.REWRITE_ERROR (#16658) 2026-06-18 14:52:48 +09:00
wozeparrot
bed0c343a3
faster mxfp8 gemm (#16656) 2026-06-17 22:35:36 -07:00
Christopher Milan
e0fe6e542e
ci: fewer pydeps (#16654) 2026-06-17 22:52:14 -04:00
chenyu
a74b7130b4
Revert "invalid clone try 2 [PR] (#16648)" (#16653)
This reverts commit 1bd4551ee1.
2026-06-17 22:05:30 -04:00
chenyu
df015ad541
remove many type ignores [PR] (#16652) 2026-06-17 21:38:45 -04:00
257 changed files with 15027 additions and 1649 deletions

View file

@ -80,7 +80,7 @@ runs:
- name: Cache Python packages (PR) - name: Cache Python packages (PR)
if: github.event_name == 'pull_request' if: github.event_name == 'pull_request'
id: restore-venv-pr id: restore-venv-pr
uses: actions/cache/restore@v4 uses: actions/cache/restore@v5
with: with:
path: /tmp/.uv-cache path: /tmp/.uv-cache
key: uv-${{ runner.os }}-${{ runner.arch }}-python-${{ inputs.python-version }}-${{ inputs.deps }}-${{ inputs.pydeps }}-${{ env.CACHE_VERSION }} key: uv-${{ runner.os }}-${{ runner.arch }}-python-${{ inputs.python-version }}-${{ inputs.deps }}-${{ inputs.pydeps }}-${{ env.CACHE_VERSION }}
@ -96,7 +96,7 @@ runs:
- name: Cache downloads (PR) - name: Cache downloads (PR)
if: inputs.key != '' && github.event_name == 'pull_request' if: inputs.key != '' && github.event_name == 'pull_request'
uses: actions/cache/restore@v4 uses: actions/cache/restore@v5
with: with:
path: ${{ runner.os == 'Linux' && '~/.cache/tinygrad/downloads/' || '~/Library/Caches/tinygrad/downloads/' }} path: ${{ runner.os == 'Linux' && '~/.cache/tinygrad/downloads/' || '~/Library/Caches/tinygrad/downloads/' }}
key: downloads-${{ github.job }}-${{ inputs.key }}-${{ env.CACHE_VERSION }} key: downloads-${{ github.job }}-${{ inputs.key }}-${{ env.CACHE_VERSION }}
@ -203,7 +203,7 @@ runs:
- name: Cache apt (PR) - name: Cache apt (PR)
if: runner.os == 'Linux' && (inputs.opencl == 'true' || inputs.amd == 'true' || inputs.webgpu == 'true' || inputs.llvm == 'true' || inputs.qemu == 'true') && github.event_name == 'pull_request' if: runner.os == 'Linux' && (inputs.opencl == 'true' || inputs.amd == 'true' || inputs.webgpu == 'true' || inputs.llvm == 'true' || inputs.qemu == 'true') && github.event_name == 'pull_request'
uses: actions/cache/restore@v4 uses: actions/cache/restore@v5
with: with:
path: /var/cache/apt/archives/ path: /var/cache/apt/archives/
key: ${{ runner.os }}-${{ runner.arch }}-apt-${{ steps.apt-pkgs.outputs.hash }}-${{ env.CACHE_VERSION }} key: ${{ runner.os }}-${{ runner.arch }}-apt-${{ steps.apt-pkgs.outputs.hash }}-${{ env.CACHE_VERSION }}

View file

@ -99,7 +99,6 @@ jobs:
ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
- name: setup staging db - name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging' if: github.ref == 'refs/heads/update_benchmark_staging'
@ -134,32 +133,10 @@ jobs:
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
- name: Fuzz Padded Tensor Core GEMM - name: Fuzz Padded Tensor Core GEMM
run: DEV=METAL M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py run: DEV=METAL M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
- name: Run LLaMA - name: Run llama3.2
run: | run: BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing - name: Run olmoe
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing run: BENCHMARK_LOG=olmoe JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m olmoe --benchmark --warmup
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run quantized LLaMA
run: |
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
- name: Run quantized LLaMA3
run: |
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
#- name: Run LLaMA 7B on 4 (virtual) GPUs
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run OLMoE
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
- name: Train MNIST - name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
@ -235,9 +212,6 @@ jobs:
- name: Symlink models and datasets - name: Symlink models and datasets
run: | run: |
mkdir -p weights mkdir -p weights
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3 ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
mkdir -p extra/datasets mkdir -p extra/datasets
ln -s /raid/datasets/imagenet extra/datasets/imagenet ln -s /raid/datasets/imagenet extra/datasets/imagenet
@ -279,36 +253,16 @@ jobs:
# TODO: too slow # TODO: too slow
# - name: Run SDXL # - name: Run SDXL
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing # run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA - name: Run llama3.2
run: | run: DEV=NV BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
BENCHMARK_LOG=llama_nojit DEV=NV JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing - name: Run qwen3.5
BENCHMARK_LOG=llama DEV=NV JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing run: DEV=NV BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run LLaMA-3 8B on 4 GPUs with BEAM - name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 run: BENCHMARK_LOG=llama3_beam_4gpu DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run quantized LLaMA3
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
# - name: Run LLaMA-3 8B on 6 GPUs # - name: Run LLaMA-3 8B on 6 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 # run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-2 70B # - name: Run LLaMA-2 70B
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing # run: DEV=NV CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit DEV=NV JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 DEV=NV JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half DEV=NV HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam DEV=NV HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- uses: actions/upload-artifact@v7 - uses: actions/upload-artifact@v7
with: with:
name: Speed (NVIDIA) name: Speed (NVIDIA)
@ -402,10 +356,7 @@ jobs:
run: | run: |
mkdir -p weights mkdir -p weights
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3 ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
mkdir -p extra/datasets mkdir -p extra/datasets
ln -s /raid/datasets/imagenet extra/datasets/imagenet ln -s /raid/datasets/imagenet extra/datasets/imagenet
@ -458,18 +409,10 @@ jobs:
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 DEV=AMD python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 DEV=AMD python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
- name: Run SDXL - name: Run SDXL
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 DEV=AMD python3 examples/sdxl.py --seed 0 --noshow --timing run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 DEV=AMD python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA 7B - name: Run llama3.2
run: | run: DEV=AMD BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
BENCHMARK_LOG=llama_nojit DEV=AMD JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing - name: Run qwen3.5
BENCHMARK_LOG=llama DEV=AMD JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing run: DEV=AMD BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
- name: Run LLaMA 7B with BEAM
run: BENCHMARK_LOG=llama_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run LLaMA-3 8B on 4 GPUs with BEAM - name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 run: BENCHMARK_LOG=llama3_beam_4gpu DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-3 8B on 6 GPUs # - name: Run LLaMA-3 8B on 6 GPUs
@ -478,16 +421,6 @@ jobs:
# run: sudo modprobe amdgpu # run: sudo modprobe amdgpu
# - name: Run LLaMA-2 70B # - name: Run LLaMA-2 70B
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing # run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral DEV=AMD python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit DEV=AMD JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 DEV=AMD JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half DEV=AMD HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam DEV=AMD HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay

View file

@ -133,46 +133,26 @@ jobs:
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20 run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
- name: Test IMAGE support - name: Test IMAGE support
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
- name: Test emulated METAL tensor cores - name: Test emulated tensor cores
env: env:
DEV: 'PYTHON::METAL' DEBUG: 2
N: 64
CNT: 1
SHOULD_USE_TC: 1
run: | run: |
DEBUG=2 python3 test/backend/test_ops.py TestOps.test_big_gemm parallel -k --link --tagstring '[{1}]' '{2} python3 ./extra/gemm/simple_matmul.py' \
python3 -m pytest -nauto test/opt/test_tensor_cores.py ::: metal gfx950 gfx1100 gfx1100_acchalf gfx1201 gfx1201_acchalf sm_75 sm_80_half sm_80_tf32 \
- name: Test emulated AMD tensor cores ::: 'DEV=PYTHON::METAL' 'DEV=PYTHON::gfx950 HALF=1 ACC_HALF=0' \
env: 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=1 ATOL=1e-3' \
DEV: 'PYTHON::gfx1100' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=1 ATOL=1e-3' \
'DEV=PYTHON::sm_75 HALF=1' 'DEV=PYTHON::sm_80 HALF=1' 'DEV=PYTHON::sm_80 ALLOW_TF32=1'
- name: Run additional tensor core tests
run: | run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::METAL python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::gfx1100 python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::gfx950 python3 -m pytest -nauto test/opt/test_tensor_cores.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py DEV=PYTHON::gfx1201 python3 -m pytest -nauto test/opt/test_tensor_cores.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD MFMA tensor cores
env:
DEV: 'PYTHON::gfx950'
run: |
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD RDNA4 tensor cores
env:
DEV: 'PYTHON::gfx1201'
run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated CUDA tensor cores
run: |
DEBUG=2 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test device flop counts
run: |
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
linter: linter:
@ -218,14 +198,15 @@ jobs:
uses: ./.github/actions/setup-tinygrad uses: ./.github/actions/setup-tinygrad
with: with:
key: unittest-13 key: unittest-13
pydeps: "pillow ftfy regex pre-commit"
deps: testing_unit deps: testing_unit
llvm: 'true' llvm: 'true'
amd: 'true' amd: 'true'
- name: Run NULL backend tests - name: Run NULL backend tests
run: DEV=NULL python -m pytest -n=auto test/null/ --durations=20 run: DEV=NULL python -m pytest -n=auto test/null/ --durations=20
- name: Run targeted tests on NULL backend - name: Run targeted tests on NULL backend
run: DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step run: |
DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
DEV=NULL VIZ=1 python3 -m pytest -n=auto test/null/test_viz.py
# TODO: too slow # TODO: too slow
# - name: Run SDXL on NULL backend # - name: Run SDXL on NULL backend
# run: DEV=NULL DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights # run: DEV=NULL DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
@ -249,7 +230,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad uses: ./.github/actions/setup-tinygrad
with: with:
key: unittest-13 key: unittest-13
pydeps: "pillow ftfy regex pre-commit" pydeps: "pre-commit"
deps: testing_unit deps: testing_unit
llvm: 'true' llvm: 'true'
- name: Run pre-commit test hooks - name: Run pre-commit test hooks
@ -266,13 +247,6 @@ jobs:
run: python3 test/external/external_benchmark_schedule.py run: python3 test/external/external_benchmark_schedule.py
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
- name: Regen dataset on test_tiny
run: |
test/external/process_replay/reset.py
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
python extra/optimization/extract_dataset.py
gzip -c /tmp/sops > extra/datasets/sops.gz
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
- name: Repo line count < 25000 lines - name: Repo line count < 25000 lines
run: MAX_LINE_COUNT=25000 python sz.py run: MAX_LINE_COUNT=25000 python sz.py
@ -337,31 +311,6 @@ jobs:
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
testgpumisc:
name: CL Misc tests
runs-on: *linux
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: gen-dataset
deps: testing
opencl: 'true'
- name: Generate Dataset
run: DEV=CL extra/optimization/generate_dataset.sh
- name: Run Kernel Count Test
run: DEV=CL python -m pytest -n=auto test/external/external_test_opt.py
- name: Run fused optimizer tests
run: DEV=CL FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py test/backend/test_optim.py -k "not muon"
- name: Upload artifact
uses: actions/upload-artifact@v7
with:
name: sops.gz
path: /tmp/sops.gz
testopenpilot: testopenpilot:
name: openpilot Compile Tests name: openpilot Compile Tests
runs-on: *linux runs-on: *linux
@ -378,7 +327,7 @@ jobs:
llvm: 'true' llvm: 'true'
- name: Test openpilot model kernel count and gate usage - name: Test openpilot model kernel count and gate usage
run: | run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness) - name: Test openpilot CL compile fp32 (test correctness)
run: | run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
@ -421,7 +370,6 @@ jobs:
with: with:
key: optim key: optim
deps: testing deps: testing
pydeps: "tensorflow==2.19"
opencl: 'true' opencl: 'true'
#- name: Test Optimization Helpers #- name: Test Optimization Helpers
# run: DEBUG=1 python3 extra/optimization/test_helpers.py # run: DEBUG=1 python3 extra/optimization/test_helpers.py
@ -430,7 +378,7 @@ jobs:
- name: Test Beam Search - name: Test Beam Search
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Test MLPerf stuff - name: Test MLPerf stuff
run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20 run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
- name: DEV=NULL beautiful_mnist_multigpu - name: DEV=NULL beautiful_mnist_multigpu
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
- name: Test Bert training - name: Test Bert training
@ -468,7 +416,7 @@ jobs:
# ****** Models Tests ****** # ****** Models Tests ******
testmodels: testmodels:
name: Models (llvm+cpu+gpu) name: Models
runs-on: *linux runs-on: *linux
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
@ -479,34 +427,12 @@ jobs:
with: with:
key: models key: models
deps: testing deps: testing
opencl: 'true'
llvm: 'true' llvm: 'true'
- name: Test models (llvm)
run: DEV=CPU:LLVM python -m pytest -n=auto test/models --durations=20
- name: Test models (opencl)
run: DEV=CL python -m pytest -n=auto test/models --durations=20
- name: Test models (cpu) - name: Test models (cpu)
run: DEV=CPU python -m pytest -n=auto test/models --durations=20 run: DEV=CPU python -m pytest -n=auto test/models --durations=20
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
testmetalmodels:
name: Models (metal)
runs-on: &macos macos-26
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: metal
deps: testing
- name: Test models (Metal)
run: DEV=METAL python -m pytest -n=auto test/models --durations=20
- name: Test LLaMA compile speed
run: DEV=METAL python test/external/external_test_speed_llama.py
# ****** Feature Tests ****** # ****** Feature Tests ******
testdsp: testdsp:
@ -715,7 +641,7 @@ jobs:
unittestmacos: unittestmacos:
name: MacOS (unit) name: MacOS (unit)
runs-on: *macos runs-on: &macos macos-26
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
- name: Checkout Code - name: Checkout Code

View file

@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
Throw in an optimizer, a data loader, and some compute, and you have all you need. Throw in an optimizer, a data loader, and some compute, and you have all you need.
```python ```python
from tinygrad import Tensor, nn from tinygrad import Tensor, nn, Context
class LinearNet: class LinearNet:
def __init__(self): def __init__(self):
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
with Tensor.train(): with Context(TRAINING=1):
for i in range(10): for i in range(10):
optim.zero_grad() optim.zero_grad()
loss = model(x).sparse_categorical_crossentropy(y).backward() loss = model(x).sparse_categorical_crossentropy(y).backward()

View file

@ -165,13 +165,14 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network. Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64. We will be training for 1000 steps with a batch size of 64.
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training. We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager. Upon exit, the flag is restored to its previous value by the context manager.
```python ```python
from tinygrad import Context
X_train, Y_train, X_test, Y_test = fetch_mnist() X_train, Y_train, X_test, Y_test = fetch_mnist()
with Tensor.train(): with Context(TRAINING=1):
for step in range(1000): for step in range(1000):
# random sample a batch # random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64)) samp = np.random.randint(0, X_train.shape[0], size=(64))

View file

@ -1,6 +1,6 @@
from typing import Tuple from typing import Tuple
import time import time
from tinygrad import Tensor, TinyJit, nn from tinygrad import Tensor, TinyJit, nn, Context
import gymnasium as gym import gymnasium as gym
from tinygrad.helpers import trange from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import import numpy as np # TODO: remove numpy import
@ -55,7 +55,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]: def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train(): with Context(TRAINING=1):
log_dist, value = model(x) log_dist, value = model(x)
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float() action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()

View file

@ -122,7 +122,7 @@ if __name__ == "__main__":
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(idxs:Tensor) -> Tensor: def train_step(idxs:Tensor) -> Tensor:
X, Y = X_train[idxs], Y_train[idxs] X, Y = X_train[idxs], Y_train[idxs]
if len(GPUS) > 1: if len(GPUS) > 1:

View file

@ -1,6 +1,6 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -19,7 +19,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor: def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])

View file

@ -1,6 +1,6 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 # model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
from tinygrad.helpers import getenv, colored, trange from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -31,7 +31,7 @@ if __name__ == "__main__":
@TinyJit @TinyJit
def train_step() -> Tensor: def train_step() -> Tensor:
with Tensor.train(): with Context(TRAINING=1):
opt.zero_grad() opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0 Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0

View file

@ -1,6 +1,6 @@
import itertools import itertools
from typing import Callable from typing import Callable
from tinygrad import nn, Tensor, dtypes, Device, TinyJit from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
from tinygrad.helpers import getenv, trange, partition from tinygrad.helpers import getenv, trange, partition
class Model: class Model:
@ -59,7 +59,7 @@ if __name__ == "__main__":
Tensor.realize(*params, *buffers, *adam_params, loss, grads) Tensor.realize(*params, *buffers, *adam_params, loss, grads)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def microbatch(): def microbatch():
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0]) samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
for t in params: t.grad = None for t in params: t.grad = None

View file

@ -359,7 +359,7 @@ def train_cifar():
i = 0 i = 0
eval_acc_pct = 0.0 eval_acc_pct = 0.0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train(): with Context(TRAINING=1):
st = time.monotonic() st = time.monotonic()
while i <= STEPS: while i <= STEPS:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"): if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os, math, time import os, math, time
import numpy as np import numpy as np
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
@ -177,7 +177,7 @@ if __name__ == "__main__":
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0) if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def step(x:Tensor, y:Tensor) -> Tensor: def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y) _, loss = model(x, y)
optimizer.zero_grad() optimizer.zero_grad()
@ -204,4 +204,3 @@ if __name__ == "__main__":
top_k = 40 top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist())) print(decode(y[0].tolist()))

View file

@ -1,5 +1,5 @@
# much taken from https://github.com/cloneofsimo/minRF # much taken from https://github.com/cloneofsimo/minRF
from tinygrad import Tensor, nn, GlobalCounters, TinyJit from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
from tinygrad.helpers import getenv, trange from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@ -135,7 +135,7 @@ if __name__ == "__main__":
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4) optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(): def train_step():
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int') if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0]) else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])

View file

@ -358,7 +358,7 @@ def eval_stable_diffusion():
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape)) batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
return batch, unpadded_bs return batch, unpadded_bs
@Tensor.train(mode=False) @Context(TRAINING=0)
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL, def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]: inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
# Eval is divided into 5 jits, one per model # Eval is divided into 5 jits, one per model

View file

@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
from pathlib import Path from pathlib import Path
import multiprocessing import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@ -614,7 +614,7 @@ def train_retinanet():
if getenv("RESET_STEP", 1): _train_step.reset() if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False): with Context(TRAINING=0):
if not RUNMLPERF: if not RUNMLPERF:
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True)) i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
else: else:
@ -784,7 +784,7 @@ def train_unet3d():
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit @TinyJit
@Tensor.train() @Context(TRAINING=1)
def train_step(model, x, y): def train_step(model, x, y):
optim.zero_grad() optim.zero_grad()
@ -795,7 +795,7 @@ def train_unet3d():
optim.step() optim.step()
return loss.realize() return loss.realize()
@Tensor.train(mode=False) @Context(TRAINING=0)
def eval_step(model, x, y): def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS) y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y) y_hat, y = Tensor(y_hat), Tensor(y)
@ -1490,7 +1490,7 @@ def train_llama3():
return lr_cpu, grad_norm_cpu return lr_cpu, grad_norm_cpu
@TinyJit @TinyJit
@Tensor.train(False) @Context(TRAINING=0)
def eval_step(tokens:Tensor): def eval_step(tokens:Tensor):
if is_dp: tokens = tokens.to(None).shard(device, 0) if is_dp: tokens = tokens.to(None).shard(device, 0)
if is_mp: tokens = tokens.shard(device) if is_mp: tokens = tokens.shard(device)
@ -1803,7 +1803,7 @@ if __name__ == "__main__":
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN) elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
else: bench_log_manager = contextlib.nullcontext() else: bench_log_manager = contextlib.nullcontext()
with Tensor.train(): with Context(TRAINING=1):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","): for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
nm = f"train_{m}" nm = f"train_{m}"
if nm in globals(): if nm in globals():

View file

@ -38,7 +38,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None, def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None, x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]: grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]:
if not fp8: if not fp8:
if ASM_GEMM: if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
@ -47,12 +47,14 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)" assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
if MXFP8: if MXFP8:
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1])) if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d)
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
if can_use_asm_gemm(x_q, w.T): if can_use_asm_gemm(x_q, w.T):
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale), out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0]) mx_w_stored=True).reshape(*l_shape, w.shape[0])
else: else:
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1]) x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1])
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
return out, (amax_x.detach() if amax_x is not None else None), x_q return out, (amax_x.detach() if amax_x is not None else None), x_q
if x_fp8 is None: if x_fp8 is None:
@ -126,10 +128,8 @@ class FlatTransformer:
# FeedForward # FeedForward
if SPLIT_W13: if SPLIT_W13:
if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim) self.w1, s_1 = self.lin_per_layer(dim, hidden_dim)
else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02) self.w3, s_3 = self.lin_per_layer(dim, hidden_dim)
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
else: else:
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2) self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std) self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
@ -160,7 +160,7 @@ class FlatTransformer:
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None): def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
if w is None: if w is None:
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features) if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize() else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
if MXFP8: if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8 from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features)) w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
@ -216,8 +216,15 @@ class FlatTransformer:
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"]) x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
amaxs.append(new_amax) amaxs.append(new_amax)
saves.extend([*s, x_w3]) saves.extend([*s, x_w3])
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"], if FUSED_SILU_W13 and MXFP8:
grad_amax_state=kwargs["grad_amax_xout"]) from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
else:
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
grad_amax_state=kwargs["grad_amax_xout"])
amaxs.append(new_amax) amaxs.append(new_amax)
saves.extend([*s, out]) saves.extend([*s, out])
else: else:
@ -247,20 +254,30 @@ class FlatTransformer:
for v in get_parameters(self): v.shard_(device, axis=None) for v in get_parameters(self): v.shard_(device, axis=None)
else: else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer # flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
def _shard_fp8(name:str, axis:int): def _shard_fp8(name:str, axis:int, std:float=0.02):
getattr(self, name).shard_(device, axis=axis) w = getattr(self, name)
scale_axis = axis if MXFP8 else (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None if MXFP8:
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False) from extra.gemm.cdna_asm_gemm import quantize_mxfp8
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False) w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name]) w_q, w_e8, _ = quantize_mxfp8(w_bf16)
w.replace(w_q)
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
else:
w.shard_(device, axis=axis)
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
sstd = 0.02 / math.sqrt(2 * self.n_layers)
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out _shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in _shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in
if SPLIT_W13: if SPLIT_W13:
_shard_fp8("w1", 1) _shard_fp8("w1", 1)
_shard_fp8("w3", 1) _shard_fp8("w3", 1)
else: else:
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out _shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in _shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize() self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize() self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize() self.norm.weight.shard_(device, axis=None).realize()

View file

@ -3,7 +3,7 @@ import torch
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import trange from tinygrad.helpers import trange, Context
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
@ -86,7 +86,7 @@ if __name__ == "__main__":
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium! optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5) optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop # training loop
with Tensor.train(): with Context(TRAINING=1):
for epoch in (t := trange(epochs)): for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0 loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps): for _ in range(n_steps):

View file

@ -5,7 +5,7 @@
# - symbolic removal # - symbolic removal
from examples.beautiful_mnist import Model from examples.beautiful_mnist import Model
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
from tinygrad.nn.datasets import mnist from tinygrad.nn.datasets import mnist
from tinygrad.helpers import trange from tinygrad.helpers import trange
@ -26,7 +26,7 @@ if __name__ == "__main__":
X_samp, Y_samp = X_train[samples], Y_train[samples] X_samp, Y_samp = X_train[samples], Y_train[samples]
print("*** got samples") print("*** got samples")
with Tensor.train(): with Context(TRAINING=1):
""" """
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i) losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)

View file

@ -1,5 +1,5 @@
from typing import Tuple, Dict, List, Optional from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType, dtypes from tinygrad.dtype import DType, dtypes, AddrSpace
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit
@ -39,7 +39,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
prg = to_program(call.src[0], Device[arg_uops[0].device].renderer) prg = to_program(call.src[0], Device[arg_uops[0].device].renderer)
info = prg.arg info = prg.arg
functions[info.function_name] = prg.src[3].arg functions[info.function_name] = prg.src[3].arg
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR] cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + list(info.vars)
statements.append((info.function_name, cargs, info.global_size, info.local_size)) statements.append((info.function_name, cargs, info.global_size, info.local_size))
return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
@ -253,17 +253,18 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
symbolic_vars = OrderedDict() symbolic_vars = OrderedDict()
for i, (_, args, global_size, _) in enumerate(statements): for i, (_, args, global_size, _) in enumerate(statements):
for j, var in enumerate(args): for j, var in enumerate(args):
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str): if getattr(var, "op", None) is Ops.PARAM and var.addrspace is AddrSpace.ALU and var.arg.name is not None:
if var not in symbolic_vars: if var not in symbolic_vars:
symbolic_vars[var] = var.arg[0] symbolic_vars[var] = var.expr
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var]) bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
statements[i][1][j] = symbolic_vars[var] statements[i][1][j] = symbolic_vars[var]
if global_size: if global_size:
for j, dim in enumerate(global_size): for j, dim in enumerate(global_size):
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}: if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and \
any(s.op is Ops.PARAM and s.addrspace is AddrSpace.ALU for s in dim.src) and any(s.op is Ops.CONST for s in dim.src):
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src) name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}" global_size[j] = f"_{name.expr}[0] + {val.arg}"
prg = "" prg = ""
if target == "clang": if target == "clang":

View file

@ -458,7 +458,8 @@ def test_matmul():
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp: def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)] gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)] lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds') lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"), sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3))) estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))

View file

@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float) # accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG) acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
acc = acc.after(acc.store(acc.zeros_like())) acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
if use_wmma: if use_wmma:
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE) k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)

View file

@ -126,7 +126,7 @@ def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
P_lds = QP_lds[:, :BLOCK_N] P_lds = QP_lds[:, :BLOCK_N]
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N) P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN) P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) — shaped store fails due to RESHAPE(DEFINE_LOCAL) surviving linearization # TODO: P_write[tid].store(S_reg.cast(dtypes.half)) -- shaped store fails due to RESHAPE(local BUFFER) surviving linearization
rw1 = UOp.range(TM, 296, AxisType.LOOP) rw1 = UOp.range(TM, 296, AxisType.LOOP)
rw2 = UOp.range(TN, 297, AxisType.LOOP) rw2 = UOp.range(TN, 297, AxisType.LOOP)
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2) P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)

View file

@ -2619,7 +2619,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
lidx = UOp.special(WORKGROUP_SIZE, "lidx0") lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0") gidx = UOp.special(NUM_WG, "gidx0")
insts = build_kernel(batch, M, N, K, A.dtype.base) insts = build_kernel(batch, M, N, K, A.dtype.base)
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds') lds = UOp.placeholder((133_120,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx, sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2))) arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),
@ -2674,14 +2674,14 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]: def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# 1x32 block scaling along the last axis # 1x32 block scaling along the last axis
rows, K = x.shape *batch, K = x.shape
scale_K, k_iters = K // 32, K // 128 scale_K = K // 32
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1) amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8) e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K) qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
x_scaled = x.float() * qscale x_scaled = x.float() * qscale
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8) return x_clamped.cast(FP8_DTYPE), e8, (mx_pack(e8) if len(batch) == 1 else None)
def mx_pack(e8:Tensor) -> Tensor: def mx_pack(e8:Tensor) -> Tensor:
rows, scale_K = e8.shape rows, scale_K = e8.shape

View file

@ -219,7 +219,8 @@ def test_matmul():
def asm_kernel(A, B, C): def asm_kernel(A, B, C):
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)] gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(THREADS, "lidx0")] lidxs = [UOp.special(THREADS, "lidx0")]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds') lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2))
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3))) arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))

View file

@ -7,7 +7,7 @@ from tinygrad.helpers import to_tuple, round_up
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
from tinygrad.uop.symbolic import symbolic_simple, symbolic from tinygrad.uop.symbolic import symbolic_simple, symbolic
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.runtime.support.hcq import MMIOInterface from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.renderer import Renderer, Estimates from tinygrad.renderer import Renderer, Estimates
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
@ -143,14 +143,17 @@ def make_getaddr(u, device=None):
def make_ins(op, *srcs): def make_ins(op, *srcs):
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op) return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
dt = dtype or val.dtype
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
def make_cmdbuf(lin, devs, tag): def make_cmdbuf(lin, devs, tag):
blob, patches = b'', [] blob, patches = b'', []
for s in (s for ins in lin.src for s in ins.src): for s in (s for ins in lin.src for s in ins.src):
if s.op is not Ops.CONST: patches.append((len(blob), s)) if s.op is not Ops.CONST: patches.append((len(blob), s))
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0) blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag) buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches] return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches])
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops)) def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
@ -211,15 +214,11 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call)) return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
def prep_kernargs(call:UOp, prg:UOp) -> UOp: def prep_kernargs(call:UOp, prg:UOp) -> UOp:
data, info = prg.arg (data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device)
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))), buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
dtypes.uint64) for i,gi in enumerate(info.globals)] \ patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)] + [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)]
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
pm_prep_runtime = PatternMatcher([ pm_prep_runtime = PatternMatcher([
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering # bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
@ -288,19 +287,26 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
new_src.append(call) new_src.append(call)
continue continue
q = get_submit(call.src[0]).src[0] new_q = ctx.last_per_queue[q.arg] = (q:=get_submit(call.src[0]).src[0]).rtag(next(ctx.opid))
new_q = ctx.last_per_queue[q.arg] = q.rtag(next(ctx.opid)) qdevs, refs = to_tuple(new_q.arg[0]), get_call_arg_uops(call)
refs = get_call_arg_uops(call) # per-lane deps, tracked per (device, queue). skip self
deps = dedup(flatten(ctx.deps.access_resources([get_dep_buf(ctx, b, l) for b in refs], call.arg.aux.outs, new_q) for l in range(len(q.arg[0])))) dep_lanes:list[tuple[UOp, int]] = []
for lane, d in enumerate(qdevs):
for dep in ctx.deps.access_resources([get_dep_buf(ctx, b, lane) for b in refs], call.arg.aux.outs, new_q.replace(arg=(d, new_q.arg[1]))):
if dep.tag != new_q.tag: dep_lanes.append((dep, lane))
# optims: keep only the max wait per queue, and drop self-queue waits when the queue self-orders # drop self-queue waits, queue self-orders
deps = {dep.arg:dep for dep in sorted(deps, key=lambda x: x.tag)} if qdevs[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
if to_tuple(new_q.arg[0])[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"): dep_lanes = [(dep, lane) for dep, lane in dep_lanes if dep.arg != (qdevs[lane], new_q.arg[1])]
deps.pop(new_q.arg, None)
new_q = new_q.after(*deps.values()).rtag("deps") if deps else new_q # keep latest dep per lane, group lanes
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:]))) latest = {(dep.arg, lane): dep for dep, lane in sorted(dep_lanes, key=lambda x: x[0].tag)}
deps:dict[UOp, tuple[int, ...]] = collections.defaultdict(tuple)
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}),)))
return linear.replace(src=tuple(new_src)) return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)]) pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
@ -314,7 +320,10 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
# queue is inc with deps # queue is inc with deps
submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0") submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0")
submit = submit.replace(src=(submit.src[0].after(*queues).rtag("deps"),))
# split each (multi-device) queue into per-device deps so each finalizer lane waits on the matching device's signal
lane_queues = [(q.replace(arg=(d, q.arg[1])), (devs.index(d),)) for q in queues for d in to_tuple(q.arg[0])]
submit = submit.replace(src=(submit.src[0].after(*(q for q, _ in lane_queues), arg=tuple(l for _, l in lane_queues)).rtag("deps"),))
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])] upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd] patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]
@ -335,12 +344,14 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
cur_devs = to_tuple((cur:=deps.src[0]).arg[0]) cur_devs = to_tuple((cur:=deps.src[0]).arg[0])
waits = [] waits = []
for dep in deps.src[1:]: for lanes, dep in zip(deps.arg, deps.src[1:]):
devs, queue = dep.arg dep_dev, queue = dep.arg # dep_dev is a single device (deps are recorded per-device)
ctx.add(dep.tag) # mark op to update signal. ctx.add(dep.tag) # mark op to update signal.
sig = make_mstack([make_signal(d, queue=queue, sentinel=d not in devs) for d in cur_devs]) # for lanes that need this dep, wait on the dep device's signal/value; other lanes get a passing sentinel
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0)) lanes = set(lanes)
sig = make_mstack([make_signal(dep_dev if j in lanes else d, queue=queue, sentinel=j not in lanes) for j, d in enumerate(cur_devs)])
val = make_mstack([make_signal_value(dep_dev if j in lanes else d, queue=queue) for j, d in enumerate(cur_devs)]).index(UOp.const(dtypes.int, 0))
waits.append(sig.wait(val + dep.tag)) waits.append(sig.wait(val + dep.tag))
return cur.replace(src=tuple(waits) + cur.src) return cur.replace(src=tuple(waits) + cur.src)
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)]) pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
@ -371,7 +382,8 @@ def merge_queues(linear:UOp) -> UOp:
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
for call in linear.src: for call in linear.src:
if call.tag != "hcq": # finalizer cannot be merged, since it bumps inner signal (this introduces race when multidevs).
if call.tag != "hcq" or (call.tag == "hcq" and call.arg.aux.name == "hcq finalizer"):
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call] new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
continue continue
@ -410,7 +422,7 @@ pm_annotate_devs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
# 4.4. replace params with per-submit input address loads # 4.4. replace params with per-submit input address loads
def replace_params(call:UOp) -> UOp|None: def replace_params(call:UOp) -> UOp|None:
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is not None}): return None if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is AddrSpace.GLOBAL}): return None
# fill new info # fill new info
hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call))) hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call)))
@ -504,8 +516,11 @@ def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
def resolve_getaddr(buf:UOp, g:UOp) -> UOp: def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr) devs, b = to_tuple(g.src[1].arg), buf.buffer
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs)) bufs = tuple(cast(Buffer, x.buffer) for x in buf.src) if buf.op is Ops.MSTACK else tuple(b.bufs if isinstance(b, MultiBuffer) else (b,)*len(devs))
assert len(bufs) == len(devs), f"can't resolve {len(bufs)} buffers on {len(devs)} devices"
addrs = tuple(UOp.const(dtypes.uint64, x.get_buf(d).va_addr) for x, d in zip(bufs, devs))
return addrs[0] if len(addrs) == 1 else UOp(Ops.STACK, dtypes.uint64.vec(len(addrs)), addrs)
def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp: def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp:
itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize
@ -516,9 +531,9 @@ pm_resolve_patches = PatternMatcher([
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack), (UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack), (UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
# index on slice is index # shrink on slice is shrink on base at offset
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True), (UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"),
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))), lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))),
# getaddr # getaddr
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x) (UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
@ -526,8 +541,8 @@ pm_resolve_patches = PatternMatcher([
# folders # folders
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store), (UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), (UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast()
fold_const_store), .store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store),
]) + symbolic_simple ]) + symbolic_simple
# ***************** # *****************

View file

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import cast from typing import cast, Any, Callable
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
assert sys.platform != 'win32' assert sys.platform != 'win32'
from dataclasses import dataclass from dataclasses import dataclass
@ -89,25 +89,25 @@ def memory_barrier(ctx):
reg_done=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff), reg_done=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff),
acquire_mem(ctx))) acquire_mem(ctx)))
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.device)) def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.devs))
def pm4_barrier(ctx): return memory_barrier(ctx) def pm4_barrier(ctx): return memory_barrier(ctx)
def pm4_store(ctx, dst, val): def pm4_store(ctx, dst, val):
if val.op is Ops.BINARY: return None if val.op is Ops.BINARY: return None
return release_mem(ctx, make_getaddr(dst, ctx.device), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low, return release_mem(ctx, make_getaddr(dst, ctx.devs), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
ctx.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True) ctx.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
def pm4_timestamp(ctx, dst): def pm4_timestamp(ctx, dst):
return release_mem(ctx, make_getaddr(dst, ctx.device), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter, return release_mem(ctx, make_getaddr(dst, ctx.devs), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
ctx.pm4.int_sel__mec_release_mem__none) ctx.pm4.int_sel__mec_release_mem__none)
def pm4_program(ctx, prg): def pm4_program(ctx, prg):
data, info = prg.arg data, info = prg.arg
lib_gpu, args = prg.src lib_gpu, args = prg.src
prog_addr = make_getaddr(lib_gpu, ctx.device) + data.entry_point_offset prog_addr = make_getaddr(lib_gpu, ctx.devs) + data.entry_point_offset
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.device) scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.devs)
args_addr = make_getaddr(args, ctx.device) args_addr = make_getaddr(args, ctx.devs)
user_regs = [] user_regs = []
if data.enable_private_segment_sgpr: if data.enable_private_segment_sgpr:
@ -174,7 +174,7 @@ pm_pm4_submit = PatternMatcher([(UPat(Ops.LINEAR, name="lin"),
class SDMAOps(FastEnum): COPY = auto(); POLL_REGMEM = auto(); FENCE = auto(); TRAP = auto(); TIMESTAMP = auto() # noqa: E702 class SDMAOps(FastEnum): COPY = auto(); POLL_REGMEM = auto(); FENCE = auto(); TRAP = auto(); TIMESTAMP = auto() # noqa: E702
def sdma_copy(ctx, dst, src, copy): def sdma_copy(ctx, dst, src, copy):
src_addr, dst_addr = make_getaddr(src, ctx.device), make_getaddr(dst, ctx.device) src_addr, dst_addr = make_getaddr(src, ctx.devs), make_getaddr(dst, ctx.devs)
return UOp(Ops.LINEAR, dtypes.void, tuple([make_ins(SDMAOps.COPY, return UOp(Ops.LINEAR, dtypes.void, tuple([make_ins(SDMAOps.COPY,
ctx.sdma.SDMA_OP_COPY | ctx.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_COPY_LINEAR), ctx.sdma.SDMA_OP_COPY | ctx.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_COPY_LINEAR),
ctx.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(min(copy.arg - off, ctx.max_copy_size) - 1), 0, ctx.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(min(copy.arg - off, ctx.max_copy_size) - 1), 0,
@ -183,17 +183,17 @@ def sdma_copy(ctx, dst, src, copy):
def sdma_wait(ctx, dst, val): def sdma_wait(ctx, dst, val):
op = ctx.sdma.SDMA_OP_POLL_REGMEM | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) \ op = ctx.sdma.SDMA_OP_POLL_REGMEM | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) \
| ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1) | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1)
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.device)), val, 0xffffffff, return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.devs)), val, 0xffffffff,
ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff)) ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
def sdma_store(ctx, dst, val): def sdma_store(ctx, dst, val):
op = ctx.sdma.SDMA_OP_FENCE | (ctx.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if ctx.target[0] != 9 else 0) op = ctx.sdma.SDMA_OP_FENCE | (ctx.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if ctx.target[0] != 9 else 0)
return UOp(Ops.LINEAR, dtypes.void, ( return UOp(Ops.LINEAR, dtypes.void, (
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.device)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0))) make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.devs)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
def sdma_timestamp(ctx, dst): def sdma_timestamp(ctx, dst):
op = ctx.sdma.SDMA_OP_TIMESTAMP | ctx.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL) op = ctx.sdma.SDMA_OP_TIMESTAMP | ctx.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL)
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.device))) return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.devs)))
pm_sdma_opsel = PatternMatcher([ pm_sdma_opsel = PatternMatcher([
(UPat(Ops.BARRIER), lambda: UOp(Ops.NOOP, dtypes.void, ())), (UPat(Ops.BARRIER), lambda: UOp(Ops.NOOP, dtypes.void, ())),
@ -516,11 +516,17 @@ class PCIIface(PCIIfaceBase):
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {}) def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
@dataclass(frozen=True)
class AMDEncodeCtx: # encode-time constants for one queue: devs (every cmdbuf address resolves into these) + gfx version + packet/ip modules
devs: tuple[str, ...]; target: tuple[int, ...]; pm4: Any; sdma: Any; soc: Any # noqa: E702
gc: AMDIP; nbio: AMDIP; xccs: int; max_copy_size: int; tmpring_size: Callable # noqa: E702
def encode_queue(q:UOp) -> UOp|None: def encode_queue(q:UOp) -> UOp|None:
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
devs = to_tuple(q.arg[0]) d = Device[(devs:=to_tuple(q.arg[0]))[0]]
ctx = AMDEncodeCtx(devs, d.target, d.pm4, d.sdma, d.soc, d.gc, d.nbio, d.xccs, d.max_copy_size, d.tmpring_size)
opsel, submit = (pm_pm4_opsel, pm_pm4_submit) if q.arg[1].startswith("COMPUTE") else (pm_sdma_opsel, pm_sdma_submit) opsel, submit = (pm_pm4_opsel, pm_pm4_submit) if q.arg[1].startswith("COMPUTE") else (pm_sdma_opsel, pm_sdma_submit)
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=Device[devs[0]], name=f"{q.arg[1]} opsel")) return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=ctx, name=f"{q.arg[1]} opsel"))
pm_lower = PatternMatcher([ pm_lower = PatternMatcher([
(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),)), encode_queue), (UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),)), encode_queue),

View file

@ -0,0 +1,104 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
LOG2E = 1.4426950408889634
@functools.cache
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
act = w1 * sig * w3
abs_a = (act < 0.0).where(-act, act)
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
@functools.cache
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
VEC = 8
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
nwg = n_elems // (THREADS_PER_WG * VEC)
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
lane = UOp.range(VEC, 2, AxisType.UNROLL)
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
e8v = e8[idx // BLK].cast(dtypes.float)
qscale = (127.0 - e8v).exp2()
ga = grad_aq[idx].cast(dtypes.float) * qscale
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
s = w1 * sig
sprime = sig * (1.0 + w1 * (1.0 - sig))
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
device = x_w1.device
rows, K = x_w1.shape
axis = x_w1.axis if isinstance(device, tuple) else None
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
fxn=_custom_silu_mul_bwd_mxfp8)
return (None, None, None, gx1.uop, gx3.uop)
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x_w1.shape
scale_K = K // BLK
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
return fp8_out, e8_out, si_out

View file

@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
step = THREADS_PER_WG // 2 step = THREADS_PER_WG // 2
while step: while step:
active = tid < step active = tid < step
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active) other = lds[(tid + step).valid(active)].load()
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier()) lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
step //= 2 step //= 2
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0]) amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])

View file

@ -0,0 +1,71 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.helpers import prod
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
@functools.cache
def _custom_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x:UOp) -> UOp:
rows, K = x.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x = x.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
x_f = x[idx].cast(dtypes.float)
abs_x = (x_f < 0.0).where(-x_f, x_f)
blk_max = abs_x.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (x_f * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
# pack the 4 e8 of this super-block into one uint32 (little-endian: byte sb), write transposed (sk4, row)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"quantize_mxfp8_{n_elems}", opts_to_apply=()))
def _quantize_mxfp8_fused_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x = kernel.src[1:]
device = x.device
rows, K = x.shape
scale_K = K // BLK
e8 = Tensor(e8_out, device=device).reshape(rows, scale_K)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, BLK).reshape(rows, K)
grad_x = (Tensor(gradient, device=device).float() * qscale).cast(dtypes.bfloat16)
return (None, None, None, grad_x.uop)
def quantize_mxfp8_fused(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x.dtype == dtypes.bfloat16, f"expected bf16, got {x.dtype}"
assert x.ndim == 2, f"expected 2d (rows, K), got {x.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x.shape
scale_K = K // BLK
axis = x.uop.axis if isinstance(x.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x, fxn=_custom_quantize_mxfp8, grad_fxn=_quantize_mxfp8_fused_bwd)
return fp8_out, e8_out, si_out

View file

@ -26,9 +26,7 @@ constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
using G = kittens::group<NUM_WARPS>; using G = kittens::group<NUM_WARPS>;
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr, __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr, fp8e8m0 *scale_A_ptr, fp8e8m0 *scale_B_ptr,
const uint32_t *__restrict__ scale_A_iter,
const uint32_t *__restrict__ scale_B_iter,
const uint8_t *__restrict__ a_e8_unused, const uint8_t *__restrict__ a_e8_unused,
const uint8_t *__restrict__ b_e8_unused) { const uint8_t *__restrict__ b_e8_unused) {
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K; constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
@ -40,15 +38,19 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
constexpr int k_iters = K / BLOCK_K; constexpr int k_iters = K / BLOCK_K;
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>; kittens::gl<fp8e8m0, k_iters * (M / BLOCK_ROW), 1, 16, 64> scale_A_gl{scale_A_ptr, nullptr, nullptr, nullptr, nullptr};
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>; kittens::gl<fp8e8m0, k_iters * (N / BLOCK_COL), 1, 16, 64> scale_B_gl{scale_B_ptr, nullptr, nullptr, nullptr, nullptr};
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>; using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>; using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
using ST_Scale = st<fp8e8m0, 16, 64, st_16x64_s>;
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
__shared__ ST_A As[2][2]; __shared__ ST_A As[2][2];
__shared__ ST_B Bs[2][2]; __shared__ ST_B Bs[2][2];
__shared__ uint8_t smem_scales[2048]; __shared__ ST_Scale scale_A_smem[2], scale_B_smem[2];
RT_A a; RT_A a;
RT_B b0, b1; RT_B b0, b1;
@ -81,54 +83,87 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
G::prefill_swizzled_offsets(As[0][0], A, sw_A); G::prefill_swizzled_offsets(As[0][0], A, sw_A);
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B); G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
const T *a_base = (const T *)&A[{0, 0, 0, 0}];
const T *b_base = (const T *)&B[{0, 0, 0, 0}];
const int a_row_stride = A.template stride<2>() * sizeof(T);
const int b_row_stride = B.template stride<2>() * sizeof(T);
i32x4 a_srd = make_srsrc(a_base, M * a_row_stride, a_row_stride);
i32x4 b_srd = make_srsrc(b_base, N * b_row_stride, b_row_stride);
const int wid = warpid() % NUM_WARPS;
constexpr int elem_per_warp = (16 / sizeof(T)) * kittens::WARP_THREADS;
uint32_t a_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t a_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t a_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t a_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
int a_row_h0 = warp_m * REG_M; int a_row_h0 = warp_m * REG_M;
int a_row_h1 = HALF_ROW + warp_m * REG_M; int a_row_h1 = HALF_ROW + warp_m * REG_M;
int b_row_h0 = warp_n * REG_N; int b_row_h0 = warp_n * REG_N;
int b_row_h1 = HALF_COL + warp_n * REG_N; int b_row_h1 = HALF_COL + warp_n * REG_N;
int tic = 0, toc = 1; int tic = 0, toc = 1;
int tic_scales = 0, toc_scales = 1;
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B); uint32_t a_lds[2][2] = {{a_lds_00, a_lds_01}, {a_lds_10, a_lds_11}};
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A); uint32_t b_lds[2][2] = {{b_lds_00, b_lds_01}, {b_lds_10, b_lds_11}};
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B);
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A); if constexpr (k_iters >= 6 && (k_iters % 2 == 0)) {
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][1]));
if (warp_m == 1) __builtin_amdgcn_s_barrier(); if (warp_m == 1) __builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt vmcnt(4)"); asm volatile("s_waitcnt vmcnt(4)");
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A); G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][0]));
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B); G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][0]));
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B); G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][1]));
asm volatile("s_waitcnt vmcnt(6)"); asm volatile("s_waitcnt vmcnt(6)");
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
#pragma unroll 2 G::load(scale_A_smem[0], scale_A_gl, {0 * tiles_M + block_row, 0, 0, 0});
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1) { G::load(scale_B_smem[0], scale_B_gl, {0 * tiles_N + block_col, 0, 0, 0});
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N); asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
int sa_idx = block_row, sb_idx = block_col;
#pragma unroll 2
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
if (k + 1 < k_iters) {
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0}); auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0); load(b0, bs0);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0}); auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0); load(a, as0);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A); G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
asm volatile("s_waitcnt lgkmcnt(8)"); asm volatile("s_waitcnt lgkmcnt(8)");
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0); fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0); fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
__builtin_amdgcn_s_setprio(2); __builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0); mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
__builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0}); auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
load(b1, bs1); load(b1, bs1);
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A); G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
asm volatile("s_waitcnt lgkmcnt(0)"); asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
@ -137,10 +172,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0}); auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
load(a, as1); load(a, as1);
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B); G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
asm volatile("s_waitcnt lgkmcnt(0)"); asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
@ -150,7 +184,7 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B); G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
asm volatile("s_waitcnt vmcnt(6)"); asm volatile("s_waitcnt vmcnt(6)");
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
@ -162,19 +196,24 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
{ // Epilogue k = k_iters - 2 { // Epilogue k = k_iters - 2
int k = k_iters - 2; int k = k_iters - 2;
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N); if (k + 1 < k_iters) {
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0); fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1); fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0); fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0}); auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0); load(b0, bs0);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0}); auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0); load(a, as0);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A); G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt lgkmcnt(0)"); asm volatile("s_waitcnt lgkmcnt(0)");
@ -217,17 +256,17 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
tic ^= 1; toc ^= 1; tic ^= 1; toc ^= 1;
tic_scales ^= 1; toc_scales ^= 1;
} }
{ // Epilogue k = k_iters - 1 (final) { // Epilogue k = k_iters - 1 (final)
int k = k_iters - 1; asm volatile("s_waitcnt vmcnt(0)");
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N); asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0); fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1); fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0); fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0}); auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0); load(a, as0);
@ -263,6 +302,38 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
} }
} else {
#pragma unroll 1
for (int kk = 0; kk < k_iters; kk++) {
G::load(As[0][0], A, {0, 0, block_row * 2, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][0]));
G::load(As[0][1], A, {0, 0, block_row * 2 + 1, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][1]));
G::load(Bs[0][0], B, {0, 0, block_col * 2, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][0]));
G::load(Bs[0][1], B, {0, 0, block_col * 2 + 1, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][1]));
G::load(scale_A_smem[0], scale_A_gl, {kk * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[0], scale_B_gl, {kk * tiles_N + block_col, 0, 0, 0});
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[0].data, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[0].data, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[0].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[0].data, b_row_h1);
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][0], {warp_n, 0}); load(b0, bs0);
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][1], {warp_n, 0}); load(b1, bs1);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[0][0], {warp_m, 0}); load(a, as0);
asm volatile("s_waitcnt lgkmcnt(0)");
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[0][1], {warp_m, 0}); load(a, as1);
asm volatile("s_waitcnt lgkmcnt(0)");
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
__builtin_amdgcn_s_barrier();
}
}
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n}); store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n}); store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n}); store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});

View file

@ -438,6 +438,35 @@ __device__ __forceinline__ T buffer_load_dwordx2(
return result; return result;
} }
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR_START >= 256) {
asm volatile("buffer_load_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
} else {
asm volatile("buffer_load_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
// Returns a u32x3 (ext_vector_type(3)). Note sizeof(u32x3) == 16 due to vector
// alignment padding, but the compiler still allocates exactly 3 consecutive VGPRs
// for the "=v" operand and emits buffer_load_dwordx3 v[N:N+2].
template<typename T = u32x3>
__device__ __forceinline__ T buffer_load_dwordx3(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
T result;
asm volatile("buffer_load_dwordx3 %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<int GPR_START> template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) { __device__ __forceinline__ void buffer_load_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR_START >= 256) { if constexpr (GPR_START >= 256) {
@ -465,6 +494,116 @@ __device__ __forceinline__ T buffer_load_dwordx4(
return result; return result;
} }
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_ubyte(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_sbyte(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_sbyte %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_ushort(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_sshort(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_sshort %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
// d16 variants merge the loaded sub-dword into the low (or high, for _hi) 16 bits
// of the destination register, leaving the other half untouched: read-modify-write.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_ubyte_d16(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_ubyte_d16 %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_ubyte_d16_hi(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_ubyte_d16_hi %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_sbyte_d16(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_sbyte_d16 %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_sbyte_d16_hi(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_sbyte_d16_hi %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_short_d16(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_short_d16 %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_short_d16_hi(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_short_d16_hi %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR> template<int GPR>
__device__ __forceinline__ void buffer_store_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) { __device__ __forceinline__ void buffer_store_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS // AGPRS
@ -519,6 +658,32 @@ __device__ __forceinline__ void buffer_store_dwordx2(
: "memory"); : "memory");
} }
template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
if constexpr (GPR_START >= 256) {
asm volatile("buffer_store_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
// VGPRS
} else {
asm volatile("buffer_store_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
template<typename T = u32x3>
__device__ __forceinline__ void buffer_store_dwordx3(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
asm volatile("buffer_store_dwordx3 %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR_START> template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) { __device__ __forceinline__ void buffer_store_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS // AGPRS
@ -546,6 +711,50 @@ __device__ __forceinline__ void buffer_store_dwordx4(
: "memory"); : "memory");
} }
// Stores the low 8 bits of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_byte(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
// Stores the high 16 bits' low 8 bits (bits [23:16]) of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_byte_d16_hi(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_byte_d16_hi %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
// Stores the low 16 bits of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_short(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
// Stores the high 16 bits of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_short_d16_hi(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_short_d16_hi %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR> template<int GPR>
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) { __device__ __forceinline__ void buffer_atomic_pk_add_bf16(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR >= 256) { if constexpr (GPR >= 256) {

View file

@ -52,6 +52,8 @@ __device__ __forceinline__ int laneid() { return threadIdx.x & 0x3f; }
using i32x2 = int32_t __attribute__((ext_vector_type(2))); using i32x2 = int32_t __attribute__((ext_vector_type(2)));
using u32x2 = uint32_t __attribute__((ext_vector_type(2))); using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
using i32x3 = int32_t __attribute__((ext_vector_type(3)));
using u32x3 = uint32_t __attribute__((ext_vector_type(3)));
using i32x4 = int32_t __attribute__((ext_vector_type(4))); using i32x4 = int32_t __attribute__((ext_vector_type(4)));
using u32x4 = uint32_t __attribute__((ext_vector_type(4))); using u32x4 = uint32_t __attribute__((ext_vector_type(4)));

View file

@ -0,0 +1,10 @@
/**
* @file
* @brief Aggregate header for CDNA4 ThunderKittens core include tree.
*/
#pragma once
#include "common/common.cuh"
#include "types/types.cuh"
#include "ops/ops.cuh"

View file

@ -688,4 +688,88 @@ __device__ inline static void store(ST &dst, const RT &src) {
} }
} }
template<ducks::rt::col_layout RT, ducks::st::all ST>
__device__ inline static void load(RT &dst, const ST &src, int col_offset) {
static_assert(RT::cols == ST::rows,
"col_layout load with col_offset: ST.rows must equal RT::cols");
static_assert(RT::width == 1, "col_layout load with col_offset: only width==1 supported");
using T2 = RT::dtype;
using U = ST::dtype;
using U2 = base_types::packing<U>::packed_type;
const int laneid = kittens::laneid();
const uint32_t src_ptr = (uint32_t)(uintptr_t)(&src.data[0]);
constexpr int subtile_bytes = ST::underlying_subtile_bytes;
if constexpr (std::is_same_v<T2, fp8e4m3_4>) {
const int block_id = laneid / 16;
const int l_within = laneid % 16;
const int tr_k_grp = l_within / 2;
const int m_half = l_within & 1;
#pragma unroll
for (int i = 0; i < RT::height; i++) {
const int m_col = col_offset + i * RT::base_tile_rows + m_half * 8;
const int subtile_base = block_id;
uint32_t addr = src_ptr
+ (uint32_t)(subtile_base * subtile_bytes)
+ src.swizzle({tr_k_grp, m_col});
asm volatile(
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[4]))
: "v"(addr), "i"(4 * subtile_bytes)
: "memory"
);
addr ^= 1088u;
asm volatile(
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2])),
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[6]))
: "v"(addr), "i"(4 * subtile_bytes)
: "memory"
);
}
} else if constexpr (std::is_same_v<U2, bf16_2> || std::is_same_v<U2, half_2>) {
const int row_offset = ((laneid % 16) / 4) + ((laneid / 16) * 4);
const int col_offset_lane = (laneid % 4) * 4;
#pragma unroll
for (int i = 0; i < RT::height; i++) {
const int m_col = col_offset + i * RT::base_tile_rows + col_offset_lane;
const uint32_t addr = src_ptr + src.swizzle({row_offset, m_col});
const uint32_t next_addr = src_ptr + src.swizzle({row_offset + 4, m_col});
if constexpr (RT::base_tile_stride == 8) {
asm volatile(
"ds_read_b64_tr_b16 %0, %2 offset:0\n"
"ds_read_b64_tr_b16 %1, %3 offset:0\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
"=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2]))
: "v"(addr), "v"(next_addr)
: "memory"
);
} else if constexpr (RT::base_tile_stride == 4) {
asm volatile(
"ds_read_b64_tr_b16 %0, %1 offset:0\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0]))
: "v"(addr)
: "memory"
);
} else {
static_assert(false, "Unsupported stride for col_layout load with col_offset");
}
}
} else {
static_assert(false, "Unsupported type for col_layout load with col_offset");
}
}
} // namespace kittens } // namespace kittens

View file

@ -201,10 +201,10 @@ __device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::col,
* @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode. * @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode.
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix. * @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
*/ */
template<ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16> template<ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col, D_shape> &d, __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col, D_shape> &d,
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a, const rt_base<MM_Operand_T, A_layout, A_shape> &a,
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b, // in row-major mode const rt_base<MM_Operand_T, B_layout, B_shape> &b,
const rt_base<float, ducks::rt_layout::col, C_shape> &c) { const rt_base<float, ducks::rt_layout::col, C_shape> &c) {
static_assert(std::is_same_v<D_shape, C_shape>, "D and C must have the same shape"); static_assert(std::is_same_v<D_shape, C_shape>, "D and C must have the same shape");
@ -243,6 +243,32 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
} }
} }
/**
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
*
* @param smem_scales LDS pointer to scale tile data.
* @param row_offset Starting row within the scale region (warp's tile offset).
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
*/
__device__ __forceinline__ fp8e8m0_4 pack_scales(
const fp8e8m0 *smem_scales, int row_offset) {
int lid = laneid();
int r16 = lid % 16;
int k_sub = lid / 16;
const fp8e8m0_4 *s4 = (const fp8e8m0_4 *)smem_scales;
fp8e8m0_4 w0 = s4[row_offset + 0 * 16 + r16];
fp8e8m0_4 w1 = s4[row_offset + 1 * 16 + r16];
fp8e8m0_4 w2 = s4[row_offset + 2 * 16 + r16];
fp8e8m0_4 w3 = s4[row_offset + 3 * 16 + r16];
fp8e8m0_4 sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
fp8e8m0_4 lo = __builtin_amdgcn_perm(w0, w1, sel);
fp8e8m0_4 hi = __builtin_amdgcn_perm(w2, w3, sel);
return lo | (hi << 16);
}
/** /**
* @brief Base dot product operation for row layout. * @brief Base dot product operation for row layout.
* *
@ -254,10 +280,10 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
* @param[in] b The second input rt_base<Operand_T, row_layout> matrix. * @param[in] b The second input rt_base<Operand_T, row_layout> matrix.
* @param[in] c The input rt_base<float, col_layout> accumulator matrix. * @param[in] c The input rt_base<float, col_layout> accumulator matrix.
*/ */
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T> template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
__device__ static inline void mma_ABt_base_scaled(rt_base<float, ducks::rt_layout::col, D_shape> &d, __device__ static inline void mma_ABt_base_scaled(rt_base<float, ducks::rt_layout::col, D_shape> &d,
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a, const rt_base<MM_Operand_T, A_layout, A_shape> &a,
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b, const rt_base<MM_Operand_T, B_layout, B_shape> &b,
const rt_base<float, ducks::rt_layout::col, C_shape> &c, const rt_base<float, ducks::rt_layout::col, C_shape> &c,
const fp8e8m0_4 *scale_a, const fp8e8m0_4 *scale_a,
const fp8e8m0_4 *scale_b) { const fp8e8m0_4 *scale_b) {
@ -441,7 +467,7 @@ __device__ static inline void mma_AB(D &d,
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode. * @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode.
* @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix. * @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix.
*/ */
template<ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C> template<ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
__device__ static inline void mma_ABt(D &d, __device__ static inline void mma_ABt(D &d,
const A &a, const A &a,
const B &b, // notice row and (M, K) instead of col and (K, M) const B &b, // notice row and (M, K) instead of col and (K, M)
@ -499,7 +525,7 @@ __device__ static inline void mma_ABt(D &d,
* @param[in] scale_a Pointer to the packed E8M0 scale for the A matrix. * @param[in] scale_a Pointer to the packed E8M0 scale for the A matrix.
* @param[in] scale_b Pointer to the packed E8M0 scale for the B matrix. * @param[in] scale_b Pointer to the packed E8M0 scale for the B matrix.
*/ */
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C> template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
__device__ static inline void mma_ABt_scaled(D &d, __device__ static inline void mma_ABt_scaled(D &d,
const A &a, const A &a,
const B &b, const B &b,

View file

@ -9,6 +9,5 @@
#include "maps.cuh" #include "maps.cuh"
#include "reductions.cuh" #include "reductions.cuh"
#include "mma.cuh" #include "mma.cuh"
#include "scales.cuh"
#include "assembly/tile.cuh" #include "assembly/tile.cuh"

View file

@ -205,6 +205,23 @@ struct st_8x32 {
} }
}; };
struct st_16x64 {
static constexpr int rows = 16;
static constexpr int cols = 64;
template<typename _T>
static constexpr int bytes_per_thread() {
static_assert(sizeof(_T) == 1 || sizeof(_T) == 2 || sizeof(_T) == 4, "Unsupported type");
return 16;
}
template<typename _T>
__device__ __forceinline__ static const uint32_t swizzle (int2 coord) {
static_assert(sizeof(_T) == 1 || sizeof(_T) == 2 || sizeof(_T) == 4, "Unsupported type");
return sizeof(_T) * (coord.x * cols + coord.y);
}
};
struct st_16x128 { struct st_16x128 {
static constexpr int rows = 16; static constexpr int rows = 16;
static constexpr int cols = 128; static constexpr int cols = 128;
@ -242,6 +259,7 @@ concept all = std::is_same_v<T, st_16x16> ||
std::is_same_v<T, st_16x32> || std::is_same_v<T, st_16x32> ||
std::is_same_v<T, st_32x16> || std::is_same_v<T, st_32x16> ||
std::is_same_v<T, st_8x32> || std::is_same_v<T, st_8x32> ||
std::is_same_v<T, st_16x64> ||
std::is_same_v<T, st_16x128>; std::is_same_v<T, st_16x128>;

View file

@ -75,6 +75,7 @@ using st_32x32_s = ducks::st_shape::st_32x32;
using st_16x32_s = ducks::st_shape::st_16x32; using st_16x32_s = ducks::st_shape::st_16x32;
using st_32x16_s = ducks::st_shape::st_32x16; using st_32x16_s = ducks::st_shape::st_32x16;
using st_8x32_s = ducks::st_shape::st_8x32; using st_8x32_s = ducks::st_shape::st_8x32;
using st_16x64_s = ducks::st_shape::st_16x64;
using st_16x128_s = ducks::st_shape::st_16x128; using st_16x128_s = ducks::st_shape::st_16x128;
} }

View file

@ -5,9 +5,12 @@
#pragma once #pragma once
#include "common/common.cuh" #if defined(KITTENS_CDNA4)
#include "types/types.cuh" #include "cdna4/includes.cuh"
#include "ops/ops.cuh" #elif defined(KITTENS_UDNA1)
#include "udna1/includes.cuh"
#endif
#include "pyutils/util.cuh" #include "pyutils/util.cuh"

View file

@ -1,74 +0,0 @@
/**
* @file
* @brief MXFP8 block scale loading and packing utilities.
*
* Provides functions for staging E8M0 block scales in LDS and packing them
* into fp8e8m0_4 registers for use with scaled MFMA instructions.
*/
#pragma once
#include "../../../../common/common.cuh"
namespace kittens {
/**
* @brief Load iteration-major packed E8M0 scales from global memory into LDS.
*
* First 256 threads each load one uint32 (4 packed E8M0 bytes) for A and B.
* A scales are placed at smem[0..1023], B scales at smem[1024..2047].
*
* @param smem_scales LDS buffer, must be >= 2048 bytes.
* @param scale_A_iter Iteration-major A scales: [k_iter * M + row] as uint32.
* @param scale_B_iter Iteration-major B scales: [k_iter * N + row] as uint32.
* @param block_m Starting row offset for A within the current block.
* @param block_n Starting row offset for B within the current block.
* @param k_iter Current K iteration index.
* @param M_dim M dimension of the matrix.
* @param N_dim N dimension of the matrix.
*/
__device__ __forceinline__ void load_scales_to_lds(
uint8_t *smem_scales,
const uint32_t *__restrict__ scale_A_iter,
const uint32_t *__restrict__ scale_B_iter,
int block_m, int block_n, int k_iter, int M_dim, int N_dim) {
int tid = threadIdx.x;
if (tid < 256) {
uint32_t sa = scale_A_iter[k_iter * M_dim + block_m + tid];
uint32_t sb = scale_B_iter[k_iter * N_dim + block_n + tid];
*(uint32_t *)&smem_scales[tid * 4] = sa;
*(uint32_t *)&smem_scales[1024 + tid * 4] = sb;
}
}
/**
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
*
* Each lane (r16 = laneid%16, k_sub = laneid/16) loads 4 dwords from
* consecutive 16-row groups, then uses v_perm_b32 to extract byte k_sub
* from each, producing the packed scale register for scaled MFMA.
*
* @param smem_scales LDS pointer to scale region.
* @param lds_base Byte offset within smem_scales (0 for A, 1024 for B).
* @param row_offset Starting row within the scale region (warp's tile offset).
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
*/
__device__ __forceinline__ fp8e8m0_4 pack_scales(
const uint8_t *smem_scales, int lds_base, int row_offset) {
int lid = laneid();
int r16 = lid % 16;
int k_sub = lid / 16;
const uint32_t *s32 = (const uint32_t *)(smem_scales + lds_base);
uint32_t w0 = s32[row_offset + 0 * 16 + r16];
uint32_t w1 = s32[row_offset + 1 * 16 + r16];
uint32_t w2 = s32[row_offset + 2 * 16 + r16];
uint32_t w3 = s32[row_offset + 3 * 16 + r16];
uint32_t sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
uint32_t lo = __builtin_amdgcn_perm(w0, w1, sel);
uint32_t hi = __builtin_amdgcn_perm(w2, w3, sel);
return (fp8e8m0_4)(lo | (hi << 16));
}
} // namespace kittens

Some files were not shown because too many files have changed in this diff Show more