Compare commits

...

110 commits

Author SHA1 Message Date
chenyu
687ade119e
IMAGE hand_coded_optimizations update (#16720) 2026-06-23 21:55:28 -04:00
George Hotz
0a8e61d0c5
switch to the new memory coaleser [pr] (#16716)
* switch to the new memory coalese

* move that stuff

* copy in allowed length logic

* mulitple buffers

* new coalese is better

* fine

* earlier

* fixes

* work

* work

* valid

* stack on index const
2026-06-23 18:03:48 -07:00
wozeparrot
dfea9e7994
llama: fused silu mul quantize mxfp8 (#16704) 2026-06-23 16:59:50 -07:00
chenyu
ce87d80911
better _drop_valid_stmts [pr] (#16719)
also dropped the unused is_increasing
2026-06-23 19:35:01 -04:00
George Hotz
5a2b3b7b06
early dtype decomp (#16718)
* early dtype decomp

* simplify

* cleanup

* that goes there

* doing too much

* stupid symbolic rules
2026-06-23 16:07:20 -07:00
Christopher Milan
116045cc8e
ci: remove tensorflow from testoptim (#16717) 2026-06-23 18:11:48 -04:00
nimlgen
7c1d0b6d9a
hcq2: use shrink(bitcast) (#16713)
* hcq2: use shrink(bitcast)

* x
2026-06-23 18:11:39 +03:00
George Hotz
c9dc1d63cc
small changes from new codegen (#16712)
* small changes from new codegen

* shrink/flatten
2026-06-22 17:44:15 -07:00
Christopher Milan
da98fae9e1
ci: try parallelizing tc tests (#16710) 2026-06-22 20:43:32 -04:00
chenyu
15988b5941
contiguous to mixin and cleanups [PR] (#16711) 2026-06-22 20:18:18 -04:00
Christopher Milan
cbfcf36e44
ci: remove generate_dataset and CL misc (#16709) 2026-06-22 18:01:07 -04:00
nimlgen
f9c8c697d6
hcq2: drop args after inner deps (#16708) 2026-06-22 23:26:11 +03:00
chenyu
0138480910
dropout and scaled_dot_product_attention to mixin (#16707) 2026-06-22 16:17:45 -04:00
chenyu
33b635d23a
Tensor.train -> TRAINING [PR] (#16705)
* Tensor.train -> TRAINING [PR]

* doc
2026-06-22 15:13:22 -04:00
chenyu
625d8bbd0d
TRAINING ContextVar (#16703) 2026-06-22 13:03:08 -04:00
wozeparrot
fe9b19b12d
llama: more mp mem fixes (#16701)
* llama: more mp mem fixes

* clean: unused

* fix: batch
2026-06-22 10:54:35 -04:00
chenyu
267af9c601
full_like to CreationMixin [PR] (#16702) 2026-06-22 09:33:23 -04:00
chenyu
97da54b9d6
more method to CreationMixin [PR] (#16698) 2026-06-22 00:01:22 -04:00
chenyu
fd0dc40689
clean up CreationMixin and DTypeMixin [PR] (#16697) 2026-06-21 21:13:40 -04:00
chenyu
2d8b802958
contiguous in wino conv (#16696)
also fixed test_counters
2026-06-21 17:11:46 -04:00
chenyu
ba1d3baae8
masked_select and nonzero to mixin [PR] (#16695)
with a .data stub
2026-06-21 15:10:44 -04:00
chenyu
d80a41d559
some rand method to RandMixin [PR] (#16693) 2026-06-21 12:16:51 -04:00
wozeparrot
5164c21b44
gemm: keep shape thru mxfp8 quantize (#16692) 2026-06-20 22:28:53 -07:00
chenyu
58ff75272e
const_like and invalids to mixin [PR] (#16690)
* const_like and invalids to mixin [PR]

* empty_like

* einsum

* type
2026-06-21 00:02:29 -04:00
chenyu
b50da5c205
move Tensor.__getitem__ to mixin [PR] (#16689) 2026-06-20 22:01:45 -04:00
chenyu
4618d27129
final const cleanups [PR] (#16688) 2026-06-20 21:38:16 -04:00
chenyu
9ae0a93d0e
more const cleanups [PR] (#16682) 2026-06-20 20:41:43 -04:00
George Hotz
30830850a9
small changes from new codegen (#16681)
* small changes from new codegen

* revert that
2026-06-19 18:29:01 -07:00
chenyu
8b07cca9f7
invalid clone try 3+ [PR] (#16679) 2026-06-19 20:13:52 -04:00
Christopher Milan
b2199c54a3
ci: update actions/cache/restore to suppress warnings (#16680) 2026-06-19 18:27:52 -04:00
Christopher Milan
1822eed8d3
ci: only test models on cpu (#16678) 2026-06-19 18:16:59 -04:00
wozeparrot
bba611bb59
gemm: fix mxfp8 on more shapes (#16677) 2026-06-19 13:28:53 -07:00
chenyu
67c3e589a1
invalid clone tests and prereq [PR] (#16675) 2026-06-19 13:20:43 -04:00
George Hotz
649971f02a
remove DEFINE_LOCAL and DEFINE_REG (gpt) (#16673)
* remove define_local and define_reg (gpt)

* fix precommit

* cleanups

* regalloc fix

* cleanups 2
2026-06-19 10:07:50 -07:00
George Hotz
b05bea81ce
x86 cleanups (fable) [pr] (#16591)
* x86 cleanups (fable)

* support shrink

* remove ptr dtype

* move that

* is_lane helper

* Revert "is_lane helper"

This reverts commit ea4571254d.
2026-06-19 09:04:51 -07:00
nimlgen
97c2e7a3d9
spec: add getaddr (#16674) 2026-06-19 15:37:33 +03:00
George Hotz
d7b10c69bc
update placeholder to not create DEFINE_LOCAL/DEFINE_REG (#16671)
* update placeholder to not create DEFINE_LOCAL/DEFINE_REG

* simpler

* define_local
2026-06-18 21:21:06 -07:00
Christopher Milan
091ec8d10d
use tinygrad.llm in benchmarks (#16670) 2026-06-19 00:03:57 -04:00
George Hotz
925c49ce99
use placeholder in tests (#16672) 2026-06-18 20:51:44 -07:00
wozeparrot
05249466ed
llama: fused quantize mxfp8 (#16667) 2026-06-18 16:02:28 -07:00
George Hotz
4a4b6956df
remove DEFINE_VAR from codebase (gpt) (#16666)
* remove DEFINE_VAR from codebase

* junk

* remove junk
2026-06-18 15:33:50 -07:00
nimlgen
eda0a402d1
hcq2: fix multi (#16661) 2026-06-18 22:56:49 +03:00
George Hotz
5989d0b150
remove DEFINE_VAR try 2 (#16651)
* remove DEFINE_VAR try 2

* param

* null index

* fix fuzzing

* fixes

* no gather neg params

* param is just Irreducible

* fixes

* skip stack

* need to filter slots there
2026-06-18 12:34:25 -07:00
wozeparrot
d37248c3ec
gemm: fix mxfp8 on odd shapes (#16664) 2026-06-18 12:03:59 -07:00
chenyu
d74f488376
clean up _function.depth properly [PR] (#16663) 2026-06-18 14:10:22 -04:00
chenyu
d7a1022188
minor function.py cleanups [PR] (#16662) 2026-06-18 13:36:48 -04:00
qazal
924bece1d5
remove some old scheduler tests (#16660) 2026-06-18 22:15:00 +09:00
qazal
b753fb5e4c
viz: view source working even if compile failed (#16657)
* failing test

* hard

* ret_dict

* switch to _data for tests too

* update sqtt

* start work

* Ops.LINEAR looks good

* baseline with depth works

* support depth

* types

* @needs_tracked_pm

* update, marg can error too

* unwrap_or goes to many more places

* move things to soft_err

* soft_err everywhere needed

* diff cleanup

* use list

* rewrite it

* change

* update depth number

* small comment change
2026-06-18 17:34:53 +09:00
qazal
31094a794f
viz: data not sent to client side starts with _ (#16659)
* ret_dict

* switch to _data for tests too

* update sqtt

* rename to filter_keys

* not cfg
2026-06-18 15:25:22 +09:00
qazal
1720987dc7
include exception name in Ops.REWRITE_ERROR (#16658) 2026-06-18 14:52:48 +09:00
wozeparrot
bed0c343a3
faster mxfp8 gemm (#16656) 2026-06-17 22:35:36 -07:00
Christopher Milan
e0fe6e542e
ci: fewer pydeps (#16654) 2026-06-17 22:52:14 -04:00
chenyu
a74b7130b4
Revert "invalid clone try 2 [PR] (#16648)" (#16653)
This reverts commit 1bd4551ee1.
2026-06-17 22:05:30 -04:00
chenyu
df015ad541
remove many type ignores [PR] (#16652) 2026-06-17 21:38:45 -04:00
chenyu
1bd4551ee1
invalid clone try 2 [PR] (#16648) 2026-06-17 19:44:35 -04:00
George Hotz
53a1226a49
STACK 0 is dtype void (#16650)
* STACK 0 is dtype void

* spec for stack

* fix gemm group + END shape

* bump
2026-06-17 16:28:32 -07:00
George Hotz
aef85ddc4d
addrspace special/range (#16647)
* addrspace special/range

* just include indexing

* define var is alu

* bring old ignore indexing back

* mults to fix

* fixes

* ALU

* fixes
2026-06-17 15:57:37 -07:00
chenyu
1e08c0a07c
remove NOOP from AFTER with multiple srcs (#16646) 2026-06-17 14:35:02 -04:00
chenyu
1acc40600d
indexing an after with all fully invalid stores is invalid (#16643)
* indexing an after with all fully invalid stores is invalid

* typing cast
2026-06-17 11:06:36 -04:00
nimlgen
0f0c622086
hcq2: multi folders (#16642) 2026-06-17 15:20:25 +03:00
George Hotz
be9b570cb2
late numbering of var params (#16640)
* do_number_param

* fix sort order in x86

* we don't want this
2026-06-17 00:36:08 -07:00
qazal
c7055d658f
viz: only store kernel info (#16641) 2026-06-17 16:21:57 +09:00
George Hotz
d631716858
remove const without STACK (#16639)
* remove const without STACK

* fix GEP rewrite

* fix null tests

* fix openpilot regression

* it's 10 in CI
2026-06-16 21:25:42 -07:00
wozeparrot
36f6d1b064
gemm: fix bf16 atb for mp sharding (#16637) 2026-06-16 15:58:47 -07:00
qazal
1cb6b88d37
viz: show contents of vconst (#16636)
* failing test

* render vconst

* simpler test

* reorder
2026-06-17 02:31:03 +09:00
nimlgen
5644605d92
hcq2: pack bufs (#16635)
* hcq2: pack bufs

* x
2026-06-16 18:58:16 +03:00
chenyu
d5d59a2be6
remove dead rangeify rules [PR] (#16634) 2026-06-16 10:03:08 -04:00
chenyu
f0998e9bba
Revert "invalid clone is anonymous buffer" (#16613) (#16633) 2026-06-16 08:27:48 -04:00
qazal
7d2b0b697d
simple failing test for invalid extra E kernel (#16632)
* simple failing test for invalid extra E kernel

* 6 kernels
2026-06-16 17:57:44 +09:00
wozeparrot
70cac72781
llama: realize weight init (#16623) 2026-06-15 23:00:19 -07:00
Christopher Milan
443f976305
fix buffer overrun in dcache_flush (#16630) 2026-06-15 23:26:32 -04:00
chenyu
aa2bef24a8
no_vectorized_alu in cstyle does nothing now [PR] (#16631) 2026-06-15 23:07:20 -04:00
chenyu
efd03d7153
invalid clone is anonymous buffer [PR] (#16613) 2026-06-15 20:14:26 -04:00
nimlgen
4a0488ae97
hcq2: optims (#16624)
* hcq2: optims

* x
2026-06-15 23:58:28 +03:00
George Hotz
41aa2fe119
test_gemm needs .clone() on eye (#16629) 2026-06-15 12:48:27 -07:00
qazal
10bdb9c9d0
viz: check node exists before anchoring zoom (#16627) 2026-06-15 21:03:24 +09:00
qazal
f998b9930a
fp8 gemm inv_scale in epilogue (#16625)
* fuse scale

* remove python inv_scale

* more inv_scale removal

* more cleanups

* cleaner

* diff polish

* work

* rename

* simpler

* simpler

* compute

* c

* Revert "c"

This reverts commit 8941fec7ca.

* Revert "compute"

This reverts commit 9db573a6d3.

* Revert "simpler"

This reverts commit 910ad33f87.

* Revert "simpler"

This reverts commit bf75d235a1.

* s_g

* update types

* less diff noise

* remove
2026-06-15 18:44:41 +09:00
nimlgen
4dc51aff6e
hcq2: jit (#16621)
* hcq2: jit

* x

* x

* minor
2026-06-15 06:35:35 +07:00
chenyu
2adedf5ccb
clean up fold_divmod_general [pr] (#16622)
genralized fold_binary_numerator in fold_divmod_congruence
2026-06-14 17:15:52 -04:00
George Hotz
a6d7fb9d4d
only SHRINK for non scalar access (#16619) 2026-06-14 10:08:37 -07:00
George Hotz
b1fb39502d delete that test 2026-06-14 09:42:58 -07:00
chenyu
2e181f4259
simpler cancel_divmod [PR] (#16616) 2026-06-14 11:41:31 -04:00
chenyu
5d5ead78da
inline unique_const in invalids [PR] (#16612) 2026-06-13 10:14:32 -04:00
Sieds Lykles
b00dd754a9
Remove if-condition from nested div rule [pr] (#16611)
* add rules and test

* trigger [pr]
2026-06-13 15:47:21 +02:00
nimlgen
5a9227b30a
hcq2: rebind var params (#16610) 2026-06-13 14:55:52 +03:00
nimlgen
8efc8d064f
unique based on opaque in from_buffer (#16609) 2026-06-13 14:31:58 +03:00
nimlgen
c43091a464
fix missing cast in cstyle (#16608)
* fix missing cast in cstyle

* x

* x
2026-06-13 10:04:06 +03:00
qazal
2e77bd01db
fp8 gemm cleanup (#16607) 2026-06-13 13:17:32 +09:00
Christopher Milan
bcdb988df0
split comma benchmark, dsp on c4 [PR] (#16598) 2026-06-12 23:26:05 -04:00
George Hotz
6b8fdfe4ca
alu addrspace is where the math happens (#16606)
* alu addrspace

* fix cstyle/llvm

* on ptx, reg+alu are the same thing
2026-06-12 20:01:28 -07:00
wozeparrot
67a4f129c2
llama: fix bf16 gemm oob (#16603) 2026-06-12 19:43:05 -07:00
Christopher Milan
8862c7549c
new-style dcache_flush (#16602) 2026-06-12 22:25:08 -04:00
chenyu
9e72a6b376
more indexing cleanup [PR] (#16600) 2026-06-12 21:33:47 -04:00
chenyu
aa32d309db
fix rangeify indexing for pad/reduce (#16599) 2026-06-12 20:26:15 -04:00
George Hotz
96b86aad7b
move new style transform up more (#16593)
* move new style transform up more

* pm_move_gates_from_index works on new style
2026-06-12 17:20:12 -07:00
chenyu
a35964493e
UPat method cleanups [PR] (#16596) 2026-06-12 17:22:54 -04:00
chenyu
3036b15ed9
remove Tensor.ufix [PR] (#16594)
* remove Tensor.ufix [PR]

* inline _ufix_keep_dtype
2026-06-12 14:40:28 -04:00
qazal
b2e95b2db3
rangeify: no copies for write+read of same slice (#16585)
* failing test

* cleaner failing tests

* assign and read of same slice shouldn't create copies

* err in the changes

* shrink with no overlapping regions in dest is fine
2026-06-13 02:19:47 +09:00
George Hotz
833cb37574
move up new style transform (#16592)
* simpler names

* move up new style transform

* fix that rule
2026-06-12 10:13:37 -07:00
George Hotz
51100d2c5c
new style cleanups (#16584)
* spec tighten

* revert

* lin fix

* lin fix

* needed for x86

* revert
2026-06-12 08:10:38 -07:00
Philip Sinitsin
76c10cd635
jit: don't memplan buffers reachable from live tensors (#16588)
The memory planner was suballocating BUFFERs created during JIT capture that are still referenced by external lazy tensor graphs, like the .grad tensors assigned by backward(). The replay then only writes the arena slices, so realizing such a tensor after the call reads freshly allocated memory and silently returns zeros. Hold every BUFFER reachable from a live Tensor instead of only the parameters of the return value; true internals are still planned. Fixes #16571.
2026-06-12 17:51:54 +03:00
nimlgen
2bfdf85f87
hcq2: move pre bufferize (#16589)
* hcq2: move pre bufferize

* x
2026-06-12 16:11:59 +03:00
nimlgen
fb74f75485
var params sort after global params (#16590) 2026-06-12 14:33:15 +03:00
qazal
4d34590b7d
llama: less E kernels (#16517) 2026-06-12 19:49:25 +09:00
qazal
12f4cf0e49
rename amd/test_custom_kernel.py to test_asm_kernel (#16586)
* rename amd/test_custom_kernel.py to test_asm_kernel

* update
2026-06-12 16:11:01 +09:00
wozeparrot
e770805d21
llama: mxfp8 (#16574) 2026-06-11 22:15:24 -07:00
George Hotz
b8aec4cce7
port x86 to new_style (fable slop) and now everything is new style (#16581)
* port x86 to new_style (fable slop)

* don't change ops

* port NIR to new_style (fable)

* lil cleanup

* fix tests, and remove new_style
2026-06-11 21:09:34 -07:00
chenyu
762f50bd52
move gradient.py to mixin/ [PR] (#16583) 2026-06-11 23:58:21 -04:00
chenyu
a2cec397f3
UOp cast and bitcast takes DTypeLike [PR] (#16582)
* UOp cast and bitcast takes DTypeLike [PR]

match Tensor

* fix type
2026-06-11 22:38:54 -04:00
George Hotz
b97e3e01e3
port NIR to new_style (fable) (#16580)
* port NIR to new_style (fable)

* lil cleanup
2026-06-11 18:47:30 -07:00
288 changed files with 15741 additions and 2130 deletions

View file

@ -80,7 +80,7 @@ runs:
- name: Cache Python packages (PR)
if: github.event_name == 'pull_request'
id: restore-venv-pr
uses: actions/cache/restore@v4
uses: actions/cache/restore@v5
with:
path: /tmp/.uv-cache
key: uv-${{ runner.os }}-${{ runner.arch }}-python-${{ inputs.python-version }}-${{ inputs.deps }}-${{ inputs.pydeps }}-${{ env.CACHE_VERSION }}
@ -96,7 +96,7 @@ runs:
- name: Cache downloads (PR)
if: inputs.key != '' && github.event_name == 'pull_request'
uses: actions/cache/restore@v4
uses: actions/cache/restore@v5
with:
path: ${{ runner.os == 'Linux' && '~/.cache/tinygrad/downloads/' || '~/Library/Caches/tinygrad/downloads/' }}
key: downloads-${{ github.job }}-${{ inputs.key }}-${{ env.CACHE_VERSION }}
@ -203,7 +203,7 @@ runs:
- name: Cache apt (PR)
if: runner.os == 'Linux' && (inputs.opencl == 'true' || inputs.amd == 'true' || inputs.webgpu == 'true' || inputs.llvm == 'true' || inputs.qemu == 'true') && github.event_name == 'pull_request'
uses: actions/cache/restore@v4
uses: actions/cache/restore@v5
with:
path: /var/cache/apt/archives/
key: ${{ runner.os }}-${{ runner.arch }}-apt-${{ steps.apt-pkgs.outputs.hash }}-${{ env.CACHE_VERSION }}

View file

@ -99,7 +99,6 @@ jobs:
ln -s ~/tinygrad/extra/disassemblers/applegpu extra/disassemblers/applegpu
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
- name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging'
@ -134,32 +133,10 @@ jobs:
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
- name: Fuzz Padded Tensor Core GEMM
run: DEV=METAL M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
- name: Run LLaMA
run: |
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run quantized LLaMA
run: |
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
- name: Run quantized LLaMA3
run: |
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
#- name: Run LLaMA 7B on 4 (virtual) GPUs
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run OLMoE
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
- name: Run llama3.2
run: BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
- name: Run olmoe
run: BENCHMARK_LOG=olmoe JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 -m tinygrad.llm -m olmoe --benchmark --warmup
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
@ -235,9 +212,6 @@ jobs:
- name: Symlink models and datasets
run: |
mkdir -p weights
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
mkdir -p extra/datasets
ln -s /raid/datasets/imagenet extra/datasets/imagenet
@ -279,36 +253,16 @@ jobs:
# TODO: too slow
# - name: Run SDXL
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA
run: |
BENCHMARK_LOG=llama_nojit DEV=NV JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama DEV=NV JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run llama3.2
run: DEV=NV BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
- name: Run qwen3.5
run: DEV=NV BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=NV JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run quantized LLaMA3
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
# - name: Run LLaMA-3 8B on 6 GPUs
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-2 70B
# run: DEV=NV CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral DEV=NV CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit DEV=NV JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 DEV=NV JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half DEV=NV HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam DEV=NV HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- uses: actions/upload-artifact@v7
with:
name: Speed (NVIDIA)
@ -402,10 +356,7 @@ jobs:
run: |
mkdir -p weights
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
ln -s /raid/weights/mixtral-8x7b-32kseqlen weights/mixtral-8x7b-32kseqlen
ln -s /raid/weights/LLaMA-2 weights/LLaMA-2
ln -s /raid/weights/LLaMA-3 weights/LLaMA-3
mkdir -p extra/datasets
ln -s /raid/datasets/imagenet extra/datasets/imagenet
@ -458,18 +409,10 @@ jobs:
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 DEV=AMD python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
- name: Run SDXL
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 DEV=AMD python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA 7B
run: |
BENCHMARK_LOG=llama_nojit DEV=AMD JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama DEV=AMD JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA 7B with BEAM
run: BENCHMARK_LOG=llama_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run llama3.2
run: DEV=AMD BENCHMARK_LOG=llama32_3b-f16 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 -m tinygrad.llm -m llama3.2:3b-f16 --benchmark --warmup
- name: Run qwen3.5
run: DEV=AMD BENCHMARK_LOG=qwen35_35b-a3b JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 -m tinygrad.llm -m qwen3.5:35b-a3b --benchmark --warmup
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu DEV=AMD JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-3 8B on 6 GPUs
@ -478,16 +421,6 @@ jobs:
# run: sudo modprobe amdgpu
# - name: Run LLaMA-2 70B
# run: DEV=AMD CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral DEV=AMD python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit DEV=AMD JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 DEV=AMD JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half DEV=AMD HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam DEV=AMD HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run process replay tests
uses: ./.github/actions/process-replay
@ -594,10 +527,10 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
testqualcommbenchmark:
name: comma Benchmark
testcommalatest:
name: comma Benchmark (0.11.0)
runs-on: [self-hosted, Linux, comma]
timeout-minutes: 30
timeout-minutes: 10
defaults:
run:
shell: bash -e -o pipefail {0}
@ -622,6 +555,27 @@ jobs:
run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3.2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.11.0 dmonitoring
run: BENCHMARK_LOG=openpilot_0_11_0_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/dmonitoring_model.onnx
- name: Run process replay tests
uses: ./.github/actions/process-replay
testcommaold:
name: comma Benchmark (0.10.1)
runs-on: [self-hosted, Linux, comma]
timeout-minutes: 10
defaults:
run:
shell: bash -e -o pipefail {0}
if: github.repository_owner == 'tinygrad'
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging'
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: DEBUG=2 openpilot compile3 0.10.1 driving_vision
run: PYTHONPATH="." DEBUG=2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot compile3 0.10.1 driving_vision
@ -630,14 +584,44 @@ jobs:
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3.2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.10.1 dmonitoring
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
- name: Run process replay tests
uses: ./.github/actions/process-replay
testqualcommdsp:
name: DSP Benchmark
runs-on: [self-hosted, Linux, comma4]
timeout-minutes: 5
defaults:
run:
shell: bash -e -o pipefail {0}
if: github.repository_owner == 'tinygrad'
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging'
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: Checkout Code
uses: actions/checkout@v6
- name: setup staging db
if: github.ref == 'refs/heads/update_benchmark_staging'
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: benchmark MobileNetV2 on DSP
run: |
# generate quantized weights
ln -s /data/home/tiny/tinygrad/extra/datasets/imagenet extra/datasets/imagenet
ln -s /data/home/tiny/tinygrad/testsig-*.so .
PYTHONPATH=. CC=clang-19 DEV=CPU QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
PYTHONPATH=. DEV=CPU QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
# benchmark on DSP with NOOPT=1, the devectorizer has issues
PYTHONPATH=. CC=clang-19 DEV=DSP NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
PYTHONPATH=. DEV=DSP NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
- name: Run process replay tests
uses: ./.github/actions/process-replay

View file

@ -133,46 +133,26 @@ jobs:
run: SKIP_SLOW_TEST=1 DEV=PYTHON python3 -m pytest -n=auto test/backend/test_dtype.py test/backend/test_dtype_alu.py test/backend/test_ops.py test/backend/test_uops.py test/backend/test_symbolic_ops.py test/backend/test_renderer_failures.py::TestRendererFailures --durations=20
- name: Test IMAGE support
run: IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm TestOps.test_simple_conv2d
- name: Test emulated METAL tensor cores
- name: Test emulated tensor cores
env:
DEV: 'PYTHON::METAL'
DEBUG: 2
N: 64
CNT: 1
SHOULD_USE_TC: 1
run: |
DEBUG=2 python3 test/backend/test_ops.py TestOps.test_big_gemm
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD tensor cores
env:
DEV: 'PYTHON::gfx1100'
parallel -k --link --tagstring '[{1}]' '{2} python3 ./extra/gemm/simple_matmul.py' \
::: metal gfx950 gfx1100 gfx1100_acchalf gfx1201 gfx1201_acchalf sm_75 sm_80_half sm_80_tf32 \
::: 'DEV=PYTHON::METAL' 'DEV=PYTHON::gfx950 HALF=1 ACC_HALF=0' \
'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1100 HALF=1 ACC_HALF=1 ATOL=1e-3' \
'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=0' 'DEV=PYTHON::gfx1201 HALF=1 ACC_HALF=1 ATOL=1e-3' \
'DEV=PYTHON::sm_75 HALF=1' 'DEV=PYTHON::sm_80 HALF=1' 'DEV=PYTHON::sm_80 ALLOW_TF32=1'
- name: Run additional tensor core tests
run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD MFMA tensor cores
env:
DEV: 'PYTHON::gfx950'
run: |
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated AMD RDNA4 tensor cores
env:
DEV: 'PYTHON::gfx1201'
run: |
DEBUG=2 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
DEBUG=2 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated CUDA tensor cores
run: |
DEBUG=2 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
DEV=PYTHON::METAL python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEV=PYTHON::gfx1100 python3 -m pytest -nauto test/opt/test_tensor_cores.py test/null/test_uops_stats.py::TestUOpsStatsMatmulHalf
DEV=PYTHON::gfx950 python3 -m pytest -nauto test/opt/test_tensor_cores.py
DEV=PYTHON::gfx1201 python3 -m pytest -nauto test/opt/test_tensor_cores.py
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test device flop counts
run: |
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
linter:
@ -218,14 +198,15 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: unittest-13
pydeps: "pillow ftfy regex pre-commit"
deps: testing_unit
llvm: 'true'
amd: 'true'
- name: Run NULL backend tests
run: DEV=NULL python -m pytest -n=auto test/null/ --durations=20
- name: Run targeted tests on NULL backend
run: DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
run: |
DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
DEV=NULL VIZ=1 python3 -m pytest -n=auto test/null/test_viz.py
# TODO: too slow
# - name: Run SDXL on NULL backend
# run: DEV=NULL DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
@ -249,7 +230,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: unittest-13
pydeps: "pillow ftfy regex pre-commit"
pydeps: "pre-commit"
deps: testing_unit
llvm: 'true'
- name: Run pre-commit test hooks
@ -266,13 +247,6 @@ jobs:
run: python3 test/external/external_benchmark_schedule.py
- name: Run process replay tests
uses: ./.github/actions/process-replay
- name: Regen dataset on test_tiny
run: |
test/external/process_replay/reset.py
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
python extra/optimization/extract_dataset.py
gzip -c /tmp/sops > extra/datasets/sops.gz
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
- name: Repo line count < 25000 lines
run: MAX_LINE_COUNT=25000 python sz.py
@ -337,31 +311,6 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
testgpumisc:
name: CL Misc tests
runs-on: *linux
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: gen-dataset
deps: testing
opencl: 'true'
- name: Generate Dataset
run: DEV=CL extra/optimization/generate_dataset.sh
- name: Run Kernel Count Test
run: DEV=CL python -m pytest -n=auto test/external/external_test_opt.py
- name: Run fused optimizer tests
run: DEV=CL FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py test/backend/test_optim.py -k "not muon"
- name: Upload artifact
uses: actions/upload-artifact@v7
with:
name: sops.gz
path: /tmp/sops.gz
testopenpilot:
name: openpilot Compile Tests
runs-on: *linux
@ -378,7 +327,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness)
run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
@ -421,7 +370,6 @@ jobs:
with:
key: optim
deps: testing
pydeps: "tensorflow==2.19"
opencl: 'true'
#- name: Test Optimization Helpers
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
@ -430,7 +378,7 @@ jobs:
- name: Test Beam Search
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Test MLPerf stuff
run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
- name: DEV=NULL beautiful_mnist_multigpu
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
- name: Test Bert training
@ -468,7 +416,7 @@ jobs:
# ****** Models Tests ******
testmodels:
name: Models (llvm+cpu+gpu)
name: Models
runs-on: *linux
timeout-minutes: 15
steps:
@ -479,34 +427,12 @@ jobs:
with:
key: models
deps: testing
opencl: 'true'
llvm: 'true'
- name: Test models (llvm)
run: DEV=CPU:LLVM python -m pytest -n=auto test/models --durations=20
- name: Test models (opencl)
run: DEV=CL python -m pytest -n=auto test/models --durations=20
- name: Test models (cpu)
run: DEV=CPU python -m pytest -n=auto test/models --durations=20
- name: Run process replay tests
uses: ./.github/actions/process-replay
testmetalmodels:
name: Models (metal)
runs-on: &macos macos-26
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v6
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: metal
deps: testing
- name: Test models (Metal)
run: DEV=METAL python -m pytest -n=auto test/models --durations=20
- name: Test LLaMA compile speed
run: DEV=METAL python test/external/external_test_speed_llama.py
# ****** Feature Tests ******
testdsp:
@ -715,7 +641,7 @@ jobs:
unittestmacos:
name: MacOS (unit)
runs-on: *macos
runs-on: &macos macos-26
timeout-minutes: 20
steps:
- name: Checkout Code

View file

@ -72,7 +72,7 @@ As it turns out, 90% of what you need for neural networks are a decent autograd/
Throw in an optimizer, a data loader, and some compute, and you have all you need.
```python
from tinygrad import Tensor, nn
from tinygrad import Tensor, nn, Context
class LinearNet:
def __init__(self):
@ -86,7 +86,7 @@ optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # replace with real mnist dataloader
with Tensor.train():
with Context(TRAINING=1):
for i in range(10):
optim.zero_grad()
loss = model(x).sparse_categorical_crossentropy(y).backward()

View file

@ -165,13 +165,14 @@ from extra.datasets import fetch_mnist
Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64.
We use `with Tensor.train()` to set the internal flag `Tensor.training` to `True` during training.
We use `with Context(TRAINING=1)` to set the internal flag `Tensor.training` to `True` during training.
Upon exit, the flag is restored to its previous value by the context manager.
```python
from tinygrad import Context
X_train, Y_train, X_test, Y_test = fetch_mnist()
with Tensor.train():
with Context(TRAINING=1):
for step in range(1000):
# random sample a batch
samp = np.random.randint(0, X_train.shape[0], size=(64))

View file

@ -1,6 +1,6 @@
from typing import Tuple
import time
from tinygrad import Tensor, TinyJit, nn
from tinygrad import Tensor, TinyJit, nn, Context
import gymnasium as gym
from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import
@ -55,7 +55,7 @@ if __name__ == "__main__":
@TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train():
with Context(TRAINING=1):
log_dist, value = model(x)
action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()

View file

@ -122,7 +122,7 @@ if __name__ == "__main__":
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step(idxs:Tensor) -> Tensor:
X, Y = X_train[idxs], Y_train[idxs]
if len(GPUS) > 1:

View file

@ -1,6 +1,6 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function, Context
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
@ -19,7 +19,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])

View file

@ -1,6 +1,6 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device, Context
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
@ -31,7 +31,7 @@ if __name__ == "__main__":
@TinyJit
def train_step() -> Tensor:
with Tensor.train():
with Context(TRAINING=1):
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0

View file

@ -1,6 +1,6 @@
import itertools
from typing import Callable
from tinygrad import nn, Tensor, dtypes, Device, TinyJit
from tinygrad import nn, Tensor, dtypes, Device, TinyJit, Context
from tinygrad.helpers import getenv, trange, partition
class Model:
@ -59,7 +59,7 @@ if __name__ == "__main__":
Tensor.realize(*params, *buffers, *adam_params, loss, grads)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def microbatch():
samples = Tensor.randint(BS // ACC_STEPS, high=X_train.shape[0])
for t in params: t.grad = None

View file

@ -359,7 +359,7 @@ def train_cifar():
i = 0
eval_acc_pct = 0.0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train():
with Context(TRAINING=1):
st = time.monotonic()
while i <= STEPS:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3
import os, math, time
import numpy as np
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters, Context
from dataclasses import dataclass
@dataclass
@ -177,7 +177,7 @@ if __name__ == "__main__":
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y)
optimizer.zero_grad()
@ -204,4 +204,3 @@ if __name__ == "__main__":
top_k = 40
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))

View file

@ -1,5 +1,5 @@
# much taken from https://github.com/cloneofsimo/minRF
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, Context
from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@ -135,7 +135,7 @@ if __name__ == "__main__":
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=5e-4)
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step():
if getenv("OVERFIT"): samples = Tensor.zeros(getenv("BS", 256), dtype='int')
else: samples = Tensor.randint(getenv("BS", 256), high=X_train.shape[0])

View file

@ -358,7 +358,7 @@ def eval_stable_diffusion():
batch = batch.cat(batch[-1:].expand(bs - unpadded_bs, *batch[-1].shape))
return batch, unpadded_bs
@Tensor.train(mode=False)
@Context(TRAINING=0)
def eval_unet(eval_inputs:list[dict], unet:UNetModel, cond_stage:FrozenOpenClipEmbedder, first_stage:AutoencoderKL,
inception:FidInceptionV3, clip:OpenClipEncoder) -> tuple[float, float]:
# Eval is divided into 5 jits, one per model

View file

@ -2,7 +2,7 @@ import os, time, math, functools, random, contextlib
from pathlib import Path
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes, Context
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker, DEBUG
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@ -614,7 +614,7 @@ def train_retinanet():
if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False):
with Context(TRAINING=0):
if not RUNMLPERF:
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
else:
@ -784,7 +784,7 @@ def train_unet3d():
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit
@Tensor.train()
@Context(TRAINING=1)
def train_step(model, x, y):
optim.zero_grad()
@ -795,7 +795,7 @@ def train_unet3d():
optim.step()
return loss.realize()
@Tensor.train(mode=False)
@Context(TRAINING=0)
def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y)
@ -1282,7 +1282,7 @@ def train_bert():
previous_step = i
def train_llama3():
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8_DTYPE
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8_DTYPE, MXFP8
from examples.llama3 import MODEL_PARAMS
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
from examples.mlperf.optim import GradAccClipAdamW
@ -1447,7 +1447,12 @@ def train_llama3():
idx = next(j for j, p in enumerate(optim.params) if p is w)
master = optim.master_params[idx]
inv = w._inv_scale if w._inv_scale.device == master.device else w._inv_scale.to(master.device)
master.assign((master * inv.reshape(*inv.shape, *([1]*(w.ndim-inv.ndim)))).contiguous())
if MXFP8:
from extra.gemm.cdna_asm_gemm import _mx_block_scale
bs = _mx_block_scale(inv.reshape(-1, inv.shape[-1])).reshape(w.shape)
master.assign((master * bs).contiguous())
else:
master.assign((master * inv.reshape(*inv.shape, *([1]*(w.ndim-inv.ndim)))).contiguous())
# realize everything here
if optim.master_params: Tensor.realize(*optim.master_params)
@ -1485,7 +1490,7 @@ def train_llama3():
return lr_cpu, grad_norm_cpu
@TinyJit
@Tensor.train(False)
@Context(TRAINING=0)
def eval_step(tokens:Tensor):
if is_dp: tokens = tokens.to(None).shard(device, 0)
if is_mp: tokens = tokens.shard(device)
@ -1798,7 +1803,7 @@ if __name__ == "__main__":
elif getenv("RUNMLPERF"): bench_log_manager = WallTimeEvent(BenchEvent.MLPERF_RUN)
else: bench_log_manager = contextlib.nullcontext()
with Tensor.train():
with Context(TRAINING=1):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
nm = f"train_{m}"
if nm in globals():

View file

@ -24,6 +24,7 @@ FUSED_ADD_NORM_MUL_QUANTIZE = getenv("FUSED_ADD_NORM_MUL_QUANTIZE", 0)
FUSED_SILU_W13 = getenv("FUSED_SILU_W13", 0)
SPLIT_W13 = getenv("SPLIT_W13", 0)
COLUMNWISE_WEIGHT_SCALE = getenv("COLUMNWISE_WEIGHT_SCALE", 0)
MXFP8 = getenv("MXFP8", 0)
FP8_DTYPE = dtypes.fp8e4m3
FP8_GRAD_DTYPE = dtypes.fp8e5m2
@ -36,35 +37,48 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None,
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]:
if not fp8:
if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),)
return (x @ w.T,)
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
if MXFP8:
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d)
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
if can_use_asm_gemm(x_q, w.T):
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
mx_w_stored=True).reshape(*l_shape, w.shape[0])
else:
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1])
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
return out, (amax_x.detach() if amax_x is not None else None), x_q
if x_fp8 is None:
if FUSED_INPUT_QUANTIZE and amax_x is not None:
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed
x_fp8, x_scale, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
x_fp8, _, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
else:
x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
x_fp8, _, x_new_amax = quantize_fp8(x, amax_state=amax_x)
if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x_fp8, w.T):
assert amax_x is not None
if COLUMNWISE_WEIGHT_SCALE:
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
else:
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
return out, x_new_amax, x_fp8
return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
return (x_fp8.dot(w.T, dtype=dtypes.float) * ((amax_x.float() + 1e-8) / FP8_MAX) * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
def norm_quantize_matmul(x:Tensor, norm:Tensor, w:Tensor, w_inv_scale:Tensor, eps:float, amax_x:Tensor, grad_amax_state:Tensor):
if FUSED_ADD_NORM_MUL_QUANTIZE:
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_rmsnorm_mul_quantize_fp8
x_fp8, x_inv_scale, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
x_fp8, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
return out, x_normed, rrms, ret
x_normed, rrms = rmsnorm(x, eps)
out, *ret = matmul(x_normed * norm, w, amax_x=amax_x, w_inv_scale=w_inv_scale, grad_amax_state=grad_amax_state)
@ -74,8 +88,8 @@ def add_norm_quantize_matmul(x:Tensor, residual:Tensor, norm:Tensor, w:Tensor, w
grad_amax_state:Tensor|None=None):
if FUSED_ADD_NORM_MUL_QUANTIZE:
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_add_rmsnorm_mul_quantize_fp8
x_fp8, x_inv_scale, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
x_fp8, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
return out, h, x_normed, rrms, ret
h = x + residual
x_normed, rrms = rmsnorm(h, eps)
@ -87,8 +101,8 @@ def silu_w13_quantize_matmul(x_w13:Tensor, w2:Tensor, s_2:Tensor,
grad_amax_xw13:Tensor, grad_amax_xout:Tensor):
if FUSED_SILU_W13:
from extra.llama_kernels.cast_amax import fused_quantize_fp8_w13
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
x2_fp8, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, amax_x=amax_x2, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
return out, ret
hidden = x_w13.shape[-1] // 2
x_w1, x_w3 = x_w13[..., :hidden], x_w13[..., hidden:]
@ -140,12 +154,17 @@ class FlatTransformer:
self._fp8_grad_amax = {name: [_amax() for _ in range(n_layers)] for name in grad_names}
w_scales = [("wqkv", s_qkv), ("wo", s_o), ("w2", s_2)]
w_scales += [("w1", s_1), ("w3", s_3)] if SPLIT_W13 else [("w13", s_13)]
self._fp8_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
self._fp8_next_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
self._fp8_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
self._fp8_next_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02):
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
if w is None:
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
return w_q.reshape(self.n_layers, out_features, in_features), w_e8.reshape(self.n_layers, out_features, in_features // 32)
amax = (w.abs().max(axis=2) if COLUMNWISE_WEIGHT_SCALE else w.abs().flatten(1).max(1)).detach()
scale = FP8_MAX / (amax + 1e-8)
inv_scale = (amax + 1e-8) / FP8_MAX
@ -171,7 +190,7 @@ class FlatTransformer:
xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
if getenv("HK_FLASH_ATTENTION"):
from extra.thunder.amd.fa import flash_attention
attn, *save = flash_attention(xq, xk, xv, is_causal=True)
attn, *save = flash_attention(xq, xk, xv, is_causal=True, write_flat=True)
saves.extend(save)
else:
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
@ -197,8 +216,15 @@ class FlatTransformer:
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
amaxs.append(new_amax)
saves.extend([*s, x_w3])
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
grad_amax_state=kwargs["grad_amax_xout"])
if FUSED_SILU_W13 and MXFP8:
from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
else:
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
grad_amax_state=kwargs["grad_amax_xout"])
amaxs.append(new_amax)
saves.extend([*s, out])
else:
@ -228,20 +254,30 @@ class FlatTransformer:
for v in get_parameters(self): v.shard_(device, axis=None)
else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
def _shard_fp8(name:str, axis:int):
getattr(self, name).shard_(device, axis=axis)
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
def _shard_fp8(name:str, axis:int, std:float=0.02):
w = getattr(self, name)
if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std
w_q, w_e8, _ = quantize_mxfp8(w_bf16)
w.replace(w_q)
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
else:
w.shard_(device, axis=axis)
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
sstd = 0.02 / math.sqrt(2 * self.n_layers)
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
_shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in
if SPLIT_W13:
_shard_fp8("w1", 1)
_shard_fp8("w3", 1)
else:
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
_shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize()

View file

@ -8,6 +8,7 @@ STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0)
MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0)
FP8_AMAX_MARGIN = getenv("FP8_AMAX_MARGIN", 1.1)
IMMEDIATE_SCALE = getenv("IMMEDIATE_SCALE", 0)
MXFP8 = getenv("MXFP8", 0)
def stochastic_round_bf16(x:Tensor) -> Tensor:
bits = x.bitcast(dtypes.uint32)
@ -90,6 +91,13 @@ class GradAccClipAdamW(Optimizer):
out = stochastic_round_bf16(new_w)
return out.shard_like(t) if offloaded else out
if t.dtype in dtypes.fp8s:
if MXFP8:
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
w_q, w_e8, _ = quantize_mxfp8(new_w.reshape(-1, new_w.shape[-1]))
new_e8 = w_e8.reshape(t._inv_scale.shape)
t._inv_scale.assign(new_e8.shard_like(t._inv_scale) if offloaded else new_e8)
ret = w_q.reshape(new_w.shape)
return ret.shard_like(t) if offloaded else ret
from examples.mlperf.models.flat_llama import FP8_MAX
if IMMEDIATE_SCALE:
amax_axis = tuple(range(t._inv_scale.ndim, new_w.ndim))

View file

@ -3,7 +3,7 @@ import torch
from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.helpers import trange
from tinygrad.helpers import trange, Context
from tinygrad.nn import optim
from tinygrad.nn.datasets import mnist
@ -86,7 +86,7 @@ if __name__ == "__main__":
optim_g = optim.Adam(get_parameters(generator), lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop
with Tensor.train():
with Context(TRAINING=1):
for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps):

View file

@ -5,7 +5,7 @@
# - symbolic removal
from examples.beautiful_mnist import Model
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable
from tinygrad import Tensor, nn, getenv, GlobalCounters, Variable, Context
from tinygrad.nn.datasets import mnist
from tinygrad.helpers import trange
@ -26,7 +26,7 @@ if __name__ == "__main__":
X_samp, Y_samp = X_train[samples], Y_train[samples]
print("*** got samples")
with Tensor.train():
with Context(TRAINING=1):
"""
i = UOp.range(samples.shape[0]) # TODO: fix range function on UOp
losses = model(X_samp[i]).sparse_categorical_crossentropy(Y_samp[i]).backward().contract(i)

View file

@ -1,5 +1,5 @@
from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType, dtypes
from tinygrad.dtype import DType, dtypes, AddrSpace
from tinygrad.tensor import Tensor
from tinygrad.device import Device, Buffer
from tinygrad.engine.jit import TinyJit
@ -39,7 +39,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
prg = to_program(call.src[0], Device[arg_uops[0].device].renderer)
info = prg.arg
functions[info.function_name] = prg.src[3].arg
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR]
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + list(info.vars)
statements.append((info.function_name, cargs, info.global_size, info.local_size))
return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
@ -253,17 +253,18 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
symbolic_vars = OrderedDict()
for i, (_, args, global_size, _) in enumerate(statements):
for j, var in enumerate(args):
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
if getattr(var, "op", None) is Ops.PARAM and var.addrspace is AddrSpace.ALU and var.arg.name is not None:
if var not in symbolic_vars:
symbolic_vars[var] = var.arg[0]
symbolic_vars[var] = var.expr
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
statements[i][1][j] = symbolic_vars[var]
if global_size:
for j, dim in enumerate(global_size):
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and \
any(s.op is Ops.PARAM and s.addrspace is AddrSpace.ALU for s in dim.src) and any(s.op is Ops.CONST for s in dim.src):
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
global_size[j] = f"_{name.expr}[0] + {val.arg}"
prg = ""
if target == "clang":

View file

@ -458,7 +458,8 @@ def test_matmul():
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))

View file

@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
acc = acc.after(acc.store(acc.zeros_like()))
acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
if use_wmma:
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)

View file

@ -126,7 +126,7 @@ def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
P_lds = QP_lds[:, :BLOCK_N]
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) — shaped store fails due to RESHAPE(DEFINE_LOCAL) surviving linearization
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) -- shaped store fails due to RESHAPE(local BUFFER) surviving linearization
rw1 = UOp.range(TM, 296, AxisType.LOOP)
rw2 = UOp.range(TN, 297, AxisType.LOOP)
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)

View file

@ -2619,7 +2619,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0")
insts = build_kernel(batch, M, N, K, A.dtype.base)
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
lds = UOp.placeholder((133_120,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),
@ -2630,7 +2630,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
@functools.cache
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3) -> UOp:
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0)
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0) + (1 if scale_mode & 4 else 0)
scales, extra = args[:n_scales], args[n_scales:]
M, K = A.shape[0]*A.shape[1], A.shape[2]
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
@ -2660,7 +2660,9 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
block_size = 256
threads = UOp.special(64 * 8, "lidx0")
workgroups = UOp.special((M // block_size) * (N // block_size), "gidx0")
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, threads, workgroups)
e_a = extra[0].base if len(extra) >= 1 else scale_A.base
e_b = extra[1].base if len(extra) >= 2 else scale_B.base
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, e_a, e_b, threads, workgroups)
sink = UOp.sink(*sink_inputs,
arg=KernelInfo(f"hk_mxfp8_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K)*A.dtype.itemsize+M*N*C.dtype.itemsize)))
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
@ -2672,14 +2674,14 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# 1x32 block scaling along the last axis
rows, K = x.shape
scale_K, k_iters = K // 32, K // 128
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
*batch, K = x.shape
scale_K = K // 32
amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
x_scaled = x.float() * qscale
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8)
return x_clamped.cast(FP8_DTYPE), e8, (mx_pack(e8) if len(batch) == 1 else None)
def mx_pack(e8:Tensor) -> Tensor:
rows, scale_K = e8.shape
@ -2750,7 +2752,8 @@ def custom_hk_bf16_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str) -> UOp:
assert M % block_m == 0 and N % block_n == 0 and K % block_k == 0, f"invalid bf16 tile {(block_m, block_n, block_k)} for {(M, N, K)}"
threads = UOp.special(64 * num_warps, "lidx0")
workgroups = UOp.special((M // block_m) * (N // block_n), "gidx0")
sink = UOp.sink(C.base, A.base, B.base, threads, workgroups,
b_extra = args[0].base if len(args) >= 1 else B.base
sink = UOp.sink(C.base, A.base, B.base, b_extra, threads, workgroups,
arg=KernelInfo(f"hk_bf16_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K+M*N)*A.dtype.itemsize)))
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
src = (kittens_path/"gemm_bf16.cpp").read_text()
@ -2785,15 +2788,21 @@ def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
assert M % TILE_M == 0 and N % TILE_N == 0 and (batch * rows) % TILE_K == 0, \
f"atb shape {a.shape} {b.shape} must produce (M,N,K) multiples of ({TILE_M},{TILE_N},{TILE_K})"
is_multi = isinstance(a.device, tuple)
reduce_out = False
if is_multi:
out = Tensor(Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
ndev = len(a.device)
if a.uop.axis in (0, 1) or b.uop.axis in (0, 1): inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
elif b.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M, N // ndev, dtype=a.dtype, device=a.device), 2
elif a.uop.axis == 2: inv, out_axis = Tensor.invalids(1, M // ndev, N, dtype=a.dtype, device=a.device), 1
else: inv, out_axis, reduce_out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device), 0, True
out = Tensor(inv.uop.multi(out_axis), device=a.device)
dname = a.device[0]
else:
out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device)
dname = a.device
dname = dname.split(":")[0]
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_hk_bf16_atb_gemm, dname=dname))[0]
if is_multi: out = out.sum(0)
if reduce_out: out = out.sum(0)
return out.squeeze(0) if out.ndim == 3 else out
@ -2805,13 +2814,15 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
out, a, b = inputs[:3]
i = 3
s_x = inputs[i]; i += 1
has_w = n_scales == 2
has_w = n_scales >= 2
s_w = inputs[i] if has_w else None; i += has_w
s_g = inputs[i] if n_scales == 3 else None; i += (n_scales == 3)
grad_amax_state = inputs[i] if has_grad_amax else None; i += has_grad_amax
w_post = inputs[i] if has_w_post else None
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
s_x_t = Tensor(s_x, device=a.device)
s_w_t = Tensor(s_w, device=a.device) if has_w else None
s_g_t = Tensor(s_g, device=a.device) if s_g is not None else None
w_post_t = Tensor(w_post, device=a.device) if has_w_post else None
g_t = g_t[:a.shape[0]]
from extra.llama_kernels.cast_amax import _grad_fp8_mailbox
@ -2836,7 +2847,8 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
store_effect = grad_amax_state.store(new_grad_amax.uop)
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t) if has_w else asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t)
if s_g_t is not None: g_scale = g_scale * s_g_t
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
# wgrad: no w_scale
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
@ -2844,7 +2856,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
else:
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
# wgrad: rescale if not scalar
if w_post_t is not None:
grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim)))
@ -2876,7 +2888,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
# ** mxfp8 gemm backward
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool, w_stored:bool=False):
inputs = kernel.src[1:] # (out, a_q, b_q, a_si, b_si, a_e8, b_e8, [w_post])
aq, bq = Tensor(inputs[1], device=inputs[1].device), Tensor(inputs[2], device=inputs[2].device)
ae8, be8 = Tensor(inputs[5], device=inputs[5].device), Tensor(inputs[6], device=inputs[6].device)
@ -2890,14 +2902,14 @@ def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
grad_b = asm_gemm(g.T, a_phys, mx=True)
grad_a = (grad_a * _mx_block_scale(ae8)).reshape(aq.shape)
grad_b = grad_b * _mx_block_scale(be8)
if not w_stored: grad_b = grad_b * _mx_block_scale(be8)
if wp is not None: grad_b = grad_b / wp.reshape(-1, 1)
return (None, grad_a.uop, grad_b.uop) + tuple(None for _ in inputs[3:])
# ** main gemm function
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None) -> Tensor:
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False, g_scale:Tensor|None=None) -> Tensor:
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
counters["used"] += 1
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
@ -2939,13 +2951,13 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
b_q, b_e8, b_si = quantize_mxfp8(b.T)
has_w_post = w_post_scale is not None
fxn = functools.partial(custom_hk_mxfp8_gemm, dname=dname)
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post)
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post, w_stored=mx_w_stored)
extra = [w_post_scale] if w_post_scale is not None else []
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
elif a.dtype == FP8_DTYPE:
scales = tuple(s for s in (x_scale, w_scale) if s is not None)
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0)
scales = tuple(s for s in (x_scale, w_scale, g_scale) if s is not None)
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0) | (4 if g_scale is not None else 0)
extra = ([grad_amax_state] if grad_amax_state is not None else []) + ([w_post_scale] if w_post_scale is not None else [])
fxn = functools.partial(custom_hk_fp8_gemm, dname=dname, scale_mode=scale_mode)
bw = functools.partial(custom_gemm_bw, n_scales=len(scales), has_grad_amax=grad_amax_state is not None, has_w_post=w_post_scale is not None)

View file

@ -219,7 +219,8 @@ def test_matmul():
def asm_kernel(A, B, C):
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(THREADS, "lidx0")]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds')
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2))
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))

View file

@ -1,145 +0,0 @@
from __future__ import annotations
import time
from typing import cast
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, MultiBuffer
from tinygrad.dtype import dtypes
from tinygrad.engine.jit import GraphRunner
from tinygrad.engine.realize import get_call_outs_ins, get_runtime
from tinygrad.helpers import round_up, ceildiv
from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, graph_rewrite
from extra.hcq2.hcq2 import HCQ2Compiled, HCQ2DeviceCtx, HCQ2LowerCtx, pm_prep_runtime, pm_lower_ops
from extra.hcq2.hcq2 import pm_split_into_queues, pm_add_barriers, pm_add_signals
from extra.hcq2.hcq2 import pm_bufferize, pm_lift_patches_to_cmdbuf, pm_resolve_patches, pm_parametrize_host_buffers
from extra.hcq2.hcq2 import pm_add_timeline_inc, pm_callify, pm_calc_kernargs_sizes
# **************** insert deps ****************
def insert_deps(ctx:HCQ2Graph, linear:UOp) -> UOp:
src = []
for j, call in enumerate(linear.src):
call = call.replace(tag=j)
_, _, bufs, _ = ctx.calls[j]
outs, ins = get_call_outs_ins(call)
deps = ctx._access_resources([bufs[i] for i in outs + ins], list(range(len(outs))), call)
src.append(UOp(Ops.AFTER, call.dtype, (call, *deps), tag=call.tag))
return linear.replace(src=tuple(src))
pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)])
pm_replace_params = PatternMatcher([
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
(UPat(Ops.SLICE, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
])
# **************** graph-only passes ****************
def alloc_queue_sig(ctx:HCQ2Graph, q:UOp) -> None:
if q.arg in ctx.queue_sigs: return None
dev = q.arg[0][0] # TODO: multi device
buf = Buffer(dev, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
ctx.queue_sig_bufs.append(buf)
ctx.queue_sigs[q.arg] = UOp.from_buffer(buf, dev)
return None
pm_alloc_queue_sigs = PatternMatcher([(UPat(Ops.LINEAR, src=UPat({Ops.PROGRAM, Ops.COPY}), name="q"), alloc_queue_sig)])
def lower_queue_deps(ctx:HCQ2Graph, after:UOp) -> UOp:
wrapper, deps, call_idx = after.src[0], after.src[1:], after.tag
def store(q_arg, v): return ctx.queue_sigs[q_arg].store(UOp.const(dtypes.uint32, v))
waits = tuple(UOp(Ops.WAIT, dtypes.void, (ctx.queue_sigs[dep.src[0].arg], UOp.const(dtypes.uint32, dep.tag),
store(dep.src[0].arg, dep.tag))) for dep in deps)
return wrapper.replace(src=tuple(q.replace(src=(*waits, *q.src, store(q.arg, call_idx))) for q in wrapper.src))
pm_lower_queue_deps = PatternMatcher([(UPat(Ops.AFTER, src=UPat(Ops.LINEAR), name="after"), lower_queue_deps)])
def optimize_queue_deps(ctx:HCQ2Graph, queue:UOp) -> UOp|None:
src, seen, pending, queue_sig = [], {}, {}, ctx.queue_sigs[queue.arg]
for x in queue.src:
if x.op is Ops.WAIT:
sig, val = x.src[0], x.src[1]
if sig is queue_sig or seen.get(sig, -1) >= val.arg: continue
if (old:=pending.get(sig)) is None or old.src[1].arg < val.arg: pending[sig] = x
continue
for wait in pending.values():
src.append(wait)
seen[wait.src[0]] = wait.src[1].arg
pending.clear()
src.append(x)
src += pending.values()
return queue.replace(src=tuple(src)) if tuple(src) != queue.src else None
pm_optimize_queue_deps = PatternMatcher([
(UPat(Ops.LINEAR, src=UPat({Ops.BARRIER, Ops.WAIT, Ops.STORE, Ops.PROGRAM, Ops.COPY}), name="queue"), optimize_queue_deps),
])
def drop_dead_stores(ctx:HCQ2Graph, outer:UOp) -> UOp:
live = {u.src[2] for u in outer.toposort() if u.op is Ops.WAIT}
return outer.replace(src=tuple(q.replace(src=tuple(x for x in q.src if x.op is not Ops.STORE or x in live)) for q in outer.src))
pm_drop_dead_stores = PatternMatcher([(UPat(Ops.LINEAR, src=UPat(Ops.LINEAR), name="outer"), drop_dead_stores)])
def add_queue_sig_resets(ctx:HCQ2Graph, x:UOp, cmdbuf:UOp) -> UOp|None:
if not ctx.queue_sig_bufs or cmdbuf.tag not in ("compute", "copy"): return None
resets = tuple((b:=UOp.from_buffer(sig)).index(UOp.const(dtypes.int, 0), dtype=b.dtype.ptr())
.cast(dtypes.uint64.ptr()).store(UOp.const(dtypes.uint64, 0)) for sig in ctx.queue_sig_bufs)
return x.replace(src=x.src + resets)
pm_add_queue_sig_resets = PatternMatcher([(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, name="cmdbuf"),), allow_any_len=True, name="x"),
add_queue_sig_resets)])
# **************** Graph ****************
class HCQ2Graph(GraphRunner):
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
super().__init__(linear, input_uops)
self.dev = cast(HCQ2Compiled, Device[self.device])
self.hcq_ctx = HCQ2LowerCtx(name="hcq_graph")
self.input_addrs = Buffer("CPU", max(len(input_uops), 1), dtypes.uint64, preallocate=True)
self.input_addrs_uop = UOp.from_buffer(self.input_addrs, "CPU")
self.linear = graph_rewrite(self.linear, pm_insert_deps, ctx=self, name="hcq: insert deps", walk=True)
self.linear = graph_rewrite(self.linear, pm_replace_params, ctx=self, name="hcq: replace params", walk=True)
self.linear = graph_rewrite(self.linear, pm_prep_runtime, ctx=self.hcq_ctx, name="hcq: prepare runtime")
self.linear = graph_rewrite(self.linear, pm_lower_ops, ctx=self.hcq_ctx, name="hcq: lower ops")
# per-queue signal state — populated as a side-effect by pm_alloc_queue_sigs walking the lowered linear.
self.queue_sig_bufs:list[Buffer] = []
self.queue_sigs:dict[tuple[str, str], UOp] = {}
graph_rewrite(self.linear, pm_alloc_queue_sigs, ctx=self, name="hcq: alloc queue sigs", walk=True)
self.linear = graph_rewrite(self.linear, pm_lower_queue_deps, ctx=self, name="hcq: lower queue deps")
self.linear = graph_rewrite(self.linear, pm_split_into_queues, ctx=self.hcq_ctx, name="hcq: split into queues")
self.linear = graph_rewrite(self.linear, pm_add_barriers, ctx=self.hcq_ctx, name="hcq: add barriers", walk=True)
self.linear = graph_rewrite(self.linear, pm_optimize_queue_deps, ctx=self, name="hcq: optimize queue deps", walk=True)
self.linear = graph_rewrite(self.linear, pm_drop_dead_stores, ctx=self, name="hcq: drop dead stores")
self.linear = graph_rewrite(self.linear, pm_add_signals, ctx=self.hcq_ctx, name="hcq: add signals", walk=True)
self.linear = graph_rewrite(self.linear, pm_add_timeline_inc, ctx=self.hcq_ctx, name="hcq: add submit", walk=True)
self.linear = graph_rewrite(self.linear, self.dev.pm_lower, ctx=self.hcq_ctx, name=f"hcq: encode cmdbuf {self.dev.device}", walk=True)
graph_rewrite(self.linear, pm_calc_kernargs_sizes, ctx=(sizes:={}), name=None)
for dev_name, sz in sizes.items():
buf = Buffer(dev_name, sz, dtypes.uint8, options=BufferSpec(cpu_access=True), preallocate=True)
self.hcq_ctx.dev_ctx[dev_name] = HCQ2DeviceCtx(dev_name, UOp.from_buffer(buf, dev_name), UOp.const(dtypes.uint64, buf._buf.va_addr))
self.linear = graph_rewrite(self.linear, pm_bufferize, ctx=self.hcq_ctx, bottom_up=True, name="realize binaries")
self.linear = graph_rewrite(self.linear, pm_lift_patches_to_cmdbuf, ctx=self.hcq_ctx, bottom_up=False, name="lift patches to cmdbuf")
self.linear = graph_rewrite(self.linear, pm_resolve_patches, ctx=self.hcq_ctx, bottom_up=False, name="simplify patches")
self.linear = graph_rewrite(self.linear, pm_add_queue_sig_resets, ctx=self, name="hcq: add queue sig resets", walk=True)
self.linear = graph_rewrite(self.linear, pm_parametrize_host_buffers, ctx=self.hcq_ctx, bottom_up=True, name="parametrize host buffers")
self.host_call = graph_rewrite(self.linear, pm_callify, ctx=self.hcq_ctx, name="hcq: callify")
self.host_rt, self.host_globals = get_runtime("CPU", self.host_call.src[0]), self.host_call.src[0].arg.globals
def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False) -> float|None:
addrs = self.input_addrs.as_memoryview(force_zero_copy=True).cast('Q')
for i, u in enumerate(input_uops):
buf = next(b for b in u.buffer.bufs if b.device == self.dev.device) if isinstance(u.buffer, MultiBuffer) else u.buffer
addrs[i] = buf._buf.va_addr
self.host_rt(*[self.hcq_ctx.inputs[i].get_buf("CPU") for i in self.host_globals], vals=self.host_call.src[0].arg.vals(var_vals), wait=True)
if wait:
st = time.perf_counter()
self.dev.synchronize()
return time.perf_counter() - st
return None
@staticmethod
def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool:
all_devs = GraphRunner._all_devs(batch_devs, new_call)
return new_call.src[0].op in (Ops.PROGRAM, Ops.COPY) and len(all_devs) == 1 and isinstance(all_devs[0], HCQ2Compiled)

View file

@ -1,14 +1,13 @@
from __future__ import annotations
from typing import cast, Callable, TypeVar, Generic, Any
import struct, functools, time, collections, importlib, itertools, weakref
from dataclasses import replace
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, pluralize, to_tuple
from dataclasses import replace, dataclass, field
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, DEBUG, dedup, flatten, pluralize
from tinygrad.helpers import to_tuple, round_up
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
from tinygrad.uop.symbolic import symbolic_simple, symbolic
from tinygrad.dtype import dtypes, DType
from dataclasses import dataclass, field
from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.renderer import Renderer, Estimates
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
@ -28,7 +27,7 @@ class HCQ2Compiled(Compiled):
(UPat(Ops.BUFFER, tag="timeline_value"), lambda ctx: ctx.timeline_value()),
(UPat(Ops.BUFFER, tag="sentinel_signal"), lambda ctx: ctx.timeline_signal("sentinel", (1 << 64) - 1)),
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=True, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
Buffer(ctx.device, b.arg, b.dtype, options=BufferSpec(host=False, uncached=True, cpu_access=True, nolru=True))), # TODO: remove nolru
])
super().__init__(device, allocator, compilers, lambda *a, **kw: None, None, arch=arch)
@ -137,21 +136,24 @@ def unwrap_after(uop):
while uop.op is Ops.AFTER: uop = uop.src[0]
return uop
def make_getaddr(u, dev=None):
def make_getaddr(u, device=None):
if unwrap_after(u).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return u
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=dev or to_tuple(u.device)[0])))
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=device or to_tuple(u.device)[0])))
def make_ins(op, *srcs):
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
dt = dtype or val.dtype
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
def make_cmdbuf(lin, devs, tag):
blob, patches = b'', []
for s in (s for ins in lin.src for s in ins.src):
if s.op is not Ops.CONST: patches.append((len(blob), s))
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
buf = UOp.new_buffer(devs if len(devs) > 1 else devs[0], len(blob), dtypes.uint8).rtag(tag)
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches])
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
@ -172,6 +174,10 @@ class HCQInfo:
name:str = ""
estimates:Estimates = Estimates()
outs:tuple[int, ...] = ()
devs:tuple[str, ...] = ()
params:tuple[int, ...] = ()
inputs:int|None = None
@staticmethod
def from_call(call:UOp) -> HCQInfo: return HCQInfo(get_call_name(call, get_call_arg_uops(call)), estimate_uop(call), get_call_outs_ins(call)[0])
@ -208,15 +214,11 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
data, info = prg.arg
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
dtypes.uint64) for i,gi in enumerate(info.globals)] \
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
(data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device)
buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \
+ [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)]
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
pm_prep_runtime = PatternMatcher([
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
@ -269,6 +271,14 @@ class DepsCtx:
deps:DepsTracker = field(default_factory=DepsTracker)
opid:itertools.count = field(default_factory=lambda: itertools.count(0))
last_per_queue:weakref.WeakValueDictionary[tuple[Any, str], UOp] = field(default_factory=weakref.WeakValueDictionary)
params:dict[tuple[int, int], Buffer] = field(default_factory=dict)
def get_dep_buf(ctx:DepsCtx, u:UOp, lane:int) -> Buffer:
# TODO: should this be a part of DepsTracker?
if u.op is Ops.PARAM: return ctx.params.setdefault((u.arg.slot, lane), Buffer("NULL", u.max_numel(), u.dtype.base))
if u.op is Ops.MSTACK: return get_dep_buf(ctx, u.src[lane], 0)
if u.op in (Ops.SLICE, Ops.MSELECT): return get_dep_buf(ctx, u.src[0], u.arg if u.op is Ops.MSELECT else lane)
return b.bufs[lane] if isinstance(b:=u.buffer, MultiBuffer) else b
def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
new_src = []
@ -277,16 +287,26 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
new_src.append(call)
continue
q = get_submit(call.src[0]).src[0]
new_q = ctx.last_per_queue[q.arg] = q.rtag(next(ctx.opid))
new_q = ctx.last_per_queue[q.arg] = (q:=get_submit(call.src[0]).src[0]).rtag(next(ctx.opid))
qdevs, refs = to_tuple(new_q.arg[0]), get_call_arg_uops(call)
deps = []
refs = [b.buffer for b in get_call_arg_uops(call)]
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.arg.aux.outs, new_q)
# per-lane deps, tracked per (device, queue). skip self
dep_lanes:list[tuple[UOp, int]] = []
for lane, d in enumerate(qdevs):
for dep in ctx.deps.access_resources([get_dep_buf(ctx, b, lane) for b in refs], call.arg.aux.outs, new_q.replace(arg=(d, new_q.arg[1]))):
if dep.tag != new_q.tag: dep_lanes.append((dep, lane))
new_q = new_q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else new_q
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
# drop self-queue waits, queue self-orders
if qdevs[0].split(":")[0] in {"AMD", "QCOM"} or new_q.arg[1].startswith("COPY"):
dep_lanes = [(dep, lane) for dep, lane in dep_lanes if dep.arg != (qdevs[lane], new_q.arg[1])]
# keep latest dep per lane, group lanes
latest = {(dep.arg, lane): dep for dep, lane in sorted(dep_lanes, key=lambda x: x[0].tag)}
deps:dict[UOp, tuple[int, ...]] = collections.defaultdict(tuple)
for (_, lane), dep in latest.items(): deps[dep] += (lane,)
if deps: new_q = new_q.after(*deps, arg=tuple(deps.values())).rtag("deps")
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}),)))
return linear.replace(src=tuple(new_src))
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
@ -298,11 +318,16 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
zero = UOp.const(dtypes.int, 0)
tl = make_signal_value(devs)
submit = make_submit(make_signal(devs).store(tl.index(zero) + 1), devs=devs, queue="COMPUTE:0")
# queue is inc with deps
submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0")
# split each (multi-device) queue into per-device deps so each finalizer lane waits on the matching device's signal
lane_queues = [(q.replace(arg=(d, q.arg[1])), (devs.index(d),)) for q in queues for d in to_tuple(q.arg[0])]
submit = submit.replace(src=(submit.src[0].after(*(q for q, _ in lane_queues), arg=tuple(l for _, l in lane_queues)).rtag("deps"),))
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]) \
.sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]
return UOp.barrier(*patches).sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
parts:dict[str, list[UOp]] = collections.defaultdict(list)
@ -319,21 +344,22 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
cur_devs = to_tuple((cur:=deps.src[0]).arg[0])
waits = []
for dep in deps.src[1:]:
devs, queue = dep.arg
for lanes, dep in zip(deps.arg, deps.src[1:]):
dep_dev, queue = dep.arg # dep_dev is a single device (deps are recorded per-device)
ctx.add(dep.tag) # mark op to update signal.
sig = make_mstack([make_signal(d, queue=queue, sentinel=d not in devs) for d in cur_devs])
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0))
# for lanes that need this dep, wait on the dep device's signal/value; other lanes get a passing sentinel
lanes = set(lanes)
sig = make_mstack([make_signal(dep_dev if j in lanes else d, queue=queue, sentinel=j not in lanes) for j, d in enumerate(cur_devs)])
val = make_mstack([make_signal_value(dep_dev if j in lanes else d, queue=queue) for j, d in enumerate(cur_devs)]).index(UOp.const(dtypes.int, 0))
waits.append(sig.wait(val + dep.tag))
return cur.replace(src=tuple(waits) + cur.src)
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
src = q.src
if q.tag in ctx:
devs, queue = q.arg
src += (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp|None:
if q.tag not in ctx: return None
devs, queue = q.arg
src = q.src + (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
return submit.replace(src=(q.replace(src=src, tag=None),))
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
@ -342,34 +368,36 @@ pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", s
def get_submit(ast:UOp) -> UOp: return next(u for u in ast.toposort() if u.op is Ops.CUSTOM_FUNCTION and u.arg == "submit")
def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
old_submit, new_submit = get_submit(old_sink), get_submit(new_sink)
old_queue, new_queue = old_submit.src[0], new_submit.src[0]
merged_submit = new_submit.replace(src=(new_queue.replace(src=old_queue.src + new_queue.src),))
old_root = old_sink.src[0].substitute({old_submit: merged_submit})
new_anchor = merged_submit if old_sink.src[0] is old_submit else old_root
return new_sink.substitute({new_submit: new_anchor})
def merge_sink(sinks:list[UOp]) -> UOp:
if len(sinks) == 1: return sinks[0]
submits = [get_submit(sink) for sink in sinks]
queues = [submit.src[0] for submit in submits]
anchor = submits[-1].replace(src=(queues[-1].replace(src=tuple(x for q in queues for x in q.src)),))
for sink, submit in zip(sinks[:-1], submits[:-1]):
if sink.src[0] is not submit: anchor = sink.src[0].substitute({submit: anchor}, walk=True)
return sinks[-1].substitute({submits[-1]: anchor}, walk=True)
def merge_queues(linear:UOp) -> UOp:
new_src:list[UOp] = []
opened_qs:dict[tuple[tuple[str, ...], str], tuple[UOp, HCQInfo]] = {} # (devs, queue) -> (sink, aux), kept in submit order
opened_qs:dict[tuple[tuple[str, ...], str], tuple[list[UOp], HCQInfo]] = {} # (devs, queue) -> (sinks, aux), kept in submit order
for call in linear.src:
if call.tag != "hcq":
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in list(opened_qs)] + [call]
# finalizer cannot be merged, since it bumps inner signal (this introduces race when multidevs).
if call.tag != "hcq" or (call.tag == "hcq" and call.arg.aux.name == "hcq finalizer"):
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in list(opened_qs)] + [call]
continue
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
aux = call.arg.aux
new_rec = ([new_sink], call.arg.aux)
if (old:=opened_qs.pop((devs, queue), None)) is not None:
new_sink = merge_sinks(old[0], new_sink) # exact same queue: merge, and re-insert at the end
aux = replace(aux, name=f"{queue.lower()} submit", estimates=old[1].estimates + aux.estimates)
new_rec = (old[0] + [new_sink], replace(new_rec[1], name=f"{queue.lower()} submit", estimates=old[1].estimates + new_rec[1].estimates))
else:
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
opened_qs[(devs, queue)] = (new_sink, aux)
closing = [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]
new_src += [merge_sink((sa:=opened_qs.pop(k))[0]).call(aux=sa[1]).rtag("hcq") for k in closing]
opened_qs[(devs, queue)] = new_rec
return linear.replace(src=tuple(new_src + [sink.call(aux=aux).rtag('hcq') for sink, aux in opened_qs.values()]))
return linear.replace(src=tuple(new_src + [merge_sink(sinks).call(aux=aux).rtag("hcq") for sinks, aux in opened_qs.values()]))
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
# *****************
@ -384,6 +412,29 @@ def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
return submit.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), wait, *q.src)),))
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
# *****************
# 4.3. annotate exec devs
pm_annotate_devs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
lambda call: call.replace(arg=replace(call.arg, aux=replace(call.arg.aux, devs=get_submit(call.src[0]).src[0].arg[0]))))])
# *****************
# 4.4. replace params with per-submit input address loads
def replace_params(call:UOp) -> UOp|None:
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is AddrSpace.GLOBAL}): return None
# fill new info
hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call)))
inputs = UOp.new_buffer(get_submit(call.src[0]).src[0].arg[0], len(hcqinfo.params), dtypes.uint64).rtag("inputs")
slot2idx = {s:i for i,s in enumerate(hcqinfo.params)}
body = call.src[0].substitute({u:inputs.index(UOp.const(dtypes.int, slot2idx[s])).load() for u,s in params.items()})
return call.replace(src=(body, *call.src[1:], inputs), arg=replace(call.arg, aux=hcqinfo))
pm_replace_params = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), replace_params)])
# *****************
# 5.1. encode cmdbufs
@ -411,16 +462,28 @@ pm_lift_patches_to_cmdbuf = PatternMatcher([
])
# *****************
# 6. bufferize placeholders: replace placeholders with real buffers.
# 5.3. pack placeholders buffers
def bufferize_buf(buf:UOp) -> UOp|None:
if buf.tag is None: return None
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), dev) for dev in to_tuple(buf.src[1].arg))
return make_mstack(uops)
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
def pack_hcq_placeholders(call:UOp) -> UOp|None:
bufs = [b for b in call.src[0].toposort() if b.op is Ops.BUFFER and b.tag in (maxtags:={"scratch"}) | (sumtags:={"program", "kernargs"})]
off_per_buf:dict[UOp, int] = {}
size_per_tag:dict[str, int] = {}
for b in bufs:
if b.tag in maxtags: size_per_tag[b.tag] = max(size_per_tag.get(b.tag, 0), b.arg)
elif b.tag in sumtags:
off_per_buf[b] = round_up(size_per_tag.get(b.tag, 0), {"program": 0x1000}.get(b.tag, 128))
size_per_tag[b.tag] = off_per_buf[b] + b.arg
count_per_tag = collections.Counter(b.tag for b in bufs)
ref_bufs = {b.tag:b for b in bufs if count_per_tag[b.tag] > 1}
bases = {tag:UOp.new_buffer(b.src[1].arg, size_per_tag[tag], b.dtype).rtag(tag) for tag,b in ref_bufs.items()}
subs = {b:UOp(Ops.SLICE, b.dtype, (bases[b.tag], UOp.const(dtypes.weakint, off_per_buf.get(b, 0))), b.arg) for b in bufs if b.tag in bases}
return call.replace(src=(call.src[0].substitute(subs, walk=True), *call.src[1:])) if subs else None
pm_pack_placeholders = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), pack_hcq_placeholders)])
# *****************
# 7.1. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
# 5.4. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
def hold_call_buffers(call:UOp) -> UOp|None:
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
@ -428,57 +491,83 @@ def hold_call_buffers(call:UOp) -> UOp|None:
pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), hold_call_buffers)])
# *****************
# 7.2. resolve patches
# 6. bufferize placeholders: replace placeholders with real buffers.
def bufferize_buf(buf:UOp) -> UOp|None:
if buf.tag is None: return None
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), "CPU") for dev in to_tuple(buf.src[1].arg))
return make_mstack(uops)
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
# *****************
# 7. resolve patches
def push_stack(op, s): return UOp(Ops.STACK, op.dtype.scalar().vec(len(s.src)),
tuple(op.replace(dtype=op.dtype.scalar(), src=tuple(x if y is s else y for y in op.src)) for x in s.src))
def fold_blob_store(buf:UOp, blob:UOp) -> UOp:
for b in (buf.src if buf.op is Ops.MSTACK else (buf,)): b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
for b in (mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)): b.ensure_allocated()._buf.cpu_view().mv.cast('B')[:len(blob.arg)] = blob.arg
return UOp(Ops.NOOP)
def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
for b, v in zip((buf.src if buf.op is Ops.MSTACK else (buf,)), (val.src if val.op is Ops.STACK else (val,))):
struct.pack_into(f'<{v.dtype.fmt}', b.buffer.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * b.dtype.base.itemsize, v.arg)
for b, v in zip((bs:=mb.bufs if isinstance((mb:=buf.buffer), MultiBuffer) else (mb,)), val.src if val.op is Ops.STACK else (val,)*len(bs)):
struct.pack_into(f'<{v.dtype.fmt}', b.ensure_allocated()._buf.cpu_view().mv.cast('B'), off.arg * buf.dtype.base.itemsize, v.arg)
return UOp(Ops.NOOP)
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr)
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs))
if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf
devs, b = to_tuple(g.src[1].arg), buf.buffer
bufs = tuple(cast(Buffer, x.buffer) for x in buf.src) if buf.op is Ops.MSTACK else tuple(b.bufs if isinstance(b, MultiBuffer) else (b,)*len(devs))
assert len(bufs) == len(devs), f"can't resolve {len(bufs)} buffers on {len(devs)} devices"
addrs = tuple(UOp.const(dtypes.uint64, x.get_buf(d).va_addr) for x, d in zip(bufs, devs))
return addrs[0] if len(addrs) == 1 else UOp(Ops.STACK, dtypes.uint64.vec(len(addrs)), addrs)
def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp:
itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize
return UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * itemsize)
pm_resolve_patches = PatternMatcher([
# multi
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
# shrink on slice is shrink on base at offset
(UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"),
lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))),
# getaddr
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), # getaddr(slice(x)) -> offset+getaddr(x)
lambda bv, dev: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
(UPat(Ops.GETADDR, src=(UPat({Ops.BUFFER, Ops.MSTACK, Ops.MSELECT}, name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
(UPat(Ops.GETADDR, src=(UPat(name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
# folders
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
fold_const_store),
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast()
.store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store),
]) + symbolic_simple
# *****************
# 8. callify hcq programs
def to_param(bufs:list[UOp], ref:UOp) -> UOp:
bufs.append(ref)
return UOp.placeholder((ref.buffer.size,), ref.dtype, len(bufs)-1)
if ref not in bufs: bufs.append(ref)
return UOp.placeholder((ref.buffer.size,), ref.dtype, bufs.index(ref))
pm_to_param = PatternMatcher([(UPat({Ops.MSELECT, Ops.MSTACK, Ops.BUFFER}, name="r"), lambda ctx, r: to_param(ctx, r))])
def parametrize_host_buffers(call:UOp) -> UOp:
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=[]), bottom_up=True, name="parametrize host buffers")
return call.replace(src=(body, *bufs) + call.src[1:], tag="hcq_param")
# preserve original order of args
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=list(get_call_arg_uops(call))), bottom_up=True, name="parametrize host buffers")
# move vars to new slots
var_slots = {nm:len(bufs)+i for i,nm in enumerate(sorted({v.expr for v in body.variables() if v.op is Ops.PARAM}))}
body = body.substitute({v:v.replace(arg=replace(v.arg, slot=var_slots[v.expr])) for v in body.variables() if v.op is Ops.PARAM})
return call.replace(src=(body, *bufs) + tuple(x for x in call.src[1:] if x.op is Ops.BIND))
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
def callify_hcq(call:UOp) -> UOp:
prg = to_program(call.src[0].sink(arg=KernelInfo("hcq_submit"), tag=1), Device["CPU"].renderer)
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(prg,), arg="hcq").call(*call.src[1:], aux=call.arg.aux)
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq_param", name="call"), callify_hcq)])
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), callify_hcq)])
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
def hcq_schedule(linear:UOp) -> UOp:
@ -492,14 +581,17 @@ def hcq_schedule(linear:UOp) -> UOp:
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
linear = graph_rewrite(linear, pm_merge_queues, name="merge queues")
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
linear = graph_rewrite(linear, pm_annotate_devs, name="annotate devs")
linear = graph_rewrite(linear, pm_replace_params, name="replace params")
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
linear = graph_rewrite(linear, pm_pack_placeholders, walk=True, name="pack placeholders")
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
# realize starts from here
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, walk=True, name="bufferize placeholders", enter_calls=True)
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
linear = graph_rewrite(linear, pm_parametrize_host_buffers, name="parametrize host buffers")
linear = graph_rewrite(linear, pm_parametrize_host_buffers, walk=True, name="parametrize host buffers")
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
return linear

View file

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import cast
from typing import cast, Any, Callable
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
assert sys.platform != 'win32'
from dataclasses import dataclass
@ -89,25 +89,25 @@ def memory_barrier(ctx):
reg_done=getattr(ctx.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_DONE').addr[0], value=0xffffffff),
acquire_mem(ctx)))
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.device))
def pm4_wait(ctx, dst, val): return wait_reg_mem(ctx, val, mem=make_getaddr(dst, ctx.devs))
def pm4_barrier(ctx): return memory_barrier(ctx)
def pm4_store(ctx, dst, val):
if val.op is Ops.BINARY: return None
return release_mem(ctx, make_getaddr(dst, ctx.device), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
return release_mem(ctx, make_getaddr(dst, ctx.devs), val, ctx.pm4.data_sel__mec_release_mem__send_32_bit_low,
ctx.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
def pm4_timestamp(ctx, dst):
return release_mem(ctx, make_getaddr(dst, ctx.device), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
return release_mem(ctx, make_getaddr(dst, ctx.devs), 0, ctx.pm4.data_sel__mec_release_mem__send_gpu_clock_counter,
ctx.pm4.int_sel__mec_release_mem__none)
def pm4_program(ctx, prg):
data, info = prg.arg
lib_gpu, args = prg.src
prog_addr = make_getaddr(lib_gpu, ctx.device) + data.entry_point_offset
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.device)
args_addr = make_getaddr(args, ctx.device)
prog_addr = make_getaddr(lib_gpu, ctx.devs) + data.entry_point_offset
scratch_addr = make_getaddr(UOp.new_buffer(lib_gpu.device, data.private_segment_size, dtypes.uint8).rtag("scratch"), ctx.devs)
args_addr = make_getaddr(args, ctx.devs)
user_regs = []
if data.enable_private_segment_sgpr:
@ -174,7 +174,7 @@ pm_pm4_submit = PatternMatcher([(UPat(Ops.LINEAR, name="lin"),
class SDMAOps(FastEnum): COPY = auto(); POLL_REGMEM = auto(); FENCE = auto(); TRAP = auto(); TIMESTAMP = auto() # noqa: E702
def sdma_copy(ctx, dst, src, copy):
src_addr, dst_addr = make_getaddr(src, ctx.device), make_getaddr(dst, ctx.device)
src_addr, dst_addr = make_getaddr(src, ctx.devs), make_getaddr(dst, ctx.devs)
return UOp(Ops.LINEAR, dtypes.void, tuple([make_ins(SDMAOps.COPY,
ctx.sdma.SDMA_OP_COPY | ctx.sdma.SDMA_PKT_COPY_LINEAR_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_COPY_LINEAR),
ctx.sdma.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(min(copy.arg - off, ctx.max_copy_size) - 1), 0,
@ -183,17 +183,17 @@ def sdma_copy(ctx, dst, src, copy):
def sdma_wait(ctx, dst, val):
op = ctx.sdma.SDMA_OP_POLL_REGMEM | ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) \
| ctx.sdma.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1)
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.device)), val, 0xffffffff,
return make_ins(SDMAOps.POLL_REGMEM, op, *data64_le(make_getaddr(dst, ctx.devs)), val, 0xffffffff,
ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | ctx.sdma.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
def sdma_store(ctx, dst, val):
op = ctx.sdma.SDMA_OP_FENCE | (ctx.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if ctx.target[0] != 9 else 0)
return UOp(Ops.LINEAR, dtypes.void, (
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.device)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
make_ins(SDMAOps.FENCE, op, *data64_le(make_getaddr(dst, ctx.devs)), val), make_ins(SDMAOps.TRAP, ctx.sdma.SDMA_OP_TRAP, 0)))
def sdma_timestamp(ctx, dst):
op = ctx.sdma.SDMA_OP_TIMESTAMP | ctx.sdma.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(ctx.sdma.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL)
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.device)))
return make_ins(SDMAOps.TIMESTAMP, op, *data64_le(make_getaddr(dst, ctx.devs)))
pm_sdma_opsel = PatternMatcher([
(UPat(Ops.BARRIER), lambda: UOp(Ops.NOOP, dtypes.void, ())),
@ -516,11 +516,17 @@ class PCIIface(PCIIfaceBase):
def _mock(iface, name=None): return type(name or f"MOCK{iface.__name__}", (iface,), {})
@dataclass(frozen=True)
class AMDEncodeCtx: # encode-time constants for one queue: devs (every cmdbuf address resolves into these) + gfx version + packet/ip modules
devs: tuple[str, ...]; target: tuple[int, ...]; pm4: Any; sdma: Any; soc: Any # noqa: E702
gc: AMDIP; nbio: AMDIP; xccs: int; max_copy_size: int; tmpring_size: Callable # noqa: E702
def encode_queue(q:UOp) -> UOp|None:
if not (isinstance(q.arg, tuple) and len(q.arg) == 2 and isinstance(q.arg[1], str) and q.arg[1].startswith(("COMPUTE", "COPY"))): return None
devs = to_tuple(q.arg[0])
d = Device[(devs:=to_tuple(q.arg[0]))[0]]
ctx = AMDEncodeCtx(devs, d.target, d.pm4, d.sdma, d.soc, d.gc, d.nbio, d.xccs, d.max_copy_size, d.tmpring_size)
opsel, submit = (pm_pm4_opsel, pm_pm4_submit) if q.arg[1].startswith("COMPUTE") else (pm_sdma_opsel, pm_sdma_submit)
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=Device[devs[0]], name=f"{q.arg[1]} opsel"))
return submit.rewrite(graph_rewrite(q, opsel + pm_flatten_linear, walk=True, ctx=ctx, name=f"{q.arg[1]} opsel"))
pm_lower = PatternMatcher([
(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),)), encode_queue),

View file

@ -59,8 +59,8 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, inv_scale.uop)
return (None, None, grad_xw13_uop, None, None)
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor]:
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, new_amax)
# grad_amax_state: delayed amax for grad_xw13 fp8 quantization in the backward.
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
MBS, SEQ, H2 = xw13.shape
@ -72,5 +72,4 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname_of(xw13.device))
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, grad_amax_state,
fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, scalar_amax(amax_buf)
return fp8_out, scalar_amax(amax_buf)

View file

@ -112,8 +112,8 @@ def _fused_add_bwd(*args, **kwargs):
grad_h, grad_w = _bwd_common(fp8_grad_u, h_grad_u, x_u, x_normed_u, rrms_u, weight_u, amax_state_u, kernel)
return (None, None, None, None, None, grad_h, grad_h, grad_w, None)
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, inv_scale, new_amax, x_normed, rrms).
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor]:
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, new_amax, x_normed, rrms).
# x_normed + rrms are saved for the rmsnorm backward (also recomputed here from x regs).
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
@ -127,13 +127,12 @@ def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, e
fxn = functools.partial(_custom_fwd, dname=dname_of(x.device), eps_val=eps)
fp8_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
fp8_out, x_normed_out, rrms_out, amax_buf, x, weight, amax_state, fxn=fxn, grad_fxn=_fused_bwd)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, scalar_amax(amax_buf), x_normed_out, rrms_out
return fp8_out, scalar_amax(amax_buf), x_normed_out, rrms_out
def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor, amax_state:Tensor,
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# NOTE: h = x + residual; y_normed = rmsnorm(h); fp8 = quantize(y_normed * weight).
# Returns (fp8, inv_scale, new_amax, h, x_normed, rrms). h is also written so downstream can
# Returns (fp8, new_amax, h, x_normed, rrms). h is also written so downstream can
# reuse it without recomputing x+residual — eliminates the separate residual-add kernel.
assert x.dtype == dtypes.bfloat16 and residual.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
assert x.shape == residual.shape
@ -149,5 +148,4 @@ def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor,
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, x, residual, weight, amax_state,
fxn=fxn, grad_fxn=_fused_add_bwd)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out
return fp8_out, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out

View file

@ -0,0 +1,104 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
LOG2E = 1.4426950408889634
@functools.cache
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
act = w1 * sig * w3
abs_a = (act < 0.0).where(-act, act)
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
@functools.cache
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
VEC = 8
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
nwg = n_elems // (THREADS_PER_WG * VEC)
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
lane = UOp.range(VEC, 2, AxisType.UNROLL)
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
e8v = e8[idx // BLK].cast(dtypes.float)
qscale = (127.0 - e8v).exp2()
ga = grad_aq[idx].cast(dtypes.float) * qscale
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
s = w1 * sig
sprime = sig * (1.0 + w1 * (1.0 - sig))
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
device = x_w1.device
rows, K = x_w1.shape
axis = x_w1.axis if isinstance(device, tuple) else None
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
fxn=_custom_silu_mul_bwd_mxfp8)
return (None, None, None, gx1.uop, gx3.uop)
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x_w1.shape
scale_K = K // BLK
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
return fp8_out, e8_out, si_out

View file

@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
step = THREADS_PER_WG // 2
while step:
active = tid < step
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
other = lds[(tid + step).valid(active)].load()
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
step //= 2
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])

View file

@ -0,0 +1,71 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.helpers import prod
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
@functools.cache
def _custom_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x:UOp) -> UOp:
rows, K = x.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x = x.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
x_f = x[idx].cast(dtypes.float)
abs_x = (x_f < 0.0).where(-x_f, x_f)
blk_max = abs_x.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (x_f * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
# pack the 4 e8 of this super-block into one uint32 (little-endian: byte sb), write transposed (sk4, row)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"quantize_mxfp8_{n_elems}", opts_to_apply=()))
def _quantize_mxfp8_fused_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x = kernel.src[1:]
device = x.device
rows, K = x.shape
scale_K = K // BLK
e8 = Tensor(e8_out, device=device).reshape(rows, scale_K)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, BLK).reshape(rows, K)
grad_x = (Tensor(gradient, device=device).float() * qscale).cast(dtypes.bfloat16)
return (None, None, None, grad_x.uop)
def quantize_mxfp8_fused(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x.dtype == dtypes.bfloat16, f"expected bf16, got {x.dtype}"
assert x.ndim == 2, f"expected 2d (rows, K), got {x.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x.shape
scale_K = K // BLK
axis = x.uop.axis if isinstance(x.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x, fxn=_custom_quantize_mxfp8, grad_fxn=_quantize_mxfp8_fused_bwd)
return fp8_out, e8_out, si_out

View file

@ -9,8 +9,8 @@ EXAMPLES = {
"empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
"plus":"test/test_tiny.py TestTiny.test_plus",
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"",
"sync":"test/amd/test_custom_kernel.py TestCustomKernel.test_lds_sync",
"handwritten":"test/amd/test_custom_kernel.py TestCustomKernel.test_handwritten",
"sync":"test/amd/test_asm_kernel.py TestAsmKernel.test_lds_sync",
"handwritten":"test/amd/test_asm_kernel.py TestAsmKernel.test_handwritten",
}
if __name__ == "__main__":

View file

@ -255,7 +255,7 @@ def _sign(serial_num, out_dir):
final = _build_elf(ehdr, allph, segs)
os.makedirs(out_dir, exist_ok=True)
out = os.path.join(out_dir, "testsig-0x%08X.so" % serial_num)
out = os.path.join(out_dir, "testsig-0x%08x.so" % serial_num)
with open(out, 'wb') as f:
f.write(final)
print("Signing complete! Output saved at %s" % out)

View file

@ -51,7 +51,8 @@ def _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, sha
return None, None, dq.uop, dk.uop, dv.uop
return grad
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
# TODO: remove write_flat once scheduler can remove reshapes between custom_kernel. TestCustomKernel.test_simple_reshape
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False, write_flat:bool=False):
assert attn_mask is None, "attn_mask not supported"
assert is_causal, "only causal attention supported"
@ -73,6 +74,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
arch = Device[single_device].renderer.target.arch
attn = _sharded_empty_like(xq, axis=shard_axis)
attn = _sharded_empty((B, N, H * D), xq, axis=shard_axis) if write_flat else _sharded_empty_like(xq, axis=shard_axis)
l_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
grad = _fa_grad_fxn(B, H, N, D, H_local, H_KV_local, H_KV, B_local, shard_axis, shard_axis_t, single_device, arch)

View file

@ -26,7 +26,7 @@ constexpr int NUM_THREADS = WARP_THREADS * NUM_WARPS;
using G = kittens::group<NUM_WARPS>;
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr) {
__global__ __launch_bounds__(NUM_THREADS, 2) void hk_bf16_gemm(bf16 *C_ptr, bf16 *A_ptr, bf16 *B_ptr, bf16 *b_unused) {
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
static_assert(M % BLOCK_SIZE == 0 && N % BLOCK_SIZE == 0 && K % K_STEP == 0);

View file

@ -99,12 +99,14 @@ using G = kittens::group<NUM_WARPS>;
#endif
__global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr
#if SCALE_MODE == 1
#if SCALE_MODE & 1
, float *x_scale_ptr
#elif SCALE_MODE == 2
#endif
#if SCALE_MODE & 2
, float *w_scale_ptr
#elif SCALE_MODE == 3
, float *x_scale_ptr, float *w_scale_ptr
#endif
#if SCALE_MODE & 4
, float *g_scale_ptr
#endif
) {
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
@ -346,20 +348,19 @@ __global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_pt
}
// apply x_scale * w_scale before bf16 store to prevent overflow
#if SCALE_MODE == 1
float scale = *x_scale_ptr;
mul(cA, cA, scale);
mul(cB, cB, scale);
mul(cC, cC, scale);
mul(cD, cD, scale);
#elif SCALE_MODE == 2
float scale = *w_scale_ptr;
mul(cA, cA, scale);
mul(cB, cB, scale);
mul(cC, cC, scale);
mul(cD, cD, scale);
#elif SCALE_MODE == 3
float scale = *x_scale_ptr * *w_scale_ptr;
#if SCALE_MODE != 0
float scale = 1.0f;
#if SCALE_MODE & 1
float x_scale = (*x_scale_ptr + 1e-08f) * (1.0f / 448.0f);
scale *= x_scale;
#endif
#if SCALE_MODE & 2
scale *= *w_scale_ptr;
#endif
#if SCALE_MODE & 4
scale *= *g_scale_ptr;
#endif
mul(cA, cA, scale);
mul(cB, cB, scale);
mul(cC, cC, scale);

View file

@ -26,9 +26,9 @@ constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
using G = kittens::group<NUM_WARPS>;
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
const uint32_t *__restrict__ scale_A_iter,
const uint32_t *__restrict__ scale_B_iter) {
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr, fp8e8m0 *scale_A_ptr, fp8e8m0 *scale_B_ptr,
const uint8_t *__restrict__ a_e8_unused,
const uint8_t *__restrict__ b_e8_unused) {
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
kittens::gl<fp8e4m3, 1, 1, M, K> A{A_ptr, nullptr, nullptr, nullptr, nullptr};
@ -38,15 +38,19 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
constexpr int k_iters = K / BLOCK_K;
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
kittens::gl<fp8e8m0, k_iters * (M / BLOCK_ROW), 1, 16, 64> scale_A_gl{scale_A_ptr, nullptr, nullptr, nullptr, nullptr};
kittens::gl<fp8e8m0, k_iters * (N / BLOCK_COL), 1, 16, 64> scale_B_gl{scale_B_ptr, nullptr, nullptr, nullptr, nullptr};
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
using ST_Scale = st<fp8e8m0, 16, 64, st_16x64_s>;
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
__shared__ ST_A As[2][2];
__shared__ ST_B Bs[2][2];
__shared__ uint8_t smem_scales[2048];
__shared__ ST_Scale scale_A_smem[2], scale_B_smem[2];
RT_A a;
RT_B b0, b1;
@ -79,54 +83,87 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
G::prefill_swizzled_offsets(As[0][0], A, sw_A);
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
const T *a_base = (const T *)&A[{0, 0, 0, 0}];
const T *b_base = (const T *)&B[{0, 0, 0, 0}];
const int a_row_stride = A.template stride<2>() * sizeof(T);
const int b_row_stride = B.template stride<2>() * sizeof(T);
i32x4 a_srd = make_srsrc(a_base, M * a_row_stride, a_row_stride);
i32x4 b_srd = make_srsrc(b_base, N * b_row_stride, b_row_stride);
const int wid = warpid() % NUM_WARPS;
constexpr int elem_per_warp = (16 / sizeof(T)) * kittens::WARP_THREADS;
uint32_t a_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t a_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t a_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t a_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&As[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_00 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_01 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[0][1].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_10 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][0].data[0]) + wid * elem_per_warp * sizeof(T)));
uint32_t b_lds_11 = __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(&Bs[1][1].data[0]) + wid * elem_per_warp * sizeof(T)));
int a_row_h0 = warp_m * REG_M;
int a_row_h1 = HALF_ROW + warp_m * REG_M;
int b_row_h0 = warp_n * REG_N;
int b_row_h1 = HALF_COL + warp_n * REG_N;
int tic = 0, toc = 1;
int tic_scales = 0, toc_scales = 1;
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B);
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A);
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B);
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A);
uint32_t a_lds[2][2] = {{a_lds_00, a_lds_01}, {a_lds_10, a_lds_11}};
uint32_t b_lds[2][2] = {{b_lds_00, b_lds_01}, {b_lds_10, b_lds_11}};
if constexpr (k_iters >= 6 && (k_iters % 2 == 0)) {
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][1]));
if (warp_m == 1) __builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt vmcnt(4)");
__builtin_amdgcn_s_barrier();
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A);
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B);
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B);
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][0]));
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][0]));
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][1]));
asm volatile("s_waitcnt vmcnt(6)");
__builtin_amdgcn_s_barrier();
#pragma unroll 2
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1) {
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
G::load(scale_A_smem[0], scale_A_gl, {0 * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[0], scale_B_gl, {0 * tiles_N + block_col, 0, 0, 0});
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
int sa_idx = block_row, sb_idx = block_col;
#pragma unroll 2
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) {
if (k + 1 < k_iters) {
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
asm volatile("s_waitcnt lgkmcnt(8)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
load(b1, bs1);
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A);
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0]));
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
@ -135,10 +172,9 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
load(a, as1);
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B);
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0]));
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
@ -148,7 +184,7 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B);
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1]));
asm volatile("s_waitcnt vmcnt(6)");
__builtin_amdgcn_s_barrier();
@ -160,19 +196,24 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
{ // Epilogue k = k_iters - 2
int k = k_iters - 2;
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
if (k + 1 < k_iters) {
sa_idx += tiles_M; sb_idx += tiles_N;
G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0});
G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0});
}
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1]));
__builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt lgkmcnt(0)");
@ -215,17 +256,17 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_sched_barrier(0);
tic ^= 1; toc ^= 1;
tic_scales ^= 1; toc_scales ^= 1;
}
{ // Epilogue k = k_iters - 1 (final)
int k = k_iters - 1;
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0);
@ -261,6 +302,38 @@ __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3
__builtin_amdgcn_s_barrier();
}
} else {
#pragma unroll 1
for (int kk = 0; kk < k_iters; kk++) {
G::load(As[0][0], A, {0, 0, block_row * 2, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][0]));
G::load(As[0][1], A, {0, 0, block_row * 2 + 1, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][1]));
G::load(Bs[0][0], B, {0, 0, block_col * 2, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][0]));
G::load(Bs[0][1], B, {0, 0, block_col * 2 + 1, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][1]));
G::load(scale_A_smem[0], scale_A_gl, {kk * tiles_M + block_row, 0, 0, 0});
G::load(scale_B_smem[0], scale_B_gl, {kk * tiles_N + block_col, 0, 0, 0});
asm volatile("s_waitcnt vmcnt(0)");
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[0].data, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[0].data, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[0].data, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[0].data, b_row_h1);
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][0], {warp_n, 0}); load(b0, bs0);
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[0][1], {warp_n, 0}); load(b1, bs1);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[0][0], {warp_m, 0}); load(a, as0);
asm volatile("s_waitcnt lgkmcnt(0)");
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[0][1], {warp_m, 0}); load(a, as1);
asm volatile("s_waitcnt lgkmcnt(0)");
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
__builtin_amdgcn_s_barrier();
}
}
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});

View file

@ -438,6 +438,35 @@ __device__ __forceinline__ T buffer_load_dwordx2(
return result;
}
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR_START >= 256) {
asm volatile("buffer_load_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
} else {
asm volatile("buffer_load_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
// Returns a u32x3 (ext_vector_type(3)). Note sizeof(u32x3) == 16 due to vector
// alignment padding, but the compiler still allocates exactly 3 consecutive VGPRs
// for the "=v" operand and emits buffer_load_dwordx3 v[N:N+2].
template<typename T = u32x3>
__device__ __forceinline__ T buffer_load_dwordx3(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
T result;
asm volatile("buffer_load_dwordx3 %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR_START >= 256) {
@ -465,6 +494,116 @@ __device__ __forceinline__ T buffer_load_dwordx4(
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_ubyte(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_sbyte(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_sbyte %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_ushort(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_sshort(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_sshort %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
// d16 variants merge the loaded sub-dword into the low (or high, for _hi) 16 bits
// of the destination register, leaving the other half untouched: read-modify-write.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_ubyte_d16(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_ubyte_d16 %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_ubyte_d16_hi(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_ubyte_d16_hi %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_sbyte_d16(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_sbyte_d16 %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_sbyte_d16_hi(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_sbyte_d16_hi %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_short_d16(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_short_d16 %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_load_short_d16_hi(
T& dst, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_load_short_d16_hi %0, %1, %2, %3 offen offset:%4"
: "+v"(dst)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR>
__device__ __forceinline__ void buffer_store_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
@ -519,6 +658,32 @@ __device__ __forceinline__ void buffer_store_dwordx2(
: "memory");
}
template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx3(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
if constexpr (GPR_START >= 256) {
asm volatile("buffer_store_dwordx3 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 2 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
// VGPRS
} else {
asm volatile("buffer_store_dwordx3 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 2), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
template<typename T = u32x3>
__device__ __forceinline__ void buffer_store_dwordx3(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
asm volatile("buffer_store_dwordx3 %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
@ -546,6 +711,50 @@ __device__ __forceinline__ void buffer_store_dwordx4(
: "memory");
}
// Stores the low 8 bits of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_byte(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
// Stores the high 16 bits' low 8 bits (bits [23:16]) of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_byte_d16_hi(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_byte_d16_hi %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
// Stores the low 16 bits of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_short(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
// Stores the high 16 bits of the value register.
template<typename T = uint32_t>
__device__ __forceinline__ void buffer_store_short_d16_hi(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("buffer_store_short_d16_hi %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR>
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR >= 256) {

View file

@ -52,6 +52,8 @@ __device__ __forceinline__ int laneid() { return threadIdx.x & 0x3f; }
using i32x2 = int32_t __attribute__((ext_vector_type(2)));
using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
using i32x3 = int32_t __attribute__((ext_vector_type(3)));
using u32x3 = uint32_t __attribute__((ext_vector_type(3)));
using i32x4 = int32_t __attribute__((ext_vector_type(4)));
using u32x4 = uint32_t __attribute__((ext_vector_type(4)));

View file

@ -0,0 +1,10 @@
/**
* @file
* @brief Aggregate header for CDNA4 ThunderKittens core include tree.
*/
#pragma once
#include "common/common.cuh"
#include "types/types.cuh"
#include "ops/ops.cuh"

View file

@ -248,7 +248,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
"ds_read_b64_tr_b16 %0, %2 offset:%3\n"
"ds_read_b64_tr_b16 %1, %2 offset:%4\n"
// "s_waitcnt lgkmcnt(0)\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
: "v"(addr), "i"(offset), "i"(offset + 4 * ST::underlying_subtile_row_bytes)
: "memory"
@ -259,7 +259,7 @@ __device__ inline static void load(RT &dst, const ST &src) {
"ds_read_b64_tr_b16 %0, %2 offset:%4\n"
"ds_read_b64_tr_b16 %1, %3 offset:%4\n"
// "s_waitcnt lgkmcnt(0)\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx])),
"=v"(*reinterpret_cast<float2*>(&dst.tiles[register_row][register_col].data[idx + 2]))
: "v"(addr), "v"(next_addr), "i"(offset)
: "memory"
@ -688,4 +688,88 @@ __device__ inline static void store(ST &dst, const RT &src) {
}
}
template<ducks::rt::col_layout RT, ducks::st::all ST>
__device__ inline static void load(RT &dst, const ST &src, int col_offset) {
static_assert(RT::cols == ST::rows,
"col_layout load with col_offset: ST.rows must equal RT::cols");
static_assert(RT::width == 1, "col_layout load with col_offset: only width==1 supported");
using T2 = RT::dtype;
using U = ST::dtype;
using U2 = base_types::packing<U>::packed_type;
const int laneid = kittens::laneid();
const uint32_t src_ptr = (uint32_t)(uintptr_t)(&src.data[0]);
constexpr int subtile_bytes = ST::underlying_subtile_bytes;
if constexpr (std::is_same_v<T2, fp8e4m3_4>) {
const int block_id = laneid / 16;
const int l_within = laneid % 16;
const int tr_k_grp = l_within / 2;
const int m_half = l_within & 1;
#pragma unroll
for (int i = 0; i < RT::height; i++) {
const int m_col = col_offset + i * RT::base_tile_rows + m_half * 8;
const int subtile_base = block_id;
uint32_t addr = src_ptr
+ (uint32_t)(subtile_base * subtile_bytes)
+ src.swizzle({tr_k_grp, m_col});
asm volatile(
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[4]))
: "v"(addr), "i"(4 * subtile_bytes)
: "memory"
);
addr ^= 1088u;
asm volatile(
"ds_read_b64_tr_b8 %0, %2 offset:0\n"
"ds_read_b64_tr_b8 %1, %2 offset:%3\n"
: "=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2])),
"=&v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[6]))
: "v"(addr), "i"(4 * subtile_bytes)
: "memory"
);
}
} else if constexpr (std::is_same_v<U2, bf16_2> || std::is_same_v<U2, half_2>) {
const int row_offset = ((laneid % 16) / 4) + ((laneid / 16) * 4);
const int col_offset_lane = (laneid % 4) * 4;
#pragma unroll
for (int i = 0; i < RT::height; i++) {
const int m_col = col_offset + i * RT::base_tile_rows + col_offset_lane;
const uint32_t addr = src_ptr + src.swizzle({row_offset, m_col});
const uint32_t next_addr = src_ptr + src.swizzle({row_offset + 4, m_col});
if constexpr (RT::base_tile_stride == 8) {
asm volatile(
"ds_read_b64_tr_b16 %0, %2 offset:0\n"
"ds_read_b64_tr_b16 %1, %3 offset:0\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0])),
"=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[2]))
: "v"(addr), "v"(next_addr)
: "memory"
);
} else if constexpr (RT::base_tile_stride == 4) {
asm volatile(
"ds_read_b64_tr_b16 %0, %1 offset:0\n"
: "=v"(*reinterpret_cast<float2*>(&dst.tiles[i][0].data[0]))
: "v"(addr)
: "memory"
);
} else {
static_assert(false, "Unsupported stride for col_layout load with col_offset");
}
}
} else {
static_assert(false, "Unsupported type for col_layout load with col_offset");
}
}
} // namespace kittens

View file

@ -201,10 +201,10 @@ __device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::col,
* @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode.
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
*/
template<ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
template<ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T=bf16>
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col, D_shape> &d,
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b, // in row-major mode
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
const rt_base<float, ducks::rt_layout::col, C_shape> &c) {
static_assert(std::is_same_v<D_shape, C_shape>, "D and C must have the same shape");
@ -243,6 +243,32 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
}
}
/**
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
*
* @param smem_scales LDS pointer to scale tile data.
* @param row_offset Starting row within the scale region (warp's tile offset).
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
*/
__device__ __forceinline__ fp8e8m0_4 pack_scales(
const fp8e8m0 *smem_scales, int row_offset) {
int lid = laneid();
int r16 = lid % 16;
int k_sub = lid / 16;
const fp8e8m0_4 *s4 = (const fp8e8m0_4 *)smem_scales;
fp8e8m0_4 w0 = s4[row_offset + 0 * 16 + r16];
fp8e8m0_4 w1 = s4[row_offset + 1 * 16 + r16];
fp8e8m0_4 w2 = s4[row_offset + 2 * 16 + r16];
fp8e8m0_4 w3 = s4[row_offset + 3 * 16 + r16];
fp8e8m0_4 sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
fp8e8m0_4 lo = __builtin_amdgcn_perm(w0, w1, sel);
fp8e8m0_4 hi = __builtin_amdgcn_perm(w2, w3, sel);
return lo | (hi << 16);
}
/**
* @brief Base dot product operation for row layout.
*
@ -254,10 +280,10 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
* @param[in] b The second input rt_base<Operand_T, row_layout> matrix.
* @param[in] c The input rt_base<float, col_layout> accumulator matrix.
*/
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_layout::all A_layout, ducks::rt_shape::all A_shape, ducks::rt_layout::all B_layout, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
__device__ static inline void mma_ABt_base_scaled(rt_base<float, ducks::rt_layout::col, D_shape> &d,
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b,
const rt_base<MM_Operand_T, A_layout, A_shape> &a,
const rt_base<MM_Operand_T, B_layout, B_shape> &b,
const rt_base<float, ducks::rt_layout::col, C_shape> &c,
const fp8e8m0_4 *scale_a,
const fp8e8m0_4 *scale_b) {
@ -441,7 +467,7 @@ __device__ static inline void mma_AB(D &d,
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode.
* @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix.
*/
template<ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
template<ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
__device__ static inline void mma_ABt(D &d,
const A &a,
const B &b, // notice row and (M, K) instead of col and (K, M)
@ -499,7 +525,7 @@ __device__ static inline void mma_ABt(D &d,
* @param[in] scale_a Pointer to the packed E8M0 scale for the A matrix.
* @param[in] scale_b Pointer to the packed E8M0 scale for the B matrix.
*/
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::all A, ducks::rt::all B, ducks::rt::col_layout C>
__device__ static inline void mma_ABt_scaled(D &d,
const A &a,
const B &b,

View file

@ -9,6 +9,5 @@
#include "maps.cuh"
#include "reductions.cuh"
#include "mma.cuh"
#include "scales.cuh"
#include "assembly/tile.cuh"

Some files were not shown because too many files have changed in this diff Show more