Compare commits

...

219 commits

Author SHA1 Message Date
George Hotz
07102624a1
Merge branch 'master' into dsp_search_merged 2025-04-02 20:32:23 +08:00
George Hotz
ed76dd71eb
Merge branch 'master' into dsp_search_merged 2025-04-02 18:11:25 +08:00
George Hotz
5d6e8bd681 you can render where 2025-04-02 17:37:40 +08:00
George Hotz
f90656c647 cleanup dsp junk 2025-04-02 17:29:15 +08:00
George Hotz
86727875f9 delete junk 2025-04-02 16:51:08 +08:00
George Hotz
7e4ac744ac
Merge branch 'master' into dsp_search_merged 2025-04-02 16:45:20 +08:00
George Hotz
4496cc6e61 cleanup comments and flip rewrite 2025-04-02 16:20:26 +08:00
George Hotz
13d3bcb6e1 fix for many 2025-04-02 15:38:08 +08:00
George Hotz
efad1ebd0d
Merge branch 'master' into dsp_search_merged 2025-04-02 15:29:42 +08:00
George Hotz
e20eed6208 touch ups 2025-04-02 15:03:41 +08:00
George Hotz
7dc265ef93 syntax error 2025-04-02 14:55:36 +08:00
George Hotz
13dec71ab0 syntax error 2025-04-02 14:54:13 +08:00
George Hotz
bb453613ac syntax error 2025-04-02 14:53:04 +08:00
George Hotz
d6c3ae186b syntax error 2025-04-02 14:52:09 +08:00
George Hotz
c066653428 optimize is an option 2025-04-02 14:50:27 +08:00
George Hotz
66c6d35fe2 fix speed 2025-04-02 14:41:37 +08:00
George Hotz
64e1ddf2a9 fix bad merge 2025-04-02 13:21:55 +08:00
George Hotz
d6013a2d50 faster 2025-04-02 13:18:37 +08:00
George Hotz
1d36aa8116
Merge branch 'master' into dsp_search_merged 2025-04-02 13:09:13 +08:00
George Hotz
6251ab3d90 Merge remote-tracking branch 'origin/master' into dsp_search_merged 2025-04-02 12:25:01 +08:00
George Hotz
dd51728795 fix all bugs 2025-04-02 11:07:32 +08:00
George Hotz
a59b1ed970 bring speed back 2025-04-02 10:21:08 +08:00
George Hotz
147fc0e648 minor l2fetch tweak 2025-04-02 10:09:24 +08:00
George Hotz
17f7b226cb dsp masked stores are stores 2025-04-02 09:41:56 +08:00
George Hotz
9c34d9eb6e mask the full index 2025-04-02 09:18:15 +08:00
George Hotz
95261b6193 opts back 2025-04-02 09:13:30 +08:00
George Hotz
1e2becfeae fix pad 2025-04-02 09:07:27 +08:00
George Hotz
e18cdbcbe2 correct 2025-04-02 00:13:38 +08:00
George Hotz
f3cb4c3eef oops 2025-04-01 23:44:44 +08:00
George Hotz
6ecaf11224 ugh many hacks 2025-04-01 23:33:09 +08:00
George Hotz
8b24f9cb0d oops, didn't mean to change that 2025-04-01 17:55:04 +08:00
George Hotz
797e512c00 all correct 2025-04-01 17:51:24 +08:00
George Hotz
f600482982 correctness 2025-04-01 17:27:16 +08:00
George Hotz
da35edbb55 reenable that upcast 2025-04-01 17:09:02 +08:00
George Hotz
661431ee75 correctness 2025-04-01 17:01:46 +08:00
George Hotz
8340d9c1c2 disable padding 2025-04-01 16:27:54 +08:00
George Hotz
910cddbbca correct but slower 2025-04-01 16:11:47 +08:00
George Hotz
e6e0c0ec86 should work 2025-04-01 15:25:15 +08:00
George Hotz
d0eedb5a79 hack 2025-04-01 15:05:13 +08:00
George Hotz
f69deddbd4 opt 2025-04-01 14:43:36 +08:00
George Hotz
be11fbbf78 works 2025-04-01 14:38:38 +08:00
George Hotz
812c391617 fp mul 2025-04-01 13:43:16 +08:00
George Hotz
3306083f42 YOU DIDNT FOIL 2025-04-01 12:32:00 +08:00
George Hotz
18d7e9d3f1 oops 2025-04-01 11:56:57 +08:00
George Hotz
1c3f249ecf fix multicore flop tracking 2025-04-01 10:16:01 +08:00
nimlgen
bb7b89475c
dsp multicore 2 (#9644)
* dsp multicore 2

* hmm

* better
2025-03-31 23:56:54 +08:00
George Hotz
8005e6c974 write test pkl imagenet 2025-03-31 19:37:28 +08:00
George Hotz
a3d61a0372 save pkl from benchmark 2025-03-31 19:31:48 +08:00
George Hotz
c73e35aa24 non const fix 2025-03-31 19:10:06 +08:00
George Hotz
0b4b9f61b9 simpler 2025-03-31 19:03:06 +08:00
George Hotz
ee3ddfcdc1 many l2fetch 2025-03-31 18:58:52 +08:00
George Hotz
220d682489 prefetch l2 is so winning 2025-03-31 18:29:12 +08:00
George Hotz
9c388c3539 try to be smarter 2025-03-31 18:23:49 +08:00
George Hotz
4b3a4c8c46 fix prefetch l2 2025-03-31 18:09:48 +08:00
George Hotz
eb606d7230 MULTICORE=1 PYTHONPATH=. QUANTIZE=1 DEBUG=2 DEVECTORIZE=0 python3 extra/replay_pkl.py /tmp/im.pkl 2025-03-31 15:37:07 +08:00
George Hotz
49d52a2763 support acc in __builtin_HEXAGON_A2_vraddub 2025-03-31 15:12:00 +08:00
George Hotz
a59c3dd09a err, that's a bug 2025-03-31 14:56:15 +08:00
George Hotz
a640292aed delete extra 2025-03-31 14:35:32 +08:00
George Hotz
2f48c12441
Merge branch 'master' into dsp_search 2025-03-31 14:27:27 +08:00
George Hotz
be3b5efc64 fix precommit a bit 2025-03-31 14:26:19 +08:00
George Hotz
996d0ac1d2 multicore all the way 2025-03-31 14:17:19 +08:00
George Hotz
77e897b3b1
Merge branch 'master' into dsp_search 2025-03-31 13:03:29 +08:00
George Hotz
273dde69bd remove range split support 2025-03-31 12:43:21 +08:00
George Hotz
a64030d8c8 ignore hacks 2025-03-31 12:36:39 +08:00
George Hotz
9b19129e87 mc 2025-03-31 11:34:22 +08:00
George Hotz
48221d9024 2 global dim 2025-03-31 11:25:12 +08:00
George Hotz
bcfcd60f55 opt weights 2025-03-31 11:02:03 +08:00
George Hotz
abc90024ac hand coded opts 2025-03-31 10:44:09 +08:00
George Hotz
f0e6d8394c
Merge branch 'master' into dsp_search 2025-03-31 10:01:19 +08:00
George Hotz
a1c1ecd597
Merge branch 'master' into dsp_search 2025-03-29 10:34:32 +08:00
nimlgen
489a5e24c4
Merge branch 'master' into dsp_search 2025-03-28 19:08:17 +07:00
George Hotz
e0fd84dd64 add locals 2025-03-28 18:52:48 +08:00
George Hotz
1a9d7a1628 upcast small 2025-03-28 18:31:17 +08:00
George Hotz
45646fe102 optional l2 fetch 2025-03-28 18:12:48 +08:00
George Hotz
9c928afafe tighter l2fetch foce 2025-03-28 18:08:29 +08:00
George Hotz
d4f1c5049b tighter l2fetch 2025-03-28 18:05:40 +08:00
George Hotz
11b478f85d prefetch l2 2025-03-28 17:59:05 +08:00
George Hotz
0aa7031b5f simpler 2025-03-28 17:42:14 +08:00
George Hotz
ab67d5ff6e unused 2025-03-28 17:37:38 +08:00
George Hotz
cbe23e13c2 ignore there 2025-03-28 17:01:21 +08:00
George Hotz
9bbd12dc65 bugfixes 2025-03-28 16:49:36 +08:00
George Hotz
b09142a893 where on two adds 2025-03-28 15:00:09 +08:00
George Hotz
1d7faf4777 simpler mult 2025-03-28 14:47:41 +08:00
George Hotz
59438be39b fix fixed point mult 2025-03-28 12:14:50 +08:00
George Hotz
cc23836a38 add 128 stores 2025-03-28 09:21:43 +08:00
George Hotz
e4354effa2
Merge branch 'master' into dsp_search 2025-03-28 09:09:14 +08:00
George Hotz
d180e909a3 debug simplify 2025-03-27 17:42:13 +08:00
George Hotz
52364231dc fast kernel 1 2025-03-27 17:12:10 +08:00
George Hotz
d32ad080c3 fast 66 2025-03-27 16:47:58 +08:00
George Hotz
a8bd26d9bc full_shape 2025-03-27 16:39:15 +08:00
George Hotz
6d860389f4 issue 2025-03-27 16:21:02 +08:00
George Hotz
5d5286489d block those ones 2025-03-27 16:12:47 +08:00
George Hotz
917e0e925b load rewrites 2025-03-27 16:04:53 +08:00
George Hotz
6081f8427e we can do that in dsp_pm 2025-03-27 16:01:08 +08:00
George Hotz
23035bf028 not consts 2025-03-27 15:40:32 +08:00
George Hotz
5e33163ef3
Merge branch 'master' into dsp_search 2025-03-27 13:16:48 +08:00
George Hotz
cee9fc7540 new_ignore 2025-03-27 13:16:30 +08:00
George Hotz
9041072dea
Merge branch 'master' into dsp_search 2025-03-27 12:16:27 +08:00
George Hotz
444d6279ac flod better 2025-03-27 12:03:15 +08:00
George Hotz
f27f484621 hacks 2025-03-27 10:51:41 +08:00
George Hotz
38488ec3b0 extend to 128 2025-03-27 10:35:06 +08:00
George Hotz
ff96f0adae 7 2025-03-27 00:29:00 +08:00
George Hotz
5dd59a6096 touchup 2025-03-27 00:23:58 +08:00
George Hotz
6bec82b918 on 7 2025-03-27 00:05:18 +08:00
George Hotz
a436d7542f up 7 2025-03-27 00:00:40 +08:00
George Hotz
5d98688de6 ugh 2025-03-26 23:39:33 +08:00
George Hotz
09d877ed8c
Merge branch 'master' into dsp_search 2025-03-26 23:35:21 +08:00
George Hotz
6ff894d674 looks fast 2025-03-26 22:48:02 +08:00
George Hotz
da03b4520a
Merge branch 'master' into dsp_search 2025-03-26 22:38:41 +08:00
George Hotz
013c6e0b10 index 2025-03-26 21:55:25 +08:00
George Hotz
31ffa1607e
Merge branch 'master' into dsp_search 2025-03-26 21:43:22 +08:00
George Hotz
928994c6ea bugfix 2025-03-26 20:10:15 +08:00
George Hotz
e283bec62e l1prefetch back 2025-03-26 19:51:18 +08:00
George Hotz
c4f5db8467 validate index 2025-03-26 19:49:34 +08:00
George Hotz
bf0d928417 add back index check 2025-03-26 19:42:44 +08:00
George Hotz
f823324eb9 fix 2025-03-26 19:34:50 +08:00
George Hotz
6995e0c91b
Merge branch 'master' into dsp_search 2025-03-26 18:37:55 +08:00
George Hotz
b934b5b907 cleanup expander 2025-03-26 18:11:40 +08:00
George Hotz
290ba9ee37 more cleanups 2025-03-26 17:59:26 +08:00
George Hotz
e0d63696d7 cleanups 2025-03-26 17:55:48 +08:00
George Hotz
acafd57f14
Merge branch 'master' into dsp_search 2025-03-26 17:49:15 +08:00
George Hotz
905f847d10 fix dq 2025-03-26 17:43:10 +08:00
George Hotz
9e19cdfbbe e kernels 2025-03-26 16:49:23 +08:00
George Hotz
f7b38fa94c make that 8 2025-03-26 16:46:49 +08:00
George Hotz
bd03942bd8 fix reduce acc 2025-03-26 16:41:59 +08:00
George Hotz
880b4a5e47 put that back 2025-03-26 15:34:57 +08:00
George Hotz
2e4cae342b less terrible first 2025-03-26 15:21:38 +08:00
George Hotz
8660fecb02 unroll both sides 2025-03-26 15:12:02 +08:00
George Hotz
e3e43df0c9 knum 5 split 2025-03-26 15:00:08 +08:00
George Hotz
a47e61b097 big hacks 2025-03-26 13:50:11 +08:00
George Hotz
f1ff18acec prepad kernel weights 2025-03-26 12:13:46 +08:00
George Hotz
60cbfe4222
Merge branch 'master' into dsp_search 2025-03-26 10:50:00 +08:00
George Hotz
311df3ff21 fixes 2025-03-25 19:14:57 +08:00
George Hotz
f6e64a5e8e optional conv 2025-03-25 19:00:48 +08:00
George Hotz
622ff115a3 back 2025-03-25 18:47:40 +08:00
George Hotz
5a6e8ee268 fix test 2025-03-25 18:37:43 +08:00
George Hotz
a9f1227625 4 faster 2025-03-25 18:32:27 +08:00
George Hotz
74c2587ef4 4 -> 8 2025-03-25 18:31:43 +08:00
George Hotz
bce252e0b8 devec 0 2025-03-25 17:12:43 +08:00
George Hotz
66a90a3c92 ugh, fast 2? 2025-03-25 17:08:35 +08:00
George Hotz
0d76b0d461 acc2 2025-03-25 15:33:38 +08:00
George Hotz
5e4505d363 kernel 2 54 GFLOPS 2025-03-25 14:04:03 +08:00
George Hotz
29920b74d5 unsafe disable on device 2025-03-25 13:57:57 +08:00
George Hotz
ccd18a803c faster? 2025-03-25 13:41:34 +08:00
George Hotz
943bde47ab fast k26 2025-03-25 13:32:11 +08:00
George Hotz
0d10c7ae2f working on kernel 15 2025-03-25 12:00:16 +08:00
George Hotz
3cab6a3d4a 3x3 2025-03-24 16:36:53 +08:00
George Hotz
22a56cbaea something for 8 2025-03-24 16:16:59 +08:00
George Hotz
afd61730b4 kernel 5 2025-03-24 16:10:08 +08:00
George Hotz
536556434b padding ish 2025-03-24 15:57:03 +08:00
George Hotz
52bff5f39d more 2025-03-24 15:33:48 +08:00
George Hotz
64d0f14d3d broken 2025-03-24 15:21:43 +08:00
George Hotz
1b61cc6ec3 unaligned 2025-03-24 15:13:57 +08:00
George Hotz
6f792e8045 vmemu 2025-03-24 15:02:40 +08:00
George Hotz
b1f8018bf4 unaligned load 2025-03-24 14:54:11 +08:00
George Hotz
2eb9241329 better conv 2025-03-24 13:07:14 +08:00
George Hotz
554a490751
Merge branch 'master' into dsp_search 2025-03-24 12:29:22 +08:00
George Hotz
651c678edf work 2025-03-24 09:49:53 +08:00
George Hotz
3274bd2d81 output 2025-03-23 15:13:00 +08:00
George Hotz
30f4d64148 rules 2025-03-22 19:17:16 +08:00
George Hotz
2634975d5a 5 and 8 2025-03-22 19:14:04 +08:00
George Hotz
fd73ec2b1b knum 2025-03-22 18:59:54 +08:00
George Hotz
e1d2bec4a4 opt 2025-03-22 18:52:56 +08:00
George Hotz
1b4e9f5e91 more opt rules 2025-03-22 18:07:31 +08:00
George Hotz
25c023bcbe more 2025-03-22 17:49:34 +08:00
George Hotz
07abf9e6bc multi_add_int32 2025-03-22 17:33:56 +08:00
George Hotz
26b02a037c fix 33 2025-03-22 17:17:47 +08:00
George Hotz
5089a601c6 name it 2025-03-22 14:44:01 +08:00
George Hotz
6b49a63c48 linearizer workaround 2025-03-22 14:18:02 +08:00
George Hotz
dca95428a5 touch 2025-03-22 11:05:36 +08:00
George Hotz
8a477ba4e1 knum 3 2025-03-21 20:36:18 +08:00
George Hotz
264dd91b8a 70 GFLOPS 2025-03-21 20:31:14 +08:00
George Hotz
bdf716b915 mul work 2025-03-21 20:05:29 +08:00
George Hotz
cf41c803d0 fast 13 2025-03-21 18:10:59 +08:00
George Hotz
3cf9224df5 a scale and b scale 2025-03-21 18:07:53 +08:00
George Hotz
af94addb3a ish 2025-03-21 17:46:45 +08:00
George Hotz
dc1469a188 double reduce 2025-03-21 17:33:48 +08:00
George Hotz
0416b0998d revert those 2025-03-21 17:15:38 +08:00
George Hotz
c715c25420
Merge branch 'master' into dsp_search 2025-03-21 17:13:10 +08:00
George Hotz
f66b03f0a6 dsp ish 2025-03-21 16:28:08 +08:00
George Hotz
2729a46ca6 don't do that 2025-03-21 16:04:21 +08:00
George Hotz
dbb50e4a00 knum 4 2025-03-21 15:48:50 +08:00
George Hotz
71c7c455a6 quantize 2025-03-21 14:55:29 +08:00
George Hotz
ff3438be4e fast 2025-03-21 13:04:18 +08:00
George Hotz
bc5e23061b diasm 2025-03-21 11:22:40 +08:00
George Hotz
5ce951fb34 l2 2025-03-21 11:14:12 +08:00
George Hotz
4a49d05a3f
Merge branch 'master' into dsp_search 2025-03-21 10:26:38 +08:00
George Hotz
c3c85c64ee simpler 2025-03-21 09:24:33 +08:00
George Hotz
61c02ca634 cleanups 2025-03-20 23:27:06 +08:00
George Hotz
325044bcaf okay that should actually prefetch 2025-03-20 22:59:59 +08:00
George Hotz
91ac508878 prefetch 2025-03-20 22:56:38 +08:00
George Hotz
2ed30f5366 correct flops 2025-03-20 21:46:13 +08:00
George Hotz
d0b9c7e7ca fast like nascar? 2025-03-20 21:27:26 +08:00
George Hotz
f6ed8f4a27 8 folds 2025-03-20 21:20:46 +08:00
George Hotz
87718170d2 more generic 2025-03-20 21:14:33 +08:00
George Hotz
b67af4049c knum 20 2025-03-20 20:59:06 +08:00
George Hotz
16e425a4c0 work 2025-03-20 20:24:21 +08:00
George Hotz
c867a48ab4 custom 2025-03-20 20:02:35 +08:00
George Hotz
2dc82c0604 should be fast 2025-03-20 19:49:04 +08:00
George Hotz
e7402e6643 KNUM=13 will be fast like roadrunner 2025-03-20 18:45:53 +08:00
George Hotz
e5ccd9e846 work 2025-03-20 15:20:03 +08:00
George Hotz
624197f169 swizzle better 2025-03-20 12:41:24 +08:00
George Hotz
d42350a401 simple test 2025-03-20 12:37:29 +08:00
George Hotz
223feb2118
Merge branch 'master' into dsp_search 2025-03-20 10:52:30 +08:00
George Hotz
8eb9093fb8 lil 2025-03-17 19:57:15 +08:00
George Hotz
45f7c08111 work 2025-03-17 19:22:12 +08:00
George Hotz
58fc77fdb3 improve render 2025-03-17 18:50:44 +08:00
George Hotz
e57258b17b prettier rendering 2025-03-17 18:46:25 +08:00
George Hotz
31cd00e72f fix name get 2025-03-17 18:09:39 +08:00
George Hotz
b00ccc08c3 ms target 2025-03-17 17:49:48 +08:00
George Hotz
94d578aec5 gep pushing 2025-03-17 17:43:02 +08:00
George Hotz
45010f7eff Revert "dont do that"
This reverts commit 249141026e.
2025-03-17 17:26:00 +08:00
George Hotz
249141026e dont do that 2025-03-17 17:15:59 +08:00
George Hotz
a913c1aab7 multi unroll 2025-03-17 17:12:45 +08:00
George Hotz
469ec6b6b4 support tuple in beam 2025-03-17 17:02:32 +08:00
George Hotz
1a84d504b7
Merge branch 'master' into dsp_search 2025-03-17 16:43:07 +08:00
George Hotz
14c9f14125 dsp beam search 2025-03-17 16:42:32 +08:00
George Hotz
cc0041cb8c padding 2025-03-17 16:30:29 +08:00
George Hotz
e4615e0cd9 dsp work try 3 2025-03-17 16:20:46 +08:00
11 changed files with 640 additions and 54 deletions

View file

@ -415,6 +415,7 @@ def get_onnx_ops():
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
if W.shape[1:] == (1,3,3) and group > 1: group = W.shape[0]
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
@ -724,6 +725,19 @@ def get_onnx_ops():
ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype)
else:
ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype)
# you need both NHWC=1 DONT_GROUP_REDUCES=1 for this to work
if getenv("NHWC") and len(ret.shape) == 4:
in_chans = ret.shape[1]
if ret.shape[1] == 3 or in_chans%32 != 0:
return ret.permute(0,2,3,1).contiguous().permute(0,3,1,2)
else:
if in_chans%32 != 0: ret = ret.pad(((0,0), (0,32-(in_chans%32)), (0,0), (0,0)))
ret = ret.reshape(ret.shape[0], ret.shape[1]//32, 32, ret.shape[-2], ret.shape[-1])
order = (0, 1, 3, 4, 2)
ret = ret.permute(order).contiguous().permute(*argsort(order))
ret = ret.reshape(ret.shape[0], -1, ret.shape[-2], ret.shape[-1])
if in_chans%32 != 0: ret = ret[:, :in_chans, :, :]
return ret
return ret.contiguous()
def DynamicQuantizeLinear(x: Tensor):
@ -735,6 +749,66 @@ def get_onnx_ops():
return y, scale, zero_point
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
if getenv("NHWC"):
# pad channels
in_shape = x.shape
if len(x.shape) == 4 and x.shape[1:] == (1,3,3) and x.shape[0]%32 != 0:
# 3x3 depthwise (C,1,3,3). pad C to 32
x = x.pad(((0,32-(x.shape[0]%32)), (0,0), (0,0), (0,0)))
elif len(x.shape) == 4 and x.shape[2:] == (1,1) and x.shape[0] != 1:
# 1x1 conv (C_out,C_in,1,1), pad C_out and C_in to 32
if x.shape[0]%32 != 0: x = x.pad(((0,32-(x.shape[0]%32)), (0,0), (0,0), (0,0)))
if x.shape[1]%32 != 0: x = x.pad(((0,0), (0,32-(x.shape[1]%32)), (0,0), (0,0)))
elif len(x.shape) == 1 and x.shape[0]%32 != 0 and x.shape[0] != 1000:
# bias
x = x.pad(((0,32-(x.shape[0]%32)),))
if in_shape != x.shape:
xzp = x_zero_point.item()
print(f"{in_shape} -> {x.shape}", xzp)
# fix up the zero point in the padded area
pp = (Tensor.full(in_shape, -xzp, dtype=dtypes.int).pad(tuple([(0, so-si) for si,so in zip(in_shape, x.shape)])) + xzp).cast(x.dtype)
x = (x + pp).contiguous()
if getenv("NHWC") and len(x.shape) == 4 and x.shape[1:] == (3,3,3):
x = x.pad(((0,0), (0,0), (0,0), (0,1)))
assert x.shape[0] == 32
order = (1,2,0,3)
x = x.permute(*order).contiguous().permute(*argsort(order))
x = x[:, :, :, :3]
if getenv("NHWC") and len(x.shape) == 4 and x.shape[1:] == (1,3,3):
# 3x3 depthwise (C,1,3,3)
# "width multiple of 4 depth multiple of 32 aligned to 128bytes"
x = x.pad(((0,0), (0,0), (0,0), (0,1)))
if x.shape[0]%32 == 0:
# depth/32 is a loop -- lsr(depth, #5)
# width/4 is a loop -- lsr(out_width, #2)
# height is a loop
x = x.reshape(-1, 32, 1, 3, 4)
order = (0,3,1,2,4)
x = x.permute(*order).contiguous().permute(*argsort(order))
x = x.reshape(-1, 1, 3, 4)
else:
assert False # (doesn't happen anymore)
#print("HERE", x.shape)
order = (2,0,1,3)
x = x.permute(*order).contiguous().permute(*argsort(order))
x = x[:, :, :, :3]
# we increase the filts to 4-aligned for speed (75% util)
WEIGHT_SHIFT = 4
if getenv("NHWC") and len(x.shape) == 4 and x.shape[2:] == (1,1) and x.shape[1]%WEIGHT_SHIFT == 0:
if x.shape[0]%32 == 0:
# DSP swizzle memory (big)
x = x.reshape(x.shape[0]//32, 32, x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(0,2,1,3).contiguous().permute(0,2,1,3).reshape(x.shape)
else:
# DSP swizzle memory
x = x.reshape(x.shape[0], x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(1,0,2).contiguous().permute(1,0,2).reshape(x.shape)
if getenv("NHWC") and x.shape == (1000, 1280):
x = x.reshape(-1, 320, 4)
order = (1,0,2)
x = x.permute(*order).contiguous().permute(*argsort(order))
x = x.reshape(-1, 1280)
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)

View file

@ -4,7 +4,7 @@ from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
from tinygrad.ops import graph_rewrite, GroupOp
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, commutative
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
@ -12,10 +12,16 @@ from tinygrad.renderer import Renderer
# ***** load/store grouping *****
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
vectorize_mask = getenv("VECTORIZE_MASK", 0) and buf.arg == 0 and mask is not None
# generate the individual indexes
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
if vectorize_mask:
# no load_store_indexing if we are doing vectorized mask
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
symbolic_flat+commutative, name=f"index_buf_{buf.arg}")
else:
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
symbolic_flat+commutative+load_store_indexing, name=f"index_buf_{buf.arg}")
# extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
for i in range(vec.dtype.count):
@ -24,7 +30,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
if len(midx.src[i].src) == 3 and not vectorize_mask: root_src = (midx.src[i].src[2], root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
# the buf.dtype is always a pointer
@ -35,10 +41,26 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
idxs: list[int|None] = [None]*vec.dtype.count
global_offset = 0
for offsets in offsets_rootsrc.values():
if 0 in offsets:
match = True
for i in range(0, max(offsets.keys()), 4):
if i in offsets and i+1 in offsets and i+2 in offsets and i+3 not in offsets: pass
else: match = False
if match:
for i in range(0, max(offsets.keys()), 4):
assert i+3 not in offsets
offsets[i+3] = {}
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for grp in grouped_offsets:
# get the index offset for this element. using [0] is okay, because they are the same
lidx = midx.src[offsets[grp[0]][0]]
if vectorize_mask:
allgrp = [midx.src[offsets[g][0]] for g in grp]
base = [x.src[2].cast(dtypes.uchar) if len(x.src) > 2 else UOp.const(dtypes.uchar, 1) for x in allgrp]
vecmask = UOp(Ops.VECTORIZE, dtypes.uchar.vec(len(base)), tuple(base))
lidx = lidx.replace(src=lidx.src[0:2]+(vecmask,))
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
# set the idxs of the output
for i,g in enumerate(grp):
@ -167,7 +189,11 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
if global_offset+fold_length > sz: continue
oidx = idx.src[1] + global_offset
if must_divide and oidx.simplify().divides(fold_length) is None: continue
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
if len(idx.src) > 2 and idx.src[2].dtype.count > 1:
# vectorized
lidx = buf.index(oidx, idx.src[2].gep(tuple(range(global_offset, global_offset+fold_length))))
else:
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
@ -271,7 +297,34 @@ pm_render = PatternMatcher([
# *** uop graph ***
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
from dataclasses import dataclass
from tinygrad.ops import identity_element
from tinygrad.helpers import partition
@dataclass
class ReduceContext:
acc_num: int = 0
def reduce_to_acc(ctx:ReduceContext, x:UOp):
ret = x.src[0]
reduce_range, reduce_expand = partition(x.src, lambda y: y.op is Ops.RANGE)
if len(reduce_range) == 0: return ret
if all(y not in reduce_range for y in ret.toposort):
# TODO: this shouldn't be here
return ret*prod([y.src[1] for y in reduce_range]).broadcast(ret.dtype.count)
alu_op = x.arg
# create acc
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+list(reduce_expand))
# create ACC and assign
return acc.assign(ret)
pm_reduce = PatternMatcher([
(UPat(Ops.REDUCE, name="x"), reduce_to_acc)
])
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None, is_conv=False) -> UOp:
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
@ -282,7 +335,15 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
# optional pre matcher
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
if opts is not None and opts.pre_matcher is not None:
if is_conv:
from tinygrad.runtime.ops_dsp import conv_pm
sink = graph_rewrite(sink, conv_pm+opts.pre_matcher)
else:
sink = graph_rewrite(sink, opts.pre_matcher)
# remove reduce
sink = graph_rewrite(sink, pm_reduce, ctx=ReduceContext(), name="remove_reduce")
# final rules for the renderer (without sym)
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)

View file

@ -1,7 +1,8 @@
# this converts a lowerer program into a vectorized program
import functools, itertools, operator
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, getenv
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
from tinygrad.codegen.symbolic import sym
@ -123,9 +124,21 @@ pm_store_ignore = PatternMatcher([
lambda store,mask: store.replace(src=(store.src[0], UOp(Ops.IGNORE, src=(store.src[1], mask)))) if store.src[1].op is not Ops.IGNORE else None),
])
def debug_ignore(x, y):
if getenv("DEBUG_IGNORE"):
print("****")
print("seen ", x.render())
print("ignore", y.render())
# this is totally wrong
return x.const_like(True)
pm_move_ignore = PatternMatcher([
# IGNORE on SELF is nothing
(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat(name="x"))), lambda x: x.const_like(True)),
# IGNORE debug
(UPat(Ops.IGNORE, src=(UPat(dtype=dtypes.bool, name="x"), UPat(dtype=dtypes.bool, name="y"))), debug_ignore),
# IGNORE with AND on SELF is nothing (is this right?)
#(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat(name="x") & UPat())), lambda x: x.const_like(True)),
# IGNORE on a CONST is nothing
(UPat(Ops.IGNORE, src=(UPat((Ops.CONST, Ops.VCONST), name="c"), UPat())), lambda c: c),
# move the IGNOREs

View file

@ -437,6 +437,61 @@ class Kernel:
return self
def hand_coded_optimizations(self) -> Kernel:
if self.opts.device == "DSP":
k = self
# special path for DSP
if k.full_shape[-3:] == (32,3,3):
# 3x3 dwconv
# kernel 49 is broken
if k.full_shape[-4]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, len(k.full_shape)-4, 4))
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-3, 32))
if k.full_shape[len(k.full_shape)-4]%4 == 0:
#if k.full_shape[len(k.full_shape)-4] <= 8: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 0))
#else: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
# if this is small, swap it
# NOTE: this is breaking something (should be fixed w/o padto)
# kernel 23 is broken with this
if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.SWAP, 0, 1))
elif k.full_shape[-4:] == (32,3,3,3):
# 3x3 normal conv
k.apply_opt(Opt(OptOps.UNROLL, 2, 0))
k.apply_opt(Opt(OptOps.UNROLL, 1, 0))
# more UNROLLs aren't working well here, but they should be
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
elif len(k.full_shape) == 3 and k.full_shape[1] == 32 and k.first_reduce == 2:
# weight that's exactly 32
# NOTE: this pad might be broken
if k.full_shape[0]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 4))
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
if k.full_shape[0]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
elif len(k.full_shape) == 4 and k.full_shape[2] == 32 and k.first_reduce == 3:
# weight that has more than 32
# NOTE: this pad is broken on kernel 50
if k.full_shape[1]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 1, 4))
# weight with more
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))
if k.full_shape[1]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
# if the more is small, upcast it (kernel 50 is broken with this)
if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.UPCAST, 0, 0))
elif len(k.full_shape) == 2 and k.first_reduce == 1:
# unroll to 4 if we can
if k.full_shape[k.first_reduce]%4 == 0: k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
# always pad to 128
# NOTE: this breaks kernel 66
if k.full_shape[0]%128 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 128))
if k.full_shape[0]%128 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 128))
elif len(k.full_shape) == 1 and k.full_shape[0] > 1000:
# pad to 128 and run on 128
if k.full_shape[0]%128 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 128))
if k.full_shape[0]%128 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 128))
return self
self.required_optimizations()
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
@ -678,7 +733,11 @@ class Kernel:
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
#if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
is_conv = (len(self.full_shape) == 6 and self.full_shape[2:4] == (3,3))
is_conv = is_conv or (len(self.full_shape) == 6 and self.full_shape[3:5] == (3,3))
is_conv = is_conv or (len(self.full_shape) == 7 and self.full_shape[3:5] == (3,3))
is_conv = is_conv or self.full_shape[-2:] == (3,3)
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts, is_conv))
if DEBUG >= 6: print_uops(self.uops)
return self

View file

@ -3,9 +3,9 @@ import functools, itertools, operator, math
from dataclasses import dataclass
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, sint_to_uop
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
from tinygrad.helpers import all_int, prod, flatten, unwrap, QUANTIZE
from tinygrad.codegen.expander import expand_rewrite
from tinygrad.codegen.symbolic import symbolic
@ -112,21 +112,20 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
def lower_reduce_axis(ctx: IndexContext, x: UOp):
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
#reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
reduce_indexes = [ctx.ridxs[i] for i in x.axis_arg]
all_nodes = flatten([x.toposort for x in reduce_indexes])
reduce_expand = [x for x in all_nodes if x.op is Ops.UNROLL]
reduce_range = [x for x in all_nodes if x.op is Ops.RANGE]
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
alu_op: Ops = x.arg[0]
ret = x.src[0]
# create acc
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
ctx.acc_num += 1
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+[ret.gep(i) for i in range(ret.dtype.count)])
ret = (functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)]),)
else:
ret = acc.alu(alu_op, ret)
if not len(reduce_range): return ret
# create ACC and assign
return acc.assign(ret)
ret = (ret,)
return UOp(Ops.REDUCE, x.dtype, ret+tuple(reduce_range), alu_op)
def lower_load_store(ctx: IndexContext, x: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
@ -221,6 +220,13 @@ pm_quant = symbolic+PatternMatcher([
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")), name="r"),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
# MUL by 1/0 on LOAD where the masks match
(UPat(Ops.WHERE, src=(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v1"),)), UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))) * \
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v2")), name="ld"),
lambda ld,v1,v2: ld if v1.arg.to_indexed_uops()[1].simplify() == v2.arg.to_indexed_uops()[1].simplify()
# NOTE: this clause is completely false and might break things
or True else None),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:

View file

@ -168,7 +168,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
quo += q * v
# if numerator is negative, and it has remainder, don't simplify because C divmod is different from python divmod.
if x.vmin < 0 and remainders: return None
if x.vmin < -10000000 and remainders: return None
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
return rem//(c//gcd)+quo
@ -201,7 +201,7 @@ commutative = PatternMatcher([
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
])
symbolic = symbolic_simple+commutative+PatternMatcher([
symbolic = symbolic_simple+PatternMatcher([
# ** boolean algebra **
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
# ** combine terms **

View file

@ -335,9 +335,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def simplify(self):
# late import!
from tinygrad.codegen.symbolic import symbolic
from tinygrad.codegen.symbolic import symbolic_flat, commutative
with Context(TRACK_MATCH_STATS=0):
return graph_rewrite(self, symbolic)
return graph_rewrite(self, symbolic_flat+commutative)
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
def _eval(self, dtype, expected_type:Type[T]) -> T:
assert self.dtype in dtype, f"eval with wrong dtype {self}"

View file

@ -78,6 +78,9 @@ class Estimates:
elif u.op is Ops.STORE: lds += u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
elif u.op in {Ops.CUSTOM, Ops.CUSTOMI} and u not in dont_count:
if u.arg.startswith("__builtin_HEXAGON_V6_vrmpy"): flops += 32*mults*(8 if 'acc' in u.arg else 7)
if u.arg.startswith("__builtin_HEXAGON_A2_vraddub"): flops += mults*(17 if 'acc' in u.arg else 16)
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
@dataclass
@ -106,6 +109,9 @@ class ProgramSpec:
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
# DSP masked store
if u.op is Ops.CUSTOM and u.arg.startswith("__builtin_HEXAGON_V6_vS32b"):
self.outs.extend([x.arg for x in u.src[1].toposort if x.op is Ops.DEFINE_GLOBAL])
if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
if u.op is Ops.SPECIAL:
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this

View file

@ -1,6 +1,6 @@
from typing import cast
import itertools
from tinygrad.helpers import dedup, DEBUG, to_function_name
from tinygrad.helpers import dedup, DEBUG, to_function_name, getenv
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.device import Buffer
from tinygrad.engine.realize import ExecItem, CompiledRunner
@ -24,17 +24,18 @@ class CPUGraph(GraphRunner):
if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}"
return f"({device.renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})"
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+") {"]
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+", int gl0, void* sync) {"]
for i, ji in enumerate(jit_cache):
args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});")
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)}, gl0, 0x0);")
if getenv("MULTICORE", 0) != 0: batched.append(f" qurt_barrier_wait(&(((qurt_barrier_t*)sync)[{i}]));")
batched.append("}")
prep = [device.renderer._render(cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)]
funcs = dedup(device.renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache))
defines = dedup(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache))
entry = device.renderer._render_entry("batched", targs)
entry = device.renderer._render_entry("batched", targs, sync_cnt=len(jit_cache))
code = '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
if DEBUG >= 4: print(code)

View file

@ -3,65 +3,417 @@ import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, context
assert sys.platform != 'win32'
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.ops import Ops, UOp
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG
from tinygrad.ops import Ops, UOp, PatternMatcher, UPat
from tinygrad.codegen.symbolic import gep_pushing
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG, dedup, all_same
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.runtime.autogen import libc, qcom_dsp
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
from tinygrad.ops import PatternMatcher, UPat
def multi_mul(a0, a1, b0, b1, c0, c1, d0=None, d1=None, acc=None):
swizzle = []
for i in range(32):
swizzle.append(i)
swizzle.append(32+i)
swizzle.append(64+i)
swizzle.append(96+i)
swizzle = tuple(swizzle)
if a0.op is not Ops.CAST:
#print("rejected on a0")
return None
if a1.op is not Ops.CAST:
#print("rejected on a1")
return None
if d0 is None:
d0 = UOp.const(dtypes.uchar.vec(32), 0).cast(dtypes.int.vec(32))
if d1 is None:
d1 = UOp.const(dtypes.uchar.vec(32), 0).cast(dtypes.int.vec(32))
assert a0.op is Ops.CAST
assert b0.op is Ops.CAST
assert c0.op is Ops.CAST
assert d0.op is Ops.CAST
assert a1.op is Ops.CAST
assert b1.op is Ops.CAST
assert c1.op is Ops.CAST
assert d1.op is Ops.CAST
dt1 = a0.src[0].dtype.scalar().vec(128)
dt2 = a1.src[0].dtype.scalar().vec(128)
m0 = UOp(Ops.CAT, dt1, src=(a0.src[0],b0.src[0],c0.src[0],d0.src[0])).gep(swizzle)
m1 = UOp(Ops.CAT, dt2, src=(a1.src[0],b1.src[0],c1.src[0],d1.src[0])).gep(swizzle)
simp_m1 = m1.simplify()
if simp_m1.op is Ops.GEP and simp_m1.arg == simp_m1.arg[0:4]*32:
# Vx32.w+=vrmpy(Vu32.ub,Rt32.b) -> __builtin_HEXAGON_V6_vrmpybus_acc
# Vx32.uw+=vrmpy(Vu32.ub,Rt32.ub) -> __builtin_HEXAGON_V6_vrmpyub_acc
scalar_m1 = simp_m1.src[0].gep(simp_m1.arg[0:4])
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_128B({0}, {1})")
def gep_on_reduce(gep, alu):
if gep.dtype.vcount == 1: return None
return UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count),
tuple(x.gep(gep.arg) if x.op is not Ops.RANGE else x for x in alu.src), alu.arg) if not isinstance(gep.dtype, PtrDType) and \
alu.dtype.count >= gep.dtype.count else None
def multi_add_int32(**aa):
if 'acc' in aa:
acc = aa['acc']
del aa['acc']
else:
acc = None
mask = 0x01010101
if 'd0' not in aa:
mask = 0x00010101
d0 = UOp.const(dtypes.uchar.vec(32), 0).cast(dtypes.int.vec(32))
aa['d0'] = d0
swizzle = []
for i in range(32):
swizzle.append(i)
swizzle.append(32+i)
swizzle.append(64+i)
swizzle.append(96+i)
for x in aa.values():
assert x.src[0].dtype.scalar() is dtypes.uchar
assert x.op is Ops.CAST
swizzle = tuple(swizzle)
m0 = UOp(Ops.CAT, dtypes.uchar.vec(128), src=tuple(aa[k].src[0] for k in sorted(aa.keys()))).gep(swizzle)
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, UOp.const(dtypes.uint, mask)),
"__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, UOp.const(dtypes.uint, mask)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
def multi_add_int2(**aa):
if 'acc' in aa:
acc = aa['acc']
del aa['acc']
else:
acc = None
eles = []
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(0))
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(1))
r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[0:4]+eles[8:12]))
r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[4:8]+eles[12:16]))
# TODO: types aren't right here
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (acc, r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)),
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})")
conv_pm = PatternMatcher([
# __builtin_HEXAGON_V6_vrmpybus x3
(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
UPat(name="c0")*UPat(name="c1"), multi_mul),
(UPat(name="acc") + UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
UPat(name="c0")*UPat(name="c1"), multi_mul),
# __builtin_HEXAGON_V6_vrmpybus x3
(UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
(UPat(name="acc") + UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
])
dsp_pm = PatternMatcher([
# convert load char32 to load char128
(UPat(Ops.LOAD, (dtypes.uchar.vec(96), dtypes.uchar.vec(64), dtypes.uchar.vec(32)), src=(UPat.var("buf").cast(),), name="load"),
lambda load, buf: load.replace(dtype=dtypes.uchar.vec(128),
src=(buf.cast(buf.dtype.base.vec(128).ptr(size=buf.dtype.size, local=buf.dtype.local)),)+load.src[1:]).gep(tuple(range(0, load.dtype.count)))),
# GEP on REDUCE
(UPat(Ops.GEP, src=(UPat(Ops.REDUCE, name='alu'),), name='gep'), gep_on_reduce),
# no swizzle down convert
(((UPat.var('x').maximum(0) ^ -1).maximum(-256) ^ -1).cast(dtypes.uchar.vec(128)),
lambda x: UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=tuple(x.gep(tuple(range(i, i+32))) for i in range(0, 128, 32)),
arg="__builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B({3}, {2}), __builtin_HEXAGON_V6_vpackwh_sat_128B({1}, {0}))")),
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src+x.src,
"__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 and x.src[0].dtype.count > 1 else None),
])
# REDUCE int4 -> 2xint2, int8 -> 4xint2
(UPat(Ops.REDUCE, dtype=dtypes.int.vec(4), name="r"),
lambda r: UOp(Ops.CAT, r.dtype, (gep_on_reduce(r.gep((0,1)), r), gep_on_reduce(r.gep((2,3)), r)))),
(UPat(Ops.REDUCE, dtype=dtypes.int.vec(8), name="r"),
lambda r: UOp(Ops.CAT, r.dtype, (gep_on_reduce(r.gep((0,1)), r), gep_on_reduce(r.gep((2,3)), r),
gep_on_reduce(r.gep((4,5)), r), gep_on_reduce(r.gep((6,7)), r)))),
# __builtin_HEXAGON_V6_vrmpybus
(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
UPat(name="c0")*UPat(name="c1") + UPat(name="d0")*UPat(name="d1"), multi_mul),
(UPat(name="acc") + UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
UPat(name="c0")*UPat(name="c1") + UPat(name="d0")*UPat(name="d1"), multi_mul),
# build __builtin_HEXAGON_V6_vrmpybus_128B
(UPat(Ops.CAST,dtype=dtypes.int.vec(32),name="a0")+UPat(Ops.CAST,name="b0")+UPat(Ops.CAST,name="c0")+UPat(Ops.CAST,name="d0"), multi_add_int32),
(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(32),name="a0")+UPat(Ops.CAST,name="b0")+
UPat(Ops.CAST,name="c0")+UPat(Ops.CAST,name="d0"), multi_add_int32),
# build __builtin_HEXAGON_A2_vraddub
(UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
# we upcast 3 as 4
(UPat(Ops.REDUCE, name="r", src=(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") +
UPat(name="b0")*UPat(name="b1") + UPat(name="c0")*UPat(name="c1"),), allow_any_len=True),
lambda r, **kwargs: r.replace(src=(mm,)+r.src[1:]) if (mm:=multi_mul(**kwargs)) else None),
(UPat(Ops.REDUCE, name="r", src=(UPat(Ops.CAST,dtype=dtypes.int.vec(32),name="a0")+UPat(Ops.CAST,name="b0")+UPat(Ops.CAST,name="c0"),
), allow_any_len=True),
lambda r, **kwargs: r.replace(src=(mm,)+r.src[1:]) if (mm:=multi_add_int32(**kwargs)) else None),
# mul by const on GEP
(UPat(Ops.GEP, src=(UPat.var('x'),), name="gep")*UPat.cvar("c", vec=False),
lambda x, gep, c: (x.gep(gep.arg[0])*c.arg).broadcast(c.dtype.count) if all_same(gep.arg) and c.dtype.count > 1 else None),
])+gep_pushing
def add_to_mul(c:UOp, x:UOp):
if c.arg.startswith("__builtin_HEXAGON_V6_vrmpyub_128B"):
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
elif c.arg.startswith("__builtin_HEXAGON_V6_vrmpyubv_128B"):
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
elif 'acc' in c.arg and x.op is not Ops.CUSTOM:
return c.replace(src=(x+c.src[0], c.src[1], c.src[2]))
else:
return None
def prefetch_l1(ld:UOp, idx:UOp):
if ld.src[-1].op is Ops.CUSTOM: return None
ranges = sorted([x for x in ld.src[0].src[0].toposort if x.op is Ops.RANGE], key=lambda x: x.arg)
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1]+UOp.const(dtypes.int, ld.dtype.count*2)),
arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
x2 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1].substitute({ranges[-1]: ranges[-1].src[0]})),
arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
return ld.replace(src=ld.src+(x1, x2))
def prefetch_l2(ld:UOp, idx:UOp):
if not getenv("PREFETCHL2", 1): return None
if ld.src[-1].op is Ops.CUSTOM and 'l2fetch' in ld.src[-1].arg: return None
ranges = sorted([x for x in ld.src[0].src[0].toposort if x.op is Ops.RANGE], key=lambda x: x.arg)
if len(ranges):
nidx = idx.src[1]
const = 0
if nidx.op is Ops.ADD and nidx.src[1].op is Ops.CONST:
# NOTE: this causes access alignment issues
#const = nidx.src[1].arg
nidx = nidx.src[0]
zero_ranges = {r:r.const_like(0) for r in ranges[:-1]}
nlen_uop = (nidx.substitute({ranges[-1]: ranges[-1].src[1], **zero_ranges}) -
nidx.substitute({ranges[-1]: ranges[-1].src[0], **zero_ranges})).simplify()
nidx = nidx.substitute({ranges[-1]: ranges[-1].src[0]})
buf_lines_total = ((idx.src[0].dtype.size*idx.src[0].dtype.itemsize)+127)//128
if buf_lines_total < 8192//128:
# if the total buffer size is sub 8k, fetch it all
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], UOp.const(dtypes.int, buf_lines_total)),
arg="__builtin_HEXAGON_Y4_l2fetch({0}, 0x808000|{1});")
else:
fetch_lines = 8
if nlen_uop.op is Ops.CONST and nlen_uop.arg <= 8192: fetch_lines = ((nlen_uop.arg+127)//128)*2+1
fetch_lines = max(fetch_lines, 8)
# fetch up to 8192
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], nidx+const, UOp.const(dtypes.int, fetch_lines)),
arg="__builtin_HEXAGON_Y4_l2fetch({0}+{1}, 0x808000|{2});")
return ld.replace(src=ld.src+(x1,))
def vectorize_shuffle(vec:UOp):
if getenv("DISABLE_VECTORIZED_SHUFFLE", 0): return None
if not all(s.op in {Ops.GEP, Ops.CONST} for s in vec.src): return None
gepped = dedup([s.src[0] for s in vec.src if s.op is Ops.GEP])
if len(gepped) == 0: return None
if len(gepped) == 1:
# this pattern is broken in DSP clang
if gepped[0].dtype.count == 4: return None
#return None
arg = []
for s in vec.src:
if s.op is Ops.GEP:
arg.append(s.arg[0])
else:
arg.append(-1)
str_arg = ','.join([f'{y:4d}' for y in arg])
full_arg = "__builtin_shufflevector({0}, {0}, "+str_arg+")"
return UOp(Ops.CUSTOM, vec.dtype, tuple(gepped), full_arg)
if not all(x.dtype.scalar() is dtypes.uchar for x in gepped): return None
if not all_same([x.dtype.count for x in gepped]) or gepped[0].dtype.count != vec.dtype.count: return None
if len(gepped) == 2:
arg = []
for s in vec.src:
if s.op is Ops.GEP:
if s.src[0] is gepped[0]:
arg.append(s.arg[0])
continue
if s.src[0] is gepped[1]:
arg.append(gepped[0].dtype.count + s.arg[0])
continue
arg.append(-1)
str_arg = ','.join([f'{y:4d}' for y in arg])
full_arg = "__builtin_shufflevector({0}, {1}, "+str_arg+")"
return UOp(Ops.CUSTOM, vec.dtype, tuple(gepped), full_arg)
if len(gepped) != 3: return None
arg = []
for s in vec.src:
if s.op is Ops.GEP:
if s.src[0] is gepped[0]:
arg.append(s.arg[0])
continue
if s.src[0] is gepped[1]:
arg.append(gepped[0].dtype.count + s.arg[0])
continue
arg.append(-1)
arg2 = []
for i,s in enumerate(vec.src):
if s.op is Ops.GEP:
if s.src[0] is gepped[2]:
arg2.append(vec.dtype.count + s.arg[0])
continue
if s.op is Ops.CONST:
arg2.append(-1)
continue
arg2.append(i)
str_arg = ','.join([f'{y:4d}' for y in arg])
str_arg2 = ','.join([f'{y:4d}' for y in arg2])
full_arg = "__builtin_shufflevector(__builtin_shufflevector({0}, {1}, "+str_arg+"), {2}, "+str_arg2+")"
return UOp(Ops.CUSTOM, vec.dtype, tuple(gepped), full_arg)
def multicore_range(r:UOp):
# NOTE: THIS IS BROKEN if this is a reduce range. TODO: check for that
if getenv("MULTICORE", 0) != 1: return None
if any(x.op is Ops.SPECIAL for x in r.toposort): return None
core = UOp(Ops.SPECIAL, dtypes.int, arg=("g0", 2))
start = (core.eq(0)).where(r.src[0], r.src[1]//2)
end = (core.eq(0)).where(r.src[1]//2, r.src[1])
return r.replace(src=(start,end))
def store_with_mask(buf, idx, val, mask, cast):
if val.dtype.count != 128 or mask.dtype.count != 128 or val.dtype.scalar() != dtypes.uchar:
print("DROP MASK", val.dtype.count, mask.dtype.count)
# NOTE: we are dropping the mask
return buf.index(idx).cast(cast.dtype).store(val)
const_0 = UOp.const(dtypes.uchar.vec(128), 0)
cmask = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(mask,),arg="{0}")
# unaligned
min_128 = (buf.index(idx).cast(dtypes.uint)&0x7F)
cmask_l = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(cmask, const_0, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
cmask_r = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(const_0, mask, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
val_l = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(val, const_0, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
val_r = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(const_0, val, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
store_l = UOp(Ops.CUSTOM, dtypes.void, src=(cmask_l, buf.index(idx).cast(cast.dtype), val_l, const_0),
arg='__builtin_HEXAGON_V6_vS32b_nqpred_ai_128B(__builtin_HEXAGON_V6_veqb_128B({0}, {3}), {1}, {2});')
store_r = UOp(Ops.CUSTOM, dtypes.void, src=(cmask_r, buf.index(idx+128).cast(cast.dtype), val_r, const_0),
arg='__builtin_HEXAGON_V6_vS32b_nqpred_ai_128B(__builtin_HEXAGON_V6_veqb_128B({0}, {3}), {1}, {2});')
return UOp(Ops.CUSTOM, src=(store_l,store_r), arg="")
dsp_pm_late = PatternMatcher([
(UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
# prefetch L1
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(4), dtypes.uchar.vec(8)), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ld"), prefetch_l1),
# prefetch L2
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(8), dtypes.uchar.vec(128), dtypes.int.vec(32)),
src=(UPat(Ops.INDEX, name="idx").cast(),), name="ld", allow_any_len=True), prefetch_l2),
# use __builtin_shufflevector
(UPat(Ops.VECTORIZE, dtypes.uchar.vec(128), name="vec"), vectorize_shuffle),
# __builtin_HEXAGON_V6_vrmpyub_acc_128B
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
# add acc to __builtin_HEXAGON_A2_vraddub (must be after the reduce expansion)
(UPat(Ops.CUSTOMI, name="cu", arg="__builtin_HEXAGON_A2_vraddub({0}, {1})") + UPat.var("x"),
lambda x,cu: cu.replace(dtype=dtypes.int64, src=(x.bitcast(dtypes.int64), cu.src[0], cu.src[1]),
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))),
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src,
"__builtin_shufflevector({0}, {0}, "+','.join([f'{y:4d}' for y in x.arg])+")") if len(x.arg) > 1 and x.src[0].dtype.count > 4 else None),
(UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")),
lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI or x.arg != "{0}" else None),
(UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")),
lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI or x.arg != "{0}" else None),
(UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")),
lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI or x.arg != "{0}" else None),
(UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
# multicore
(UPat(Ops.RANGE, name="r", arg=0), multicore_range),
# store with mask
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))).cast().named("cast"), UPat.var("val"))),
store_with_mask),
])
# NOTE: this just increases readability of the generated code
dsp_string = PatternMatcher([
(UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)),
pretty_render = PatternMatcher([
# makes rendering nicer
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, dtype=(dtypes.uint8, dtypes.int8)), name="v"),
lambda v: UOp(Ops.VECTORIZE, v.dtype, src=tuple(UOp(Ops.CUSTOMI, x.dtype, src=(UOp.const(dtypes.int, x.arg),), arg="{0}") for x in v.src))),
])
class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = True
global_max = (2, 1, 1)
buffer_suffix = " restrict __attribute__((align_value(128)))"
kernel_prefix = "__attribute__((noinline)) "
pre_matcher = dsp_pm
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
string_rewrite = dsp_string+ClangRenderer.string_rewrite
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher+pretty_render
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
code_for_op = {k:v for k,v in ClangRenderer.code_for_op.items() if k != Ops.SQRT}
extra_args = ['int global_idx_0', 'void* sync']
code_for_workitem = {"g": lambda x: f"global_idx_{x}"}
def _render_defines(self, uops) -> list[str]:
return ['''/* DSP boilerplate */ struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency;
_Bool set_dcvs_params; short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3];};''','int HAP_power_set(void*, void*);',
'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
'unsigned long long HAP_perf_get_time_us(void);', 'typedef unsigned long qurt_thread_t;', 'void qurt_thread_exit(int);',
'typedef struct _qurt_barrier { char padding[64]; } qurt_barrier_t;', 'int qurt_barrier_init(qurt_barrier_t*, unsigned int);',
'int qurt_barrier_wait(qurt_barrier_t*);',
'typedef struct _qurt_thread_attr { char name[16]; unsigned char tcb_partition; unsigned char affinity; unsigned short priority;',
'unsigned char asid; unsigned char bus_priority; unsigned short timetest_id; unsigned int stack_size;'
'void *stack_addr; char padding[96]; } qurt_thread_attr_t;',
'int qurt_thread_join(qurt_thread_t tid, int *status);', 'void* malloc(unsigned int);', 'void free(void*);',
'int qurt_thread_create (qurt_thread_t *thread_id, qurt_thread_attr_t *attr, void (*entrypoint) (void *), void *arg);',
] + super()._render_defines(uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]], sync_cnt=0x0) -> str:
msrc = ['typedef struct all_args {', *[f'int sz_or_val_{i}; int off{i}; void *buf_{i};' for i in range(len(bufs))], 'void* sync; } all_args_t;']
msrc += ['void threader(all_args_t* args) {']
buf_inputs = ', '.join([(f'args->buf_{i}' if isinstance(b[1][0], PtrDType) else f'args->sz_or_val_{i}') for i,b in enumerate(bufs)])
msrc += [f"{function_name}({buf_inputs}, 1, args->sync);"]
msrc += ['qurt_thread_exit(0); }'
'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
'HAP_power_set((void*)handle, (void*)&req);']
msrc += ['if ((sc>>24) != 2) return 0;']
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
if sync_cnt > 0:
msrc += [f"qurt_barrier_t* sync = malloc({sync_cnt} * sizeof(qurt_barrier_t));"]
msrc += [f"qurt_barrier_init(&sync[{i}], 2);" for i in range(sync_cnt)]
else: msrc += ["qurt_barrier_t* sync = 0x0;"]
msrc += ['all_args_t args = { 0 };']
msrc += [f'args.sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
msrc += [f'args.off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += [f'args.buf_{i} = HAP_mmap(0,args.sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+args.off{i};'
for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += ['args.sync = sync;']
msrc += ["qurt_thread_attr_t attr = { 0 };"]
msrc += ["attr.name[0] = 't';", "attr.priority = 255;", "attr.asid = 0;"]
msrc += ["attr.stack_size = (64 << 10);", "attr.stack_addr = malloc(attr.stack_size);"]
msrc += [""]
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
if getenv("MULTICORE", 0) != 0:
msrc += ["qurt_thread_t thread_ = 0; qurt_thread_create(&thread_, &attr, (void (*)(void*))threader, (void*)&args);"]
buf_inputs = ', '.join([(f'args.buf_{i}' if isinstance(b[1][0], PtrDType) else f'args.sz_or_val_{i}') for i,b in enumerate(bufs)])
msrc += [f"{function_name}({buf_inputs}, 0, args.sync);"]
if getenv("MULTICORE", 0) != 0:
msrc += ['int status;', "qurt_thread_join(thread_, &status);"]
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += [f'HAP_munmap(args.buf_{i}, args.sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += ['free(attr.stack_addr);']
if sync_cnt > 0: msrc += ['free(sync);']
msrc += ["return 0; }"]
return '\n'.join(msrc)
@ -141,7 +493,7 @@ class DSPDevice(Compiled):
'got', 'got.plt', 'dynsym', 'dynstr', 'symtab', 'shstrtab', 'strtab']
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
with tempfile.NamedTemporaryFile(delete=False) as self.link_ld:
self.link_ld.write(f"SECTIONS {{ . = 0x0; {sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
self.link_ld.write(f"SECTIONS {{ . = 0x0;\n{sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
self.link_ld.flush()
from tinygrad.runtime.graph.cpu import CPUGraph
@ -283,7 +635,15 @@ class MockDSPRenderer(DSPRenderer):
else:
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
msrc.append("unsigned int st = inscount();")
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])});")
buf_inputs = ', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])
if getenv("MULTICORE", 0) != 0:
# TODO: get count?
# NOTE: we do them in reverse order to reveal bugs
msrc.append(f"{function_name}({buf_inputs}, 1, 0);")
msrc.append(f"{function_name}({buf_inputs}, 0, 0);")
else:
# huh, why did this change?
msrc.append(f"{function_name}({buf_inputs}, 0, 0);")
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
for i,b in enumerate(bufs):
if isinstance(b[1][0], PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].size*b[1][0].itemsize});")

View file

@ -81,6 +81,9 @@ spec = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
# all loads (we add Ops.CUSTOM as fake sources on this for l1fetch)
(UPat(Ops.LOAD), lambda: True),
# early STORE has a <buf, shapetracker, val>
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
@ -91,6 +94,9 @@ spec = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat()), name="idx"), validate_index),
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool, name="mask")), name="idx"), validate_index),
# any mask
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(name="mask")), name="idx"), validate_index),
# LOAD takes a <bufidx, alt?, barrier?>
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),