Compare commits

...

158 commits

Author SHA1 Message Date
George Hotz
1568c92f7d cleanups 2026-04-09 20:43:24 +08:00
George Hotz
5d09363b5f simpler 2026-04-09 20:12:38 +08:00
George Hotz
c012f9c5a7 move define -> regalloc 2026-04-09 20:06:54 +08:00
George Hotz
934c0c5797 const 2026-04-09 19:42:33 +08:00
George Hotz
c559c29d0b works 2026-04-09 19:15:05 +08:00
George Hotz
e528bc389b move x86 stuff to correct places 2026-04-09 18:57:19 +08:00
Christopher Milan
bd6d7e22ce c.Struct cleanup (#15640) 2026-04-09 18:22:43 +08:00
qazal
fb40e711dd viz/cli: add pmc printer (#15651)
* viz/cli: add pmc printer

* cli work

* s

* linter

* pack workgroups

* add : to wgp

* counter name
2026-04-09 18:22:43 +08:00
chenyu
62a7b84aba fix merge_reduce_ends (#15659)
* fix merge_reduce_ends

same range with different nesting should not merge, like cumsum twice should not merge

* skip that
2026-04-09 18:22:43 +08:00
ttomsa
a17988a52d add callee saved registers 2026-04-08 21:00:37 +01:00
ttomsa
12f073e137
Merge branch 'master' into new_x86_backend 2026-04-08 20:29:59 +01:00
George Hotz
29582199c1
Merge branch 'master' into new_x86_backend 2026-04-07 21:16:43 +08:00
ttomsa
14be3279c1
Merge branch 'master' into new_x86_backend 2026-04-06 23:29:16 +01:00
ttomsa
2e61817001 fix 2026-04-06 04:55:21 +01:00
ttomsa
8868abe830 fix 2026-04-06 04:23:15 +01:00
ttomsa
8346332061 fix 2026-04-06 02:49:02 +01:00
ttomsa
91bf07e702
Merge branch 'master' into new_x86_backend 2026-04-06 02:34:12 +01:00
ttomsa
891807a1b9 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-04-05 00:05:11 +01:00
ttomsa
0b0ea63439 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-04-03 04:12:01 +01:00
ttomsa
53ef0d36ec Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-04-01 00:49:39 +01:00
ttomsa
d333ac1242
Merge branch 'master' into new_x86_backend 2026-03-29 16:51:49 +01:00
ttomsa
85e6b77c13
Merge branch 'master' into new_x86_backend 2026-03-26 23:34:27 +00:00
ttomsa
a2b32b3abf Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-03-26 16:28:02 +00:00
ttomsa
1fb940f762
Merge branch 'master' into new_x86_backend 2026-03-25 03:08:17 +00:00
ttomsa
cd0152efec
Merge branch 'master' into new_x86_backend 2026-03-23 22:50:56 +00:00
ttomsa
35fc12b839 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-03-21 19:16:59 +00:00
ttomsa
292e1745b2 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-03-19 20:54:58 +00:00
ttomsa
e81878abd9 enable gep noop rule 2026-03-19 20:54:00 +00:00
ttomsa
acdc232d65 fix 2026-03-18 01:11:58 +00:00
ttomsa
0dc615b588 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-03-18 00:52:27 +00:00
ttomsa
449c79ada2 deal with flags correctly 2026-03-18 00:49:05 +00:00
ttomsa
221eafcd8d fix 2026-03-10 01:41:23 +00:00
ttomsa
7115ed0c22 fix 2026-03-10 00:21:04 +00:00
ttomsa
1a52341196 fix 2026-03-10 00:10:24 +00:00
ttomsa
037c5e6f82 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-03-09 23:33:52 +00:00
ttomsa
aaab4407af a lot better 2026-03-09 23:29:49 +00:00
ttomsa
465be0d333
Merge branch 'master' into new_x86_backend 2026-03-06 00:32:35 +00:00
ttomsa
dd5076529b rm this for now 2026-03-06 00:32:17 +00:00
ttomsa
41f2bd8a05 more isel tests 2026-03-06 00:28:36 +00:00
ttomsa
255a788dea enable vector load/store on all dtypes 2026-03-05 19:57:07 +00:00
ttomsa
b172b5d72c
Merge branch 'master' into new_x86_backend 2026-03-04 20:50:26 +00:00
ttomsa
64f574572c regalloc takes renderer 2026-03-04 20:49:47 +00:00
ttomsa
3dc9bbd831 vpsrldq can't access memory 2026-03-02 22:16:36 +00:00
ttomsa
cafa3b74d4 rm bad rewrite 2026-03-02 21:32:52 +00:00
ttomsa
393e591f49
Merge branch 'master' into new_x86_backend 2026-03-02 21:14:07 +00:00
ttomsa
82954c7ca4 support float16 vector load/store 2026-03-02 21:13:07 +00:00
ttomsa
eb4ad1ebf0 move max/min and add test 2026-02-27 22:52:50 +00:00
ttomsa
ce2c690721
Merge branch 'master' into new_x86_backend 2026-02-27 22:40:28 +00:00
ttomsa
a982a8709e canonical max 2026-02-27 22:38:16 +00:00
ttomsa
94c317a437 print correct reg names 2026-02-27 22:23:11 +00:00
ttomsa
13c4f2fb04 linter 2026-02-27 00:22:24 +00:00
ttomsa
ee4455952d
Merge branch 'master' into new_x86_backend 2026-02-26 23:43:17 +00:00
ttomsa
c247bdd9b9 cleanups 2026-02-26 23:42:41 +00:00
ttomsa
e29915e3ff fix vpbroadcast 2026-02-26 18:39:35 +00:00
ttomsa
b4152bff20
Merge branch 'master' into new_x86_backend 2026-02-24 22:14:02 +00:00
ttomsa
0af6d94422 move X86Ops to x86.py 2026-02-24 22:13:44 +00:00
ttomsa
a9b5a368da linter 2026-02-23 22:09:46 +00:00
ttomsa
24bff881e8 linter 2026-02-23 22:01:24 +00:00
ttomsa
563a31f791 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-02-23 19:41:53 +00:00
ttomsa
510ef99411 fix 2026-02-23 18:28:34 +00:00
ttomsa
8b83cc3aeb no more bool cast 2026-02-23 17:44:57 +00:00
ttomsa
59603f4d93
Merge branch 'master' into new_x86_backend 2026-02-23 17:43:03 +00:00
ttomsa
84b361df95 move x86 tests 2026-02-23 01:41:52 +00:00
ttomsa
82e42ec061 actually no 2026-02-23 01:19:22 +00:00
ttomsa
dacfa01c0d enable float16 and unaligned vector load/store 2026-02-23 01:01:57 +00:00
ttomsa
6b7d75683f oops 2026-02-23 00:51:38 +00:00
ttomsa
18cf8c57e8 fix 8bit idiv? 2026-02-23 00:39:35 +00:00
ttomsa
1e316e025a print formatted assembly 2026-02-23 00:17:20 +00:00
ttomsa
a9f8c06f84 fix 2026-02-22 18:56:12 +00:00
ttomsa
cb3c4b8b47 fix 2026-02-22 18:36:18 +00:00
ttomsa
999483490a fix 2026-02-22 17:49:09 +00:00
ttomsa
ad3882bf08 fix 2026-02-21 17:41:37 +00:00
ttomsa
250b1b2520 more changes 2026-02-21 17:37:47 +00:00
ttomsa
b5db91bfdf Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-02-21 16:47:52 +00:00
ttomsa
3f01b9970a more changes 2026-02-21 16:44:55 +00:00
ttomsa
9b3b425518 more changes 2026-02-21 16:39:20 +00:00
ttomsa
dd558ecfae add Ops.INS to x86 2026-02-19 00:20:21 +00:00
ttomsa
1f140d9d53 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-02-17 19:43:43 +00:00
ttomsa
194d498d28 fix idiv 2026-02-17 19:33:16 +00:00
ttomsa
d1c28c2692 simplify live range 2026-02-17 19:31:03 +00:00
ttomsa
72f341a534 cleanup encode 2026-02-10 20:45:49 +00:00
ttomsa
b32bafe1ae also this 2026-02-10 17:31:48 +00:00
ttomsa
0681e3311c Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-02-10 17:29:19 +00:00
ttomsa
9fbf64e339 more asserts 2026-02-10 17:22:56 +00:00
ttomsa
ce31a4fbec move NOOPs to pre_isel_matcher and rm NOOP from spec 2026-02-10 17:22:21 +00:00
ttomsa
86b5441781 allow set[X86Ops] in upat 2026-02-08 22:04:15 +00:00
ttomsa
878557004c add x86op test 2026-02-08 21:18:12 +00:00
ttomsa
80e68f3706 more linter 2026-02-08 20:33:23 +00:00
ttomsa
f0565ed5dc tell mypy to shut up 2026-02-08 19:50:03 +00:00
ttomsa
78171c4f70 spacing 2026-02-08 19:24:57 +00:00
ttomsa
e1bf9c9e02 rm X86GroupOp from ops.py 2026-02-08 19:18:00 +00:00
ttomsa
6ff67781f1
Merge branch 'master' into new_x86_backend 2026-02-08 19:15:14 +00:00
ttomsa
fe2b08bee3 fix imports 2026-02-08 19:15:03 +00:00
ttomsa
5c2b0b2363 allow for extending enums and move X86Ops out of uop 2026-02-08 19:08:49 +00:00
ttomsa
733789e294 and this 2026-02-06 23:38:34 +00:00
ttomsa
ef76bfa081 rm machine scheduler stuff 2026-02-06 23:37:38 +00:00
ttomsa
fdaad71b6a rm noqa 2026-02-06 22:36:26 +00:00
ttomsa
e2d49fa578 fix 2026-02-06 22:27:43 +00:00
ttomsa
74e24d53c9 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-02-06 22:18:38 +00:00
ttomsa
c4c69d8276 Ops becomes OpType 2026-02-06 22:16:57 +00:00
ttomsa
a3d1f8435a how much does this fix 2026-02-06 20:00:05 +00:00
ttomsa
1d8a277928 enable emulated int64 tests 2026-02-06 18:39:24 +00:00
ttomsa
4d6ed29af3 use def_reg in test_encodings 2026-02-06 18:36:24 +00:00
ttomsa
93022ac35a do mulacc in isel 2026-02-06 17:42:06 +00:00
ttomsa
6de9da8b9c
Merge branch 'master' into new_x86_backend 2026-02-04 16:18:47 +00:00
ttomsa
74e3d9faf3 add min x86op and neg in decomps 2026-02-04 16:00:53 +00:00
ttomsa
bbe012ac86 more 2026-02-03 16:50:30 +00:00
ttomsa
4c3081b613 switch to PARAM 2026-02-03 16:45:35 +00:00
ttomsa
69a27d9a5c Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-02-03 16:41:19 +00:00
ttomsa
77a28ac3f2 x86 goes in own linearize 2026-02-03 16:32:28 +00:00
ttomsa
0ae5c5e4f9 func arg order independent from op value 2026-02-02 15:47:49 +00:00
ttomsa
e9f2e89f8f
Merge branch 'master' into new_x86_backend 2026-02-02 15:22:49 +00:00
ttomsa
983f7a2155 skip a few tests 2026-02-01 19:23:07 +00:00
ttomsa
f0234b9da3 fix 2026-02-01 18:12:20 +00:00
ttomsa
a198cb54e2 fix const tag hack and add x86ops to _shape 2026-02-01 17:53:32 +00:00
ttomsa
b53bcb3319 skip inf rewrite tests 2026-01-31 20:27:04 +00:00
ttomsa
f1327ebff6 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-01-31 19:42:22 +00:00
ttomsa
6f977100ff skip bounds check when NOOPs exist 2026-01-31 16:15:22 +00:00
ttomsa
db3ed92ae3 fixup isel tests 2026-01-31 15:22:27 +00:00
ttomsa
3fcde08b20 cleaner shuffle functions 2026-01-30 20:29:45 +00:00
ttomsa
c1b2816d8b Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-01-29 19:20:40 +00:00
ttomsa
f8ade82553 more scheduling info 2026-01-29 19:18:00 +00:00
ttomsa
dd48f6a111 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-01-26 02:34:46 +00:00
ttomsa
1fe4185e89 start new scheduler 2026-01-26 02:30:38 +00:00
ttomsa
037c824f9d perhaps 2026-01-12 00:53:04 +00:00
ttomsa
609d9385d8 let's try this 2026-01-12 00:41:58 +00:00
ttomsa
5a61a10547 more 2026-01-11 23:13:31 +00:00
ttomsa
ff5f071ba2 more 2026-01-11 22:51:23 +00:00
ttomsa
7864067e34 more 2026-01-11 22:17:42 +00:00
ttomsa
a5e189794a more linter 2026-01-11 22:04:35 +00:00
ttomsa
7bafe52335 linter 2026-01-11 21:42:43 +00:00
ttomsa
c133d3b1d0 fix DEFINE_VAR/SPECIAL and enable multithreading 2026-01-11 20:46:23 +00:00
ttomsa
0daa1d94d0 Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-01-10 20:23:29 +00:00
ttomsa
423f7e66ca minor 2026-01-10 20:16:55 +00:00
ttomsa
7ab99089fc always fuse index 2026-01-07 02:55:06 +00:00
ttomsa
b4f8d64d2b add float max 2026-01-07 01:08:37 +00:00
ttomsa
f9b2f51554
Merge branch 'master' into new_x86_backend 2026-01-06 00:15:06 +00:00
ttomsa
0fe5d75982 fix remaining seg faults 2026-01-05 23:06:31 +00:00
ttomsa
138e20adcf no TUPLE_ORDER, breaks tests 2026-01-05 19:57:53 +00:00
ttomsa
d0d3272df1 support storing imms 2026-01-05 19:56:55 +00:00
ttomsa
243f6c85b9 add cmoves to the spec 2026-01-04 03:01:49 +00:00
ttomsa
c005ab0122 yup isel fixes the mask stuff too and its beautiful 2026-01-04 00:46:27 +00:00
ttomsa
f92e2d259a test windows ci 2026-01-02 00:24:48 +00:00
ttomsa
bcd8b2b5cc add BARRIER 2026-01-02 00:14:43 +00:00
ttomsa
f4309a3b1a rm TwoAddress2nd 2026-01-01 22:09:31 +00:00
ttomsa
587259976d float16 fix 2026-01-01 18:55:28 +00:00
ttomsa
8d4a48fcd3 add x86 backend to tests 2026-01-01 02:47:16 +00:00
ttomsa
885172f4bc Merge remote-tracking branch 'upstream/master' into new_x86_backend 2026-01-01 02:33:35 +00:00
ttomsa
1eca96ea44 fixes 2026-01-01 02:26:48 +00:00
ttomsa
12714337f0 add movabs instruction and fix idiv 2025-12-23 01:22:37 +00:00
ttomsa
32942f12b7 don't fuse load if used multiple times in src 2025-12-21 00:51:53 +00:00
ttomsa
8365bc84ee add vbroadcastss instruction 2025-12-20 23:51:51 +00:00
ttomsa
54396f5cb3 woops 2025-12-20 20:46:58 +00:00
ttomsa
b8f06970fa Merge remote-tracking branch 'upstream/master' into new_x86_backend 2025-12-20 20:43:21 +00:00
ttomsa
edb592f314 model flag state and support rematerialization 2025-12-20 20:24:24 +00:00
ttomsa
678a6b3689 cleanup test_isel 2025-12-20 20:22:35 +00:00
ttomsa
51e1292200 cleanup test_encodings 2025-12-20 20:21:25 +00:00
ttomsa
98f0ba7fb8 draft 2025-12-13 20:48:01 +00:00
21 changed files with 1620 additions and 72 deletions

View file

@ -793,7 +793,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [llvm, cpu, opencl, lvp]
backend: [llvm, cpu, opencl, lvp, x86]
name: Linux (${{ matrix.backend }})
runs-on: ubuntu-22.04
@ -810,7 +810,7 @@ jobs:
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
mesa: ${{ matrix.backend == 'lvp' && 'true' }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'DEV=CL' || matrix.backend == 'lvp' && 'DEV=CPU:LVP' }}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'DEV=CL' || matrix.backend == 'lvp' && 'DEV=CPU:LVP' || matrix.backend == 'x86' && 'DEV=CPU:X86' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CPU','CL'], Device.DEFAULT"
@ -960,7 +960,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [llvm, cpu, webgpu]
backend: [llvm, cpu, webgpu, x86]
name: Windows (${{ matrix.backend }})
runs-on: windows-latest
@ -976,7 +976,7 @@ jobs:
pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }}
- name: Set env
shell: bash
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'DEV=WEBGPU'}}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'DEV=WEBGPU' || matrix.backend == 'x86' && 'DEV=CPU:X86' }}" >> $GITHUB_ENV
- name: Run unit tests
if: matrix.backend=='llvm'
# test_newton_schulz hits RecursionError
@ -988,7 +988,7 @@ jobs:
- name: Run pytest (${{ matrix.backend }})
shell: bash
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU', 'X86':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
python -m pytest -n=auto test/test_tiny.py test/backend/test_ops.py --durations=20
# ****** Compile-only Tests ******

View file

@ -24,7 +24,7 @@ List all codegen steps for a kernel: `--rewrites -s E_3`
Get source code: `--rewrites -s E_3 -i "View Source"`
Inspect a graph rewrite: `--rewrites -s E_3 -i "initial symbolic"`
# SQTT tracing
## SQTT tracing
Supported on AMD for RDNA3 and RDNA4 (best) and CDNA (developing).
@ -38,8 +38,12 @@ You can select a specific trace with --source, Example workflow:
VIZ=-2 python extra/gemm/amd_asm_matmul.py
# View barriers
extra/viz/cli.py --profile -s "SQTT kernel PKTS SE:0" | rg BARRIER | head -10
extra/viz/cli.py --profile -s "kernel SQTT SE:0 PKTS" | rg BARRIER | head -10
# Get bank conflicts from performance counters
python extra/viz/cli.py -p -s "kernel PMC" -i "SQC_LDS_BANK_CONFLICT"
# Find the EXEC corresponding to a DISPATCH at cycle 410
extra/viz/cli.py --profile -s "SQTT kernel PKTS SE:0" | awk '/EXEC/ && $1 - $5 == 410'
extra/viz/cli.py --profile -s "kernel SQTT SE:0 PKTS" | awk '/EXEC/ && $1 - $5 == 410'
```

View file

@ -47,7 +47,9 @@ def decode_profile(data:bytes) -> dict:
def get(data:dict, key:str):
for k,v in data.items():
if ansistrip(k) == key: return v
raise RuntimeError(f'item "{key}" not found in list')
import difflib
match = difflib.get_close_matches(key, [ansistrip(k) for k in data], n=1, cutoff=0.6)
raise RuntimeError(f'item "{key}" not found in list'+(f", did you mean {match[0]!r}?" if match else ''))
def main(args) -> None:
viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))
@ -59,8 +61,8 @@ def main(args) -> None:
events:list = viz.load_pickle(args.profile_path, default=[])
if (profile_bytes:=viz.get_profile(events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
profile = decode_profile(profile_bytes)
profile["layout"].update([(f'{c["name"]} {s["name"]}', s["data"]) for c in viz.ctxs if c["name"].startswith("SQTT") for s in c["steps"]
if "PKTS" in s["name"]])
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz.ctxs
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
if args.src is None:
for k in profile["layout"]:
print(f" {format_colored(k)}")
@ -99,6 +101,20 @@ def main(args) -> None:
print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay or ''):<4} {info}")
return None
# ** PMC printer
if "PMC" in args.src:
table = viz.unpack_pmc(data[0])
cols = table["cols"]
rows:list = []
for r in table["rows"]:
if args.item is None: rows.append(r[:2])
elif args.item == r[0]:
rows = r[2]["rows"] if len(r) > 2 else [r[:2]]
cols = r[2]["cols"] if len(r) > 2 else cols
from tabulate import tabulate
print(tabulate(rows, headers=cols, tablefmt="github"))
return None
# ** Profiler printer
agg:dict[str, tuple[float, int]] = {}
total = 0

View file

@ -24,6 +24,10 @@ class TestArange(unittest.TestCase):
self.assertEqual(self._get_flops(Tensor.arange(256), np.arange(256)), 0)
self.assertEqual(self._get_flops(Tensor.arange(2560), np.arange(2560)), 0)
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
def test_arange_cumsum(self):
np.testing.assert_equal(Tensor.arange(513).cumsum(0).numpy(), np.arange(513).cumsum())
def test_arange_cat(self):
t = Tensor.arange(2, dtype=dtypes.int)+Tensor([3])
self.assertEqual(t.cat(t).tolist(), [3, 4, 3, 4])

View file

@ -0,0 +1,150 @@
import unittest
from typing import cast
from tinygrad import Device
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes
from tinygrad.renderer.isa.x86 import X86Ops, X86Renderer, RBP, RDI, RSP, RSI, RAX, RDX, XMM, GPR, imm, def_reg
def ins(op, dt, src, tag=None): return UOp(Ops.INS, arg=op, dtype=dt, src=src, tag=tag)
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only on x86")
class TestEncodingsX86(unittest.TestCase):
# NOTE: x86 supports a single displacement as memory address and index without base memory address
# these have no use cases so they aren't supported
def encode(self, u:UOp): return cast(X86Renderer, Device[Device.DEFAULT].renderer).render([u])
# displacement of 0 isn't emitted
def test_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RDI)
# mov edi, dword ptr [rdi]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 3F"))
# rsp/r12 require a sib byte when used as base memory address
def test_rsp_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RSP)
# mov esp, dword ptr [rsp]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 24 24"))
# rbp/r13 require a displacement when used as base memory address
def test_rbp_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RBP)
# mov ebp, dword ptr [rbp + 0]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 6D 00"))
# test [base + index*scale]
def test_base_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0)), RAX)
# mov eax, dword ptr [rax + rdx*4]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 04 90"))
# rsp as index means no index
def test_rsp_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0)), RAX)
# mov eax, dword ptr [rax]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 00"))
# however r12 is a valid index
def test_r12_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0)), RAX)
# mov eax, dword ptr [rax + r12*4]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("42 8B 04 A0"))
# test [base + index*scale + 8bit disp]
def test_complex_address_8bit_disp(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)), RDI)
# mov edi, dword ptr [rdi + rsi*4 + 0xa]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 7C B7 0A"))
# test [base + index*scale + 32bit disp]
def test_complex_address_32bit_disp(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000)), RDI)
# mov edi, dword ptr [rdi + rsi*4 + 0x2710]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B BC B7 10 27 00 00"))
# 8bit variants of legacy instructions subtract 1 from opcode
def test_8bit_legacy_encoding(self):
cast = ins(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDX),), RAX)
# movsx eax, dl
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("0F BE C2"))
# accessing lower 8 bits of rsp, rbp, rsi, rdi requires rex prefix
def test_lower_8bits_reg(self):
cast = ins(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDI),), RAX)
# movsx eax, dil
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("40 0F BE C7"))
# test 16 bit variant of legacy instruction
def test_16bit_legacy_encoding(self):
cast = ins(X86Ops.MOVSX, dtypes.int16, (def_reg(dtypes.int8, RDX),), RAX)
# movsx ax, dl
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("66 0F BE C2"))
# test 64 bit variant of legacy instruction
def test_64bit_legacy_encoding(self):
cast = ins(X86Ops.MOVSX, dtypes.int64, (def_reg(dtypes.int8, RDX),), RAX)
# movsx rax, dl
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("48 0F BE C2"))
# test compact vex encoding
def test_compact_vex_encoding(self):
xmm0, xmm1 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1])
add = ins(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm1), XMM[0])
# vaddss xmm0, xmm0, xmm1
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FA 58 C1"))
# test long vex encoding
def test_long_vex_encoding(self):
xmm0, xmm8 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[8])
add = ins(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm8), XMM[0])
# vaddss xmm0, xmm0, xmm8
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C4 C1 7A 58 C0"))
# test ymm encoding
def test_ymm_encoding(self):
xmm0, xmm1 = def_reg(dtypes.float32.vec(8), XMM[0]), def_reg(dtypes.float32.vec(8), XMM[1])
add = ins(X86Ops.VADDPS, dtypes.float32.vec(8), (xmm0, xmm1), XMM[0])
# vaddps ymm0, ymm0, ymm1
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FC 58 C1"))
# test encoding where register is in the immediate field
def test_reg_in_imm_field(self):
xmm0, xmm1, xmm2 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1]), def_reg(dtypes.float32, XMM[2])
blend = ins(X86Ops.VBLENDVPS, dtypes.float32, (xmm0, xmm1, xmm2), XMM[0])
# vblendvps xmm0, xmm0, xmm1, xmm2
self.assertEqual(bytes.fromhex(self.encode(blend)), bytes.fromhex("C4 E3 79 4A C1 20"))
# when writting to mem the uop takes the store form where dtype is void and there's no definition
def test_write_mem(self):
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
xmm0 = def_reg(dtypes.float32, XMM[0])
extr = ins(X86Ops.VPEXTRD, dtypes.void, (base, index, disp, xmm0, imm(dtypes.uint8, 0)))
# vpextrd dword ptr [rdi + rsi*4 + 0xa], xmm0, 0
self.assertEqual(bytes.fromhex(self.encode(extr)), bytes.fromhex("C4 E3 79 16 44 B7 0A 00"))
# test two address instruction with fused load works
def test_two_address_load(self):
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
cmove = ins(X86Ops.CMOVE, dtypes.int32, (base, index, disp), RAX)
# cmove eax, dword ptr [rdi + rsi*4 + 0xa]
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 44 B7 0A"))
# test instruction where displacement and imm have the same value
def test_disp_imm_same_value(self):
base, index, disp = def_reg(dtypes.int8.ptr(), RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10)
mov = ins(X86Ops.MOVi, dtypes.void, (base, index, disp, disp))
# mov byte ptr [rdi + rsi + 0xa], 0xa
self.assertEqual(bytes.fromhex(self.encode(mov)), bytes.fromhex("40 C6 44 37 0A 0A"))
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10)
imul = ins(X86Ops.IMULi, dtypes.int32, (base, index, disp) + (imm(dtypes.int32, 10),), RDI)
# imul edi, dword ptr [rdi + rsi*4 + 0xa], 0xa
self.assertEqual(bytes.fromhex(self.encode(imul)), bytes.fromhex("69 BC B7 0A 00 00 00 0A 00 00 00"))
# cmoves have the cmp as the last src even though it is not explicitly used, the cmp doesn't define a reg and is ignored in the encoding
def test_cmove_ignore_cmp(self):
cmove = ins(X86Ops.CMOVE, dtypes.int32, (def_reg(dtypes.int32, RAX), UOp(Ops.INS, arg=X86Ops.CMP)), RDX)
# cmove edx, eax
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 D0"))
if __name__ == "__main__":
unittest.main()

152
test/backend/test_isel.py Normal file
View file

@ -0,0 +1,152 @@
import unittest
from typing import cast
from tinygrad import Device
from tinygrad.uop import Ops
from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
from tinygrad.renderer.isa.x86 import X86Renderer, X86Ops
from tinygrad.renderer.isa import IselContext
# these tests are to catch changes that don't cause incorrect codegen but cause worse codegen
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only x86")
class TestIselX86(unittest.TestCase):
def isel_rewrite(self, x:UOp):
return graph_rewrite(x, cast(X86Renderer, Device[Device.DEFAULT].renderer).isel_matcher, IselContext(x), bottom_up=True)
def _check_op(self, dt_op, expr):
nargs = expr.__code__.co_argcount
for dt,op in dt_op:
with self.subTest(dtype=dt):
v = [UOp.variable(str(i), 0, 0, dt) for i in range(nargs)]
n = self.isel_rewrite(expr(*v))
self.assertIs(n.arg, op)
def test_cmove(self):
a = UOp.variable("a", 0, 0, dtypes.int32)
b = UOp.variable("b", 0, 0, dtypes.int32)
c = (a < b).where(a, b)
d = (a != b).where(a, b)
f = c + d
n = self.isel_rewrite(f)
self.assertTrue(n.src[0].arg is X86Ops.CMOVL and n.src[1].arg is X86Ops.CMOVNE)
# both comparisons become the same instruction
self.assertTrue(n.src[0].src[2] == n.src[1].src[2] and n.src[0].src[2].arg is X86Ops.CMP)
def test_vmax(self):
dt_op = [(dtypes.float32, X86Ops.VMAXSS), (dtypes.float64, X86Ops.VMAXSD),
(dtypes.float32.vec(4), X86Ops.VMAXPS), (dtypes.float64.vec(4), X86Ops.VMAXPD)]
self._check_op(dt_op, lambda a,b: (a < b).where(b, a))
def test_vmin(self):
dt_op = [(dtypes.float32, X86Ops.VMINSS), (dtypes.float64, X86Ops.VMINSD),
(dtypes.float32.vec(4), X86Ops.VMINPS), (dtypes.float64.vec(4), X86Ops.VMINPD)]
self._check_op(dt_op, lambda a,b: (a < b).where(a, b))
def test_vfmadd(self):
dt_op = [(dtypes.float32, X86Ops.VFMADD213SS), (dtypes.float64, X86Ops.VFMADD213SD),
(dtypes.float32.vec(4), X86Ops.VFMADD213PS), (dtypes.float64.vec(4), X86Ops.VFMADD213PD)]
self._check_op(dt_op, lambda a,b,c: a * b + c)
# TODO: shouldn't match fmadd if var is used multiple times
@unittest.expectedFailure
def test_vfmadd_fail(self):
dt_op = [(dtypes.float32, X86Ops.VADDSS), (dtypes.float64, X86Ops.VADDSD),
(dtypes.float32.vec(4), X86Ops.VADDPS), (dtypes.float64.vec(4), X86Ops.VADDPD)]
self._check_op(dt_op, lambda a,b: a * b + b)
def test_vpbroadcast(self):
a = UOp.variable("a", 0, 0, dtypes.int32)
n = self.isel_rewrite(a.broadcast(4))
# need to move src from gpr to xmm before broadcasting
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and n.src[0].arg is X86Ops.VMOVD)
# if we can fuse a load we can skip the move and access memory directly
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
n = self.isel_rewrite(load.broadcast(4))
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
def test_vbroadcastss(self):
a = UOp.variable("a", 0, 0, dtypes.float32)
valid = [UOp.vectorize(a, a, a, a), UOp.vectorize(a, a, a, a, a, a, a, a)]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VBROADCASTSS)
def test_vshufps(self):
a = UOp.variable("a", 0, 0, dtypes.float32.vec(8))
b = UOp.variable("b", 0, 0, dtypes.float32.vec(8))
c = UOp.variable("c", 0, 0, dtypes.float32)
d = UOp.variable("d", 0, 0, dtypes.float32)
valid = [UOp.vectorize(c, c, d, d),
UOp.vectorize(a.gep(0), a.gep(1), c, c),
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(1), a.gep(2), a.gep(3), a.gep(0)),
UOp.vectorize(a.gep(3), a.gep(2), a.gep(1), a.gep(0), a.gep(7), a.gep(6), a.gep(5), a.gep(4)),
UOp.vectorize(a.gep(0), a.gep(0), b.gep(1), b.gep(1), a.gep(4), a.gep(4), b.gep(5), b.gep(5))]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
invalid = [UOp.vectorize(a.gep(0), a.gep(1), b.gep(4), b.gep(5)),
UOp.vectorize(a.gep(0), a.gep(5), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(0), a.gep(0), a.gep(0), a.gep(0), a.gep(4), a.gep(4), a.gep(4), a.gep(5)),
UOp.vectorize(a.gep(0), a.gep(0), b.gep(0), b.gep(0), a.gep(4), a.gep(4), b.gep(4), a.gep(4))]
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
def test_vshufpd(self):
a = UOp.variable("a", 0, 0, dtypes.float64.vec(4))
b = UOp.variable("b", 0, 0, dtypes.float64.vec(4))
c = UOp.variable("c", 0, 0, dtypes.float64)
d = UOp.variable("d", 0, 0, dtypes.float64)
valid = [UOp.vectorize(c, d),
UOp.vectorize(a.gep(0), c),
UOp.vectorize(a.gep(1), b.gep(1)),
UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(1), a.gep(1), a.gep(3), a.gep(3))]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
invalid = [UOp.vectorize(c, c, c, c),
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(2), b.gep(3), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(0), b.gep(1), a.gep(0), b.gep(1))]
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
# this is the fallback slow VECTORIZE, 1 vinsertps per src in VECTORIZE
def test_vinsertps(self):
a = UOp.variable("a", 0, 0, dtypes.float32.vec(4))
b = UOp.variable("b", 0, 0, dtypes.float32.vec(4))
c = UOp.variable("c", 0, 0, dtypes.float32.vec(4))
d = UOp.variable("e", 0, 0, dtypes.float32)
# pack 1 from vector and 1 from scalar, moving 0th element to position 0 does nothing so only 1 vinsertps is generated
n = self.isel_rewrite(UOp.vectorize(a.gep(0), d))
self.assertIs(n.arg, X86Ops.VINSERTPS)
self.assertIsNot(n.src[0].arg, X86Ops.VINSERTPS)
valid = [UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)), # TODO: this should be vunpck
UOp.vectorize(a.gep(3), b.gep(2), c.gep(1), d)]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VINSERTPS)
# complex address is [base + index*scale + displacement]
def test_complex_address(self):
a = UOp.variable("a", 0, 0, dtypes.int32)
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(a + 1, ptr=True).load()
n = self.isel_rewrite(load)
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
self.assertTrue(n.src[2].op is Ops.CONST and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4)
def test_fold_load(self):
load1 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load2 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 1), ptr=True).load()
n = self.isel_rewrite(load1 + load2)
self.assertTrue(len(n.src) == 4)
# don't fold when used multiple times
def test_dont_fold_load(self):
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
# used by multiple users
n = self.isel_rewrite(load + 1 + load)
self.assertTrue(len(n.src) == 2)
# used mutiple times by same user
n = self.isel_rewrite(load * load)
self.assertTrue(len(n.src) == 2)
# TODO: might want to check that load isn't part of another range when fusing
if __name__ == "__main__":
unittest.main()

View file

@ -510,7 +510,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
def test_cumsum_parallel_reduce_fused(self):
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END (same nesting context = should merge)
step, num_steps = 513, 10
t = Tensor.arange(step).float().realize()
phase = t.cumsum()
@ -521,6 +521,12 @@ class TestSchedule(unittest.TestCase):
expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten()
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
def test_reduce_different_nesting_depth(self):
# two REDUCEs sharing the same RANGE at different nesting depths must NOT merge
x = Tensor.arange(768).reshape(3, 256).float()
np.testing.assert_allclose((x.sum(axis=1) + x.sum(axis=1).sum()).numpy(), x.numpy().sum(axis=1) + x.numpy().sum(axis=1).sum())
def test_multimatmul_fusion(self):
Tensor.manual_seed(0)
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()

View file

@ -12,6 +12,51 @@ class TestC(unittest.TestCase):
subprocess.check_output(('clang', '-x', 'c', '-fPIC', '-shared', '-', '-o', f.name), input=src.encode())
return DLL("test", f.name)
def test_struct_array_init(self):
@record
class Foo:
SIZE = 12
a: Annotated[ctypes.c_int * 3, 0]
init_records()
f = Foo((1,2,3))
assert f.a[0] == 1
assert f.a[1] == 2
assert f.a[2] == 3
f = Foo((ctypes.c_int * 3)(1,2,3))
assert f.a[0] == 1
assert f.a[1] == 2
assert f.a[2] == 3
def test_field_ranges(self):
@record
class Foo:
SIZE = 2
s: Annotated[ctypes.c_int8, 0]
u: Annotated[ctypes.c_uint8, 1]
init_records()
f = Foo()
f.s = -1
f.u = -1
assert f.s == -1
assert f.u == 255
# this syntax is inherited from ctypes, but it seems a bit nonsensical?
def test_voidp_none(self):
@record
class Foo:
SIZE = 8
p: Annotated[ctypes.c_void_p, 0]
init_records()
f = Foo(None)
assert f.p is None
f.p = ctypes.c_void_p(0xDEADBEEF)
assert f.p == 0xDEADBEEF
f.p = None
assert f.p is None
def test_packed_struct(self):
@record
class Baz:

View file

@ -101,6 +101,13 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
from tinygrad.renderer.isa import ISARenderer, IselContext
if isinstance(ren, ISARenderer):
linear_sink = graph_rewrite(sink, ren.pre_isel_matcher, name="pre instruction selection", bottom_up=True)
isel_ctx = IselContext(linear_sink)
linear_sink = graph_rewrite(linear_sink, ren.isel_matcher, ctx=isel_ctx, name="instruction selection", bottom_up=True)
sink = linear_sink
# return the rewritten sink
return sink
@ -114,20 +121,35 @@ pm_linearize_cleanups = PatternMatcher([
])
# requires lst be toposorted. like graph rewrite, but for lines
def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]:
def line_rewrite(lst:list[UOp], pm:PatternMatcher, ctx=None) -> list[UOp]:
newlst = []
replaced: dict[UOp, UOp] = {}
for u in lst:
nu = u.replace(src=tuple([replaced[x] for x in u.src]))
ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu)) or (nu, [nu])
nu = u.replace(src=tuple([replaced.get(x, x) for x in u.src]))
ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu, ctx)) or (nu, [nu])
replaced[u] = ret[0]
newlst.extend(ret[1])
return newlst
def do_linearize(prg:UOp, sink:UOp) -> UOp:
lst = line_rewrite(linearize(sink), pm_linearize_cleanups)
if SPEC: type_verify(lst, program_spec)
return prg.replace(src=prg.src + (UOp(Ops.LINEAR, src=tuple(lst)),))
def do_linearize(ctx:Renderer, prg:UOp, sink:UOp) -> UOp:
from tinygrad.renderer.isa import ISARenderer
generic_lst = line_rewrite(linearize(sink), pm_linearize_cleanups) if sink.arg.estimates is None and not isinstance(ctx, ISARenderer) else None
if isinstance(ctx, ISARenderer):
from tinygrad.renderer.isa import PreRegAllocContext
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
lst = linearize(sink)
if ctx.pre_regalloc_matcher is not None: lst = line_rewrite(lst, ctx.pre_regalloc_matcher, PreRegAllocContext())
regalloc_ctx = LinearScanRegallocContext(lst, ctx)
lst = line_rewrite(lst, pm_regalloc_rewrite, regalloc_ctx)
if ctx.late_regalloc_matcher is not None: lst = line_rewrite(lst, ctx.late_regalloc_matcher, regalloc_ctx)
lst = line_rewrite(lst, ctx.post_regalloc_matcher, regalloc_ctx)
if DEBUG >= 4: print(ctx.asm(lst, sink.arg.function_name))
if SPEC: type_verify(lst, ctx.isa_spec)
else:
assert generic_lst is not None
lst = generic_lst
if SPEC: type_verify(lst, program_spec)
return prg.replace(src=(sink,)+prg.src[1:] + (UOp(Ops.LINEAR, src=tuple(lst)),))
def do_estimates(prg:UOp, sink:UOp, lin:UOp) -> UOp|None:
if sink.arg.estimates is not None: return None

View file

@ -328,11 +328,23 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
return acc.after(end).index(UOp.const(dtypes.int, 0))
def merge_reduce_ends(ctx:ReduceContext, sink:UOp):
# merge ENDs that share the same range (only those created by reduce_to_acc)
# merge ENDs that share the same range and nesting context (only those created by reduce_to_acc)
# ENDs at different nesting depths get cloned RANGEs so each RANGE maps to one END
range_to_ends: dict[tuple[UOp, ...], list[UOp]] = {}
for u in sink.backward_slice:
if u.op is Ops.END and u.tag == "mergeable": range_to_ends.setdefault(u.src[1:], []).append(u)
subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in range_to_ends.items() if len(ends) > 1 for e in ends}
subs: dict[UOp, UOp] = {}
next_axis = max((u.arg[0] for u in sink.backward_slice if u.op is Ops.RANGE), default=-1) + 1
for r, ends in range_to_ends.items():
if len(ends) <= 1: continue
by_ctx: dict[frozenset[UOp], list[UOp]] = {}
for e in ends: by_ctx.setdefault(frozenset(e.ranges), []).append(e)
for i, group in enumerate(by_ctx.values()):
tr = r if i == 0 else tuple(rr.replace(arg=(next_axis + j, *rr.arg[1:])) for j, rr in enumerate(r))
if i > 0: next_axis += len(r)
mapped = [e.substitute(dict(zip(r, tr))) if i > 0 else e for e in group]
merged = mapped[0] if len(mapped) == 1 else UOp.group(*(e.src[0] for e in mapped)).end(*tr)
for e in group: subs[e] = merged
return sink.substitute(subs) if subs else None
pm_reduce = PatternMatcher([

View file

@ -2,7 +2,7 @@ import heapq
from typing import Any
from collections import defaultdict
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
from tinygrad.helpers import prod, getenv, TUPLE_ORDER, DEV
def linearize(sink:UOp) -> list[UOp]:
# this is a toposort with priority
@ -31,10 +31,18 @@ def linearize(sink:UOp) -> list[UOp]:
case Ops.RANGE: priority = 5 # placing RANGE is good
case Ops.END: priority = -5 # placing END is bad
case _: priority = 0 # everything else has priority 0
# stack pointer needs to be scheduled at the top of the kernel
# TODO: remove once there's a proper isa scheduler
if u.op is Ops.INS:
from tinygrad.renderer.isa.x86 import X86Ops, RSP
match u.arg:
case X86Ops.DEFINE_REG: priority, extra = (-21 if u.tag[0] == RSP else -20), u.tag[0].index
priorities[u] = (run_count, priority, extra)
# number the uops in "ideal" order
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER else ())))}
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER and DEV.value.renderer != "X86" else ())))}
# then force them to be toposorted in as close to the ideal order as possible
heap = [(-nkey[sink], sink)]
@ -93,4 +101,4 @@ def do_split_ends(e:UOp):
pm_split_ends = PatternMatcher([
# split the ends
(UPat(Ops.END, name="e"), do_split_ends),
])
])

View file

@ -0,0 +1,132 @@
import itertools
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
from tinygrad.renderer.isa import ISARenderer, Register
from tinygrad.dtype import dtypes, PtrDType
PSEUDO_OPS = {Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP}
def _uop_key(u:UOp): return (u.op, u.dtype, u.arg)
# loosely based on: https://bernsteinbear.com/assets/img/register-spilling-range-splitting-ssa.pdf
class LinearScanRegallocContext:
def __init__(self, uops:list[UOp], ren:ISARenderer):
if saved:=ren.callee_saved():
ret_i = next(i for i,u in reversed(tuple(enumerate(uops))) if u.op is Ops.INS and getattr(u.arg, "name", None) == "RET")
uops[0:0] = saved
uops[ret_i+len(saved)] = uops[ret_i+len(saved)].replace(src=uops[ret_i+len(saved)].src + saved)
live_range: dict[Register, list[int]] = {}
live: dict[Register, Register] = {}
live_ins: list[dict[Register, Register]] = []
self.defs: dict[Register, UOp] = {} # mapping from virtual to uop that defines it
self.real_defs: dict[Register, Register] = {} # mapping from virtual to real at definition
self.spills: dict[Register, UOp] = {} # mapping from virtual to stack slot
self.fills: dict[int, dict[int, tuple[Register, Register]]] = {} # mapping from program point to mapping from idx to virtual and real to fill to
self.insert_before: dict[int, list[tuple[Register, Register]]] = {} # mapping from program point to fills to be inserted
self.idx = itertools.count()
self.ren = ren
self.stack_size = 0
# the label associated with each loop NOTE: this is only used post regalloc and should be removed
self.loop_label: dict[UOp, str] = {}
arg_order = {Ops.PARAM: 0, Ops.DEFINE_VAR: 1, Ops.SPECIAL: 2}
self.func_arg_idxs = {_uop_key(u): i for i,u in enumerate(sorted({u for u in uops if u.op in arg_order}, key=lambda k: (arg_order[k.op], k.arg)))}
self.local_offsets: dict[tuple, int] = {}
for u in uops:
if u.op not in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG): continue
self.local_offsets.setdefault(_uop_key(u), self.stack_size)
self.stack_size += u.dtype.nbytes()
# compute live ranges
lr = live_range
ranges: list[Register] = []
for i,u in enumerate(reversed(uops)):
if u.op in PSEUDO_OPS: continue
defs = u.tag if isinstance(u.tag, tuple) else ()
for v in defs + tuple(s.reg for s in set(u.src)):
if isinstance(v, Register): lr.setdefault(v, []).insert(0, len(uops) - 1 - i)
for v in defs:
if isinstance(v, Register): self.defs[v] = u
if v in lr and (n:=max((lr[rng][-1] for rng in ranges if lr[rng][0] <= lr[v][-1] < lr[rng][-1]), default=None)): lr[v].append(n)
if u.op is Ops.RANGE: ranges.append(u.reg)
def alloc(cons:tuple[Register, ...], i:int) -> Register:
live_inv = {v:k for k,v in live.items()}
# allocate the best register. Registers not in live or not used again are free and have priority,
# otherwise pick the one with the furthest next use. Regs that appear first in cons have priority in case of a tie
reg,vreg = max(((r,live_inv.get(r)) for r in cons),
key=lambda rv: next((j-i for j in ([] if rv[1] is None else live_range[rv[1]]) if j >= i), len(uops)))
return live.pop(vreg) if vreg is not None else reg
# assign register to spilled virtual and record load to be emitted before current uop, also assign it a stack slot
def fill(v:Register, i:int, cons:tuple[Register, ...]|None=None) -> Register:
if v not in self.spills:
dt = self.defs[v].dtype
sz = dt.scalar().itemsize * dt.count if not isinstance(dt, PtrDType) else 8
assert sz > 0
offset = self.stack_size + (sz - self.stack_size % sz) % sz
self.spills[v] = UOp.const(dtypes.int32, offset)
self.stack_size = offset + sz
r = alloc(cons if cons is not None else v.cons, i)
self.insert_before.setdefault(i, []).append((v, r))
return r
for i,u in enumerate(uops):
if u.op in PSEUDO_OPS: continue
# allocate uses
for j,s in enumerate(u.src):
# HACK: cause of later hacks to lower range
if u.op is Ops.END: continue
if not isinstance(v:=s.reg, Register): continue
if v not in live: live[v] = fill(v, i)
if v in self.spills: self.fills.setdefault(i, {})[j] = (v, live[v])
# allocate defs
if isinstance(u.tag, tuple):
for j,v in enumerate(u.tag):
assert isinstance(v, Register) and v not in live
cons = v.cons
# two address instructions (src is reused by def) can only coalesce reused src. reused src goes first to get priority in case of a tiebreak
if ren.is_two_address(u) and j == 0:
ins = tuple(live.get(s.reg) for s in u.src)
cons = ((ins[0],) if ins[0] in cons else ()) + tuple(r for r in cons if r not in ins)
assert cons
# HACK: cause the range is missing the comparison
self.real_defs[v] = live[v] = alloc(cons, i+1 if u.op is not Ops.RANGE else i)
# loop prologue, avoid loading inside the loop
if u.op is Ops.RANGE:
# we move to registers vars used in the loop sorted by next use, vars not used in the loop will not be reloaded in the epilogue
used_in_loop = [v for v in live.keys() | self.spills.keys() if any(i <= l < live_range[u.reg][-1] for l in live_range[v])]
sorted_uses = sorted(used_in_loop, key=lambda k: next(l-i for l in live_range[k] if l >= i))
live_in: dict[Register, Register] = {}
for v in sorted_uses:
# if all the possible registers are already in live_in there's no space for this var
if set(v.cons).issubset(live_in.values()): continue
if v not in live: live[v] = fill(v, i)
assert live[v] not in live_in.values()
live_in[v] = live[v]
live_ins.append(live_in)
# loop epilogue, reload registers that were live at loop entry
if u.op is Ops.END:
# TODO: if a uop is in a different reg in live out vs live in move between registers instead of loading
# TODO: don't reload if first use in loop is a load
for v,r in live_ins.pop().items():
if v not in live or live[v] != r: live[v] = fill(v, i, (r,))
def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
i = next(ctx.idx)
if x.op in PSEUDO_OPS: return None
nsrc = []
for j,s in enumerate(x.src):
if i in ctx.fills and j in ctx.fills[i]:
v,r = ctx.fills[i][j]
nsrc.append(ctx.ren.fill(ctx.spills[v], ctx.defs[v], r))
else: nsrc.append(s)
ndefs = tuple(ctx.real_defs[v] for v in x.tag) if isinstance(x.tag, tuple) else x.tag
nx = x.replace(src=tuple(nsrc), tag=ndefs)
fills = [ctx.ren.fill(ctx.spills[v], ctx.defs[v], r) for v,r in ctx.insert_before.get(i, [])]
spills = [ctx.ren.spill(ctx.spills[v], nx) for v in x.tag if v in ctx.spills] if isinstance(x.tag, tuple) else []
return nx, fills + [nx] + spills
pm_regalloc_rewrite = PatternMatcher([
(UPat({Ops.INS, Ops.CONST, Ops.RANGE, Ops.END, Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER,
Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"), regalloc_rewrite),
])

View file

@ -319,7 +319,8 @@ def is_dtype_supported(dtype:DType, target:Target|None=None) -> bool:
case "METAL": return not CI or BENCHMARKS
case "CUDA": return (not CI or BENCHMARKS) and target.renderer != "PTX"
case "NV": return (not CI or BENCHMARKS) and target.renderer not in ("PTX", "NAK")
case "CPU": return (not CI or BENCHMARKS) and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and target.renderer != "LVP"
case "CPU": return (not CI or BENCHMARKS) and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and \
target.renderer not in ("LVP", "X86")
case "AMD" | "CL" | "PYTHON" | "NULL": return True
case _: return False
if dtype in dtypes.fp8_ocp:

View file

@ -3,7 +3,7 @@ import ctypes
from tinygrad.helpers import ceildiv, round_up
from tinygrad.uop.ops import UOp, Ops
from tinygrad.runtime.autogen import amdgpu_kd, hsa, libc
from tinygrad.renderer.amd.dsl import Reg, FixedBitField
from tinygrad.renderer.amd.dsl import Inst, Reg, FixedBitField
from tinygrad.runtime.autogen.amd.common import OpType
# instructions used for padding
@ -11,8 +11,9 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import s_code_end # same encoding as
from tinygrad.runtime.autogen.amd.cdna.ins import s_nop as s_nop_cdna
_arch_map = {"gfx9": "cdna", "gfx10": "rdna3", "gfx11": "rdna3", "gfx12": "rdna4"}
def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp:
def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp|None:
insts = [u.arg for u in lin.src]
if not all(isinstance(inst, Inst) for inst in insts): return None
# ** scan for max vgpr/sgpr/accvgpr
max_vgpr, max_sgpr, max_accvgpr = 0, 0, 0

View file

@ -0,0 +1,44 @@
from __future__ import annotations
import itertools
from dataclasses import dataclass, field
from tinygrad.renderer import Renderer
from tinygrad.uop.ops import PatternMatcher, UOp, Ops
@dataclass(frozen=True)
class Register:
name: str
index: int
_cons: tuple[Register, ...] = field(default_factory=tuple)
@property
def cons(self): return self._cons or (self,)
def __repr__(self): return self.name
class IselContext:
def __init__(self, sink:UOp):
self.uses = sink.get_consumer_map()
self.reg_n = itertools.count()
arg_order = {Ops.PARAM: 0, Ops.DEFINE_VAR: 1, Ops.SPECIAL: 2}
self.func_args = sorted([u for u in self.uses if u.op in arg_order], key=lambda k: (arg_order[k.op], k.arg))
def vreg(self, cons:tuple[Register, ...]|Register):
return Register(f"v{next(self.reg_n)}", 0, _cons=cons if isinstance(cons, tuple) else (cons,))
@dataclass
class PreRegAllocContext:
lock: UOp|None = None
clobbered: set[UOp] = field(default_factory=set)
class ISARenderer(Renderer):
isa_spec: PatternMatcher
pre_isel_matcher: PatternMatcher
isel_matcher: PatternMatcher
pre_regalloc_matcher: PatternMatcher|None = None
late_regalloc_matcher: PatternMatcher|None = None
post_regalloc_matcher: PatternMatcher
def callee_saved(self) -> tuple[UOp, ...]: return tuple()
def is_two_address(self, x:UOp) -> bool: return False
def copy(self, x:UOp, reg:Register) -> UOp: raise NotImplementedError("arch specific")
def spill(self, disp:UOp, x:UOp) -> UOp: raise NotImplementedError("arch specific")
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp: raise NotImplementedError("arch specific")
def asm(self, uops:list[UOp], function_name:str) -> str: raise NotImplementedError("arch specific")

View file

@ -0,0 +1,936 @@
# flake8: noqa: E702
# allow semicolons to put multiple ops on one line
import sys, struct, functools
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, DType, truncate
from tinygrad.uop import FastEnum, auto, Ops, GroupOp
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
from tinygrad.renderer.isa import ISARenderer, IselContext, Register, PreRegAllocContext
from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
# ***** X86 Ops *****
class X86Ops(FastEnum):
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
# these aren't real instructions
DEFINE_REG = auto(); FRAME_INDEX = auto(); LABEL = auto()
# index
LEA = auto()
# register / memory / immediate moves
MOV = auto(); MOVm = auto(); MOVi = auto(); MOVABS = auto()
VMOVSS = auto(); VMOVSD = auto(); VMOVUPS = auto()
VMOVSSm = auto(); VMOVSDm = auto(); VMOVUPSm = auto()
# casts
MOVZX = auto(); MOVSX = auto(); MOVSXD = auto()
VPMOVZXBW = auto(); VPMOVZXBD = auto(); VPMOVZXBQ = auto()
VPMOVZXWD = auto(); VPMOVZXWQ = auto(); VPMOVZXDQ = auto()
VPMOVSXBW = auto(); VPMOVSXBD = auto(); VPMOVSXBQ = auto()
VPMOVSXWD = auto(); VPMOVSXWQ = auto(); VPMOVSXDQ = auto()
VCVTDQ2PS = auto(); VCVTDQ2PD = auto(); VCVTTPS2DQ = auto(); VCVTTPD2DQ = auto()
VCVTPH2PS = auto(); VCVTPS2PH = auto(); VCVTPS2PD = auto(); VCVTPD2PS = auto()
VCVTSS2SD = auto(); VCVTSD2SS = auto(); VCVTSI2SS = auto(); VCVTSI2SD = auto()
VCVTTSS2SI = auto(); VCVTTSD2SI = auto()
# bitcasts
VMOVD = auto(); VMOVQ = auto(); VMOVDm = auto(); VMOVQm = auto()
# comparisons
VUCOMISS = auto(); VUCOMISD = auto()
VCMPSS = auto(); VCMPSD = auto(); VCMPPS = auto(); VCMPPD = auto()
VPCMPGTB = auto(); VPCMPGTW = auto(); VPCMPGTD = auto(); VPCMPGTQ = auto()
VPCMPEQB = auto(); VPCMPEQW = auto(); VPCMPEQD = auto(); VPCMPEQQ = auto()
SETNE = auto(); SETE = auto(); SETL = auto(); SETB = auto()
# where
CMOVNE = auto(); CMOVE = auto(); CMOVL = auto(); CMOVB = auto()
VPBLENDVB = auto(); VBLENDVPS = auto(); VBLENDVPD = auto()
# jumps
JNE = auto(); JE = auto(); JL = auto(); JB = auto(); JGE = auto(); JMP = auto()
# vectorize / gep
VSHUFPS = auto(); VSHUFPD = auto(); VINSERTPS = auto(); VPSRLDQ = auto()
VPEXTRB = auto(); VPEXTRW = auto(); VPEXTRD = auto(); VPEXTRQ = auto()
VPINSRB = auto(); VPINSRW = auto(); VPINSRD = auto(); VPINSRQ = auto()
VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto()
VBROADCASTSS = auto()
# int binary
IDIV = auto(); DIV = auto()
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto()
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto()
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto()
# float unary (sometimes not unary)
VROUNDSS = auto(); VROUNDSD = auto(); VROUNDPS = auto(); VROUNDPD = auto()
VSQRTSS = auto(); VSQRTSD = auto(); VSQRTPS = auto(); VSQRTPD = auto()
# float scalar / vector binary
VADDSS = auto(); VADDSD = auto(); VADDPS = auto(); VADDPD = auto()
VSUBSS = auto(); VSUBSD = auto(); VSUBPS = auto(); VSUBPD = auto()
VMULSS = auto(); VMULSD = auto(); VMULPS = auto(); VMULPD = auto()
VDIVSS = auto(); VDIVSD = auto(); VDIVPS = auto(); VDIVPD = auto()
VMAXSS = auto(); VMAXSD = auto(); VMAXPS = auto(); VMAXPD = auto()
VMINSS = auto(); VMINSD = auto(); VMINPS = auto(); VMINPD = auto()
# int vector binary
VPADDB = auto(); VPADDW = auto(); VPADDD = auto(); VPADDQ = auto()
VPSUBB = auto(); VPSUBW = auto(); VPSUBD = auto(); VPSUBQ = auto()
VPMULLW = auto(); VPMULLD = auto()
# packed bitwise
VPAND = auto(); VPOR = auto(); VPXOR = auto()
# packed variable shifts
VPSLLVD = auto(); VPSLLVQ = auto(); VPSRLVD = auto(); VPSRLVQ = auto(); VPSRAVD = auto()
# fused multiply add TODO: add other variants to fuse more loads
VFMADD213SS = auto(); VFMADD213SD = auto(); VFMADD213PS = auto(); VFMADD213PD = auto()
# return
RET = auto()
# TODO: add commutative groupop to fuse more loads
class X86GroupOp:
# X86Ops whose first src is also the destination
TwoAddress = {X86Ops.ADD, X86Ops.ADDi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi, X86Ops.OR, X86Ops.ORi, X86Ops.IMUL,
X86Ops.SUB, X86Ops.SUBi, X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi,
X86Ops.IDIV, X86Ops.DIV, X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD,
X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB}
# X86Ops whose first src can read from memory
ReadMem1st = {X86Ops.MOV, X86Ops.VMOVSS, X86Ops.VMOVSD, X86Ops.VMOVUPS, X86Ops.MOVZX, X86Ops.MOVSX, X86Ops.MOVSXD, X86Ops.VMOVD, X86Ops.VMOVQ,
X86Ops.VPMOVZXBW, X86Ops.VPMOVZXBD, X86Ops.VPMOVZXBQ, X86Ops.VPMOVZXWD, X86Ops.VPMOVZXWQ, X86Ops.VPMOVZXDQ,
X86Ops.VPMOVSXBW, X86Ops.VPMOVSXBD, X86Ops.VPMOVSXBQ, X86Ops.VPMOVSXWD, X86Ops.VPMOVSXWQ, X86Ops.VPMOVSXDQ,
X86Ops.VCVTDQ2PS, X86Ops.VCVTDQ2PD, X86Ops.VCVTTPS2DQ, X86Ops.VCVTTPD2DQ, X86Ops.VCVTTSS2SI, X86Ops.VCVTTSD2SI,
X86Ops.VCVTPH2PS, X86Ops.VCVTPS2PD, X86Ops.VCVTPD2PS, X86Ops.VROUNDPS, X86Ops.VROUNDPD, X86Ops.VSQRTPS, X86Ops.VSQRTPD,
X86Ops.VPBROADCASTB, X86Ops.VPBROADCASTW, X86Ops.VPBROADCASTD, X86Ops.VPBROADCASTQ, X86Ops.VBROADCASTSS,
X86Ops.CMPi, X86Ops.IMULi, X86Ops.LEA}
# X86Ops whose second src can read from memory NOTE: some of these are TwoAddress so the second src is actually the first
ReadMem2nd = {X86Ops.ADD, X86Ops.SUB, X86Ops.AND, X86Ops.OR, X86Ops.XOR, X86Ops.SHL, X86Ops.SHR, X86Ops.SAR, X86Ops.IMUL, X86Ops.CMP,
X86Ops.VADDSS, X86Ops.VADDSD, X86Ops.VADDPS, X86Ops.VADDPD, X86Ops.VSUBSS, X86Ops.VSUBSD, X86Ops.VSUBPS, X86Ops.VSUBPD,
X86Ops.VMULSS, X86Ops.VMULSD, X86Ops.VMULPS, X86Ops.VMULPD, X86Ops.VDIVSS, X86Ops.VDIVSD, X86Ops.VDIVPS, X86Ops.VDIVPD,
X86Ops.VPADDB, X86Ops.VPADDW, X86Ops.VPADDD, X86Ops.VPADDQ, X86Ops.VPSUBB, X86Ops.VPSUBW, X86Ops.VPSUBD, X86Ops.VPSUBQ,
X86Ops.VPCMPEQB, X86Ops.VPCMPEQW, X86Ops.VPCMPEQD, X86Ops.VPCMPEQQ, X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD,
X86Ops.VPCMPGTB, X86Ops.VPCMPGTW, X86Ops.VPCMPGTD, X86Ops.VPCMPGTQ, X86Ops.VCMPSS, X86Ops.VCMPSD, X86Ops.VCMPPS, X86Ops.VCMPPD,
X86Ops.VPMULLW, X86Ops.VPMULLD, X86Ops.VROUNDSS, X86Ops.VROUNDSD, X86Ops.VSQRTSS, X86Ops.VSQRTSD, X86Ops.VSHUFPS, X86Ops.VINSERTPS,
X86Ops.VPINSRB, X86Ops.VPINSRW, X86Ops.VPINSRD, X86Ops.VPINSRQ, X86Ops.VPAND, X86Ops.VPOR, X86Ops.VPXOR, X86Ops.VPSLLVD,
X86Ops.VPSLLVQ, X86Ops.VPSRLVD, X86Ops.VPSRLVQ, X86Ops.VPSRAVD, X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB,
X86Ops.VMAXSS, X86Ops.VMAXSD, X86Ops.VMAXPS, X86Ops.VMAXPD, X86Ops.VMINSS, X86Ops.VMINSD, X86Ops.VMINPS, X86Ops.VMINPD,
X86Ops.VCVTSI2SS, X86Ops.VCVTSI2SD, X86Ops.VCVTSS2SD, X86Ops.VCVTSD2SS, X86Ops.VUCOMISS, X86Ops.VUCOMISD, X86Ops.IDIV, X86Ops.DIV,
X86Ops.VSHUFPD}
# X86Ops whose third src can read from memory NOTE: these are TwoAddress so the third src is actually the second
ReadMem3rd = {X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD}
# X86Ops that can write to memory
WriteMem = {X86Ops.MOVm, X86Ops.MOVi, X86Ops.VMOVSSm, X86Ops.VMOVSDm, X86Ops.VMOVUPSm, X86Ops.VMOVDm, X86Ops.VMOVQm,
X86Ops.ADDi, X86Ops.SUBi, X86Ops.ANDi, X86Ops.ORi, X86Ops.XORi, X86Ops.SHLi, X86Ops.SHRi, X86Ops.SARi, X86Ops.SETNE,
X86Ops.SETE, X86Ops.SETL, X86Ops.SETB, X86Ops.VCVTPS2PH, X86Ops.VPEXTRB, X86Ops.VPEXTRW, X86Ops.VPEXTRD, X86Ops.VPEXTRQ}
# X86Ops that read flags
ReadFlags = {X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE, X86Ops.SETB, X86Ops.SETL, X86Ops.SETE, X86Ops.SETNE, X86Ops.JB, X86Ops.JL,
X86Ops.JE, X86Ops.JNE, X86Ops.JGE}
# X86Ops that write flags or can modify flags to undefined values
WriteFlags = {X86Ops.CMP, X86Ops.CMPi, X86Ops.ADD, X86Ops.ADDi, X86Ops.SUB, X86Ops.SUBi, X86Ops.IMUL, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV,
X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi,
X86Ops.OR, X86Ops.ORi, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
# X86Ops whose first src is the rm field
Rm1st = ReadMem1st | (ReadMem2nd & TwoAddress) | {X86Ops.VPSRLDQ}
# X86Ops whose second src is the rm field
Rm2nd = ReadMem2nd | (ReadMem3rd & TwoAddress)
All = set(X86Ops)
# ***** X86 legalization *****
extra_matcher = PatternMatcher([
# bool CMPNE is XOR, bool CMPEQ is XOR+XOR, bool CMPLT is XOR+AND
# TODO: how does this work for vector dtypes?
(UPat.var('x', dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
(UPat.var('x', dtypes.bool).alu(Ops.CMPEQ, UPat.var('y')), lambda x,y: (x^y)^True),
(UPat.var('x', dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
# cast to pointer is a noop
(UPat.var("y").cast(name="x"), lambda y,x: y if isinstance(x.dtype, PtrDType) or y.dtype == dtypes.void else None),
# can't cast from float16 to ints/float64 directly and vice versa
(UPat.var("y", dtypes.float16).cast((dtypes.float64,)+dtypes.ints, name="x"), lambda y,x: y.cast(dtypes.float32).cast(x.dtype)),
(UPat.var("y", (dtypes.float64,)+dtypes.ints).cast(dtypes.float16, name="x"), lambda y,x: y.cast(dtypes.float32).cast(x.dtype)),
# can't cast from float to int8/16 directly and vice versa
(UPat.var("y", dtypes.floats).cast(dtypes.int8s+dtypes.int16s, name="x"), lambda y,x: y.cast(dtypes.int32).cast(x.dtype)),
(UPat.var("y", (dtypes.bool,)+dtypes.int8s+dtypes.int16s).cast(dtypes.floats, name="x"), lambda y,x: y.cast(dtypes.int32).cast(x.dtype)),
# int/float casts only for signed int
(UPat.var("y", dtypes.uint32).cast(dtypes.floats, name="x"), lambda y,x: y.cast(dtypes.int64).cast(x.dtype)),
# casting uint64 to float requires special handling
(UPat.var("y", dtypes.uint64).cast(dtypes.floats, name="x"), lambda y,x:
(y >> 1).cast(dtypes.int64).cast(x.dtype) * 2 + (y & 1).cast(dtypes.int64).cast(x.dtype)),
# no int8 mul or cmove, cast to int16
(UPat.var("a", dtypes.int8s) * UPat.var("b"), lambda a,b: (a.cast(dtypes.int16) * b.cast(dtypes.int16)).cast(a.dtype)),
(UPat.var("m").where(UPat.var("a", (dtypes.bool,)+dtypes.int8s), UPat.var("b")),
lambda m,a,b: m.where(a.cast(dtypes.int16), b.cast(dtypes.int16)).cast(a.dtype) if a.dtype.count == 1 else None),
# float16 alus are done in float32
(UPat(GroupOp.ALU, dtypes.float16, name="x"), lambda x: UOp(x.op, dtypes.float.vec(x.dtype.count),
tuple(s.cast(dtypes.float) if s.dtype != dtypes.bool else s for s in x.src)).cast(x.dtype)),
(UPat(GroupOp.Comparison, src=(UPat.var("a", dtypes.float16), UPat.var("b")), name="x"),
lambda x,a,b: UOp(x.op, x.dtype, (a.cast(dtypes.float32), b.cast(dtypes.float32))).cast(x.dtype)),
# no cmpne for packed ints, y != x => !(y==x)
(UPat(Ops.CMPNE, src=(UPat.var("y", dtypes.ints), UPat.var("x")), name="cmp"),
lambda y,x,cmp: UOp(Ops.CMPEQ, cmp.dtype, (y,x))^True if y.dtype.count > 1 else None),
# float where expects a mask TODO: handle float64 cmp to float32 where
(UPat.var("m", dtypes.bool).where(UPat.var("a", dtypes.floats), UPat.var("b")),
lambda m,a,b: m.cast(a.dtype).ne(0).where(a, b) if m.src[0].dtype not in dtypes.floats else None),
# TODO: do we want this? If yes make it general
#(UPat(Ops.VECTORIZE, dtypes.float16, name="x"), lambda x: x.replace(dtype=dtypes.float32.vec(x.dtype.count),
# src=tuple(s.src[0] for s in x.src)).cast(x.dtype) if all(s.op is Ops.CAST for s in x.src) else None),
# rewrite -x -> 0 - x
(UPat(Ops.NEG, name="x"), lambda x: UOp(Ops.SUB, x.dtype, (x.const_like(0),) + x.src)),
])
# ***** X86 pre instruction selection *****
# these must be done in a separate matcher because they violate the spec
pre_isel_matcher = PatternMatcher([
# zero extending scalar 32bit int is a noop
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
# cast between signed and unsigned int is a noop
(UPat.var("y", dtypes.ints+(dtypes.bool,)).cast(dtypes.ints, name="x"),
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize == y.dtype.itemsize else None),
# cast to < scalar int is a noop
(UPat.var("y", dtypes.ints).cast(dtypes.ints, name="x"),
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize < y.dtype.itemsize and y.dtype.count == 1 else None),
# bitcasts between scalar floats and ints are real, rest are noops
(UPat.var("y").bitcast().named("x"), lambda y,x: None if y.dtype in dtypes.floats and x.dtype in dtypes.ints or \
y.dtype in dtypes.ints and x.dtype in dtypes.floats else x.replace(op=Ops.NOOP)),
# noop of a noop is removed
(UPat(Ops.NOOP, src=(UPat(Ops.NOOP),), name="x"), lambda x: x.replace(src=x.src[0].src)),
# moving elements of a single register to another without shuffling is a noop
(UPat(Ops.VECTORIZE, src=(UPat.var("y"),), allow_any_len=True, name="x"),
lambda y,x: UOp(Ops.NOOP, x.dtype, y.src) if all(s.op is Ops.GEP and s.src == y.src and s.arg[0] == i for i,s in enumerate(x.src)) else None),
# gated index becomes a conditional move on the index, the load/store are unconditional
(UPat.var("base").index(UPat.var("idx"), UPat.var("gate")).load(UPat.var("alt"), name="x"), lambda base,idx,gate,alt,x:
gate.where(base.index(idx, ptr=True), (l:=UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count), arg=0)
.index(UOp.const(dtypes.int32, 0), ptr=True)).after(l.store(alt))).load(dtype=x.dtype)),
(UPat.var("base").index(UPat.var("idx"), UPat.var("gate")).store(UPat.var("val")), lambda base,idx,gate,val:
gate.where(base.index(idx, ptr=True), UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count), arg=0)
.index(UOp.const(dtypes.int32, 0), ptr=True)).store(val)),
# TODO: remove this once we allow all flag producing ops in cmove
# if gate in scalar int cmove is not a comparison need to add one to set the flag
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
lambda m,a,b: m.ne(0).where(a,b) if m.op not in GroupOp.Comparison and a.dtype.count == 1 else None),
])
# ***** X86 registers *****
RAX = Register("rax", 0)
RCX = Register("rcx", 1)
RDX = Register("rdx", 2)
RBX = Register("rbx", 3)
RSP = Register("rsp", 4)
RBP = Register("rbp", 5)
RSI = Register("rsi", 6)
RDI = Register("rdi", 7)
GPR = (RAX, RCX, RDX, RBX, RSP, RBP, RSI, RDI) + tuple(Register(f"r{i}", i) for i in range(8, 16))
XMM = tuple(Register(f"xmm{i}", i) for i in range(16))
# gprs you can write to
WGPR = tuple(r for r in GPR if r != RSP)
CALLEE_SAVED = (RBX, RSP, RBP, GPR[12], GPR[13], GPR[14], GPR[15]) + ((RSI, RDI) + XMM[6:16] if sys.platform == "win32" else ())
reg_strs = {"rax": {4:"eax", 2:"ax", 1:"al"}, "rcx": {4:"ecx", 2:"cx", 1:"cl"}, "rdx": {4:"edx", 2:"dx", 1:"dl"}, "rbx": {4:"ebx", 2:"bx", 1:"bl"},
"rsp": {4:"esp", 2:"sp", 1:"spl"}, "rbp": {4:"ebp", 2:"bp", 1:"bpl"}, "rsi": {4:"esi", 2:"si", 1:"sil"}, "rdi": {4:"edi", 2:"di", 1:"dil"},
**{f"r{i}": {4:f"r{i}d", 2:f"r{i}w", 1:f"r{i}b"} for i in range(8, 16)}, **{f"xmm{i}": {64:f"zmm{i}", 32:f"ymm{i}"} for i in range(16)}}
# ***** X86 instruction selection *****
# if the load is used multiple times we don't fold
def is_foldable_load(ctx:IselContext, x:UOp, s:UOp) -> bool: return s.op is Ops.LOAD and len(ctx.uses[s]) == x.src.count(s) == 1
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
def lane(x:UOp, i:int) -> int: return s.arg[0] if (s:=x.src[i]).op is Ops.GEP else 0
def to_int(dt:DType): return {dtypes.float16: dtypes.int16, dtypes.float32: dtypes.int32, dtypes.float64: dtypes.int64}[dt]
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dt, tag=None if reg is None else (reg,))
def imm(dt:DType, v:int) -> UOp: return UOp(Ops.CONST, dt, arg=truncate[dt](v), tag="__x86_imm__")
def _uop_key(u:UOp): return (u.op, u.dtype, u.arg)
def to_imm(c:UOp) -> UOp|None:
if c.op is not Ops.CONST: return None
if c.dtype is dtypes.int64: return imm(dtypes.int32, c.arg) if not c.overflows(dtypes.int32) else None
if c.dtype is dtypes.uint64: return imm(dtypes.uint32, c.arg) if not c.overflows(dtypes.uint32) else None
if c.dtype in dtypes.ints+(dtypes.bool,): return imm(c.dtype, c.arg)
return None
def cmp(x:UOp) -> UOp:
if x.src[0].dtype is dtypes.float32: return x.ins(X86Ops.VUCOMISS, dtype=dtypes.void)
if x.src[0].dtype is dtypes.float64: return x.ins(X86Ops.VUCOMISD, dtype=dtypes.void)
return x.ins(X86Ops.CMP, dtype=dtypes.void) if (i:=to_imm(x.src[1])) is None else x.ins(X86Ops.CMPi, dtype=dtypes.void, src=(x.src[0], i))
def vcmp(x:UOp) -> UOp:
v = imm(dtypes.uint8, {Ops.CMPLT: 1, Ops.CMPNE: 4, Ops.CMPEQ: 0}[x.op])
if x.dtype.scalar() is dtypes.float32: return x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (v,))
return x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (v,))
# vshufps xmm2, xmm0, xmm1, imm
# for 128 bit xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1 according to imm
# for 256 bit ymm2 repeats the shuffle for its upper 128 bits selecting from the upper 128 bits of ymm0 and ymm1
def vshufps(x:UOp) -> UOp|None:
a, b = base(x, 0), base(x, 2)
if not (a is base(x, 1) and b is base(x, 3)) or any(lane(x, i) > 3 for i in range(4)): return None
if len(x.src) == 8:
if not (a is base(x, 4) is base(x, 5) and b is base(x, 6) is base(x, 7)) or any(lane(x, i+4) != lane(x, i)+4 for i in range(4)): return None
return x.ins(X86Ops.VSHUFPS, src=(a, b, imm(dtypes.uint8, sum(lane(x, i) << 2*i for i in range(4)))))
# vshufpd xmm2, xmm0, xmm1, imm
# for 128 bit xmm2 selects its lower 64 bits from xmm0 and its upper 64 bits from xmm1 according to imm
# for 256 bit ymm2 also selects its upper 128 bits from the upper 128 bits of ymm0 and ymm1 following the same constraint
def vshufpd(x:UOp) -> UOp|None:
a, b = base(x, 0), base(x, 1)
if lane(x, 0) > 1 or lane(x, 1) > 1: return None
if len(x.src) == 4 and not (a is base(x, 2) and b is base(x, 3) and lane(x, 2) > 1 and lane(x, 3) > 1): return None
return x.ins(X86Ops.VSHUFPD, src=(a, b, imm(dtypes.uint8, sum(lane(x, i) << i for i in range(len(x.src))))))
# vinsertps xmm2, xmm0, xmm1, imm
# inserts any 32 bit element in xmm1 into any position in xmm0 according to immm, result is written to xmm2
# this is the fallback slow case for when you can't match more a powerful shuffle
def vinsertps(x:UOp) -> UOp:
def _insert(ret:UOp, i:int) -> UOp:
s, v = base(x, i), lane(x, i)
# moving the 0th element into the 0th position does nothing
return s if i == v == 0 else x.ins(X86Ops.VINSERTPS, src=(ret, s, imm(dtypes.uint8, v << 6 | i << 4)))
return functools.reduce(_insert, range(len(x.src)), def_reg(x.dtype))
# vpinsq xmm2, xmm0, rax, imm
# inserts element in rax into any position in xmm0, result is written to xmm2 according to imm
def vpins(x:UOp) -> UOp:
op = {1: X86Ops.VPINSRB, 2: X86Ops.VPINSRW, 4: X86Ops.VPINSRD, 8: X86Ops.VPINSRQ}[x.dtype.scalar().itemsize]
return functools.reduce(lambda ret,i: x.ins(op, src=(ret, x.src[i], imm(dtypes.uint8, i))), range(len(x.src)), def_reg(x.dtype))
# vpbroadcastd xmm1, xmm0
# inserts scalar int in xmm0 into all lanes of xmm1
def vpbroadcast(ctx:IselContext, x:UOp, y:UOp) -> UOp:
n = x.ins({1: X86Ops.VPBROADCASTB, 2: X86Ops.VPBROADCASTW, 4: X86Ops.VPBROADCASTD, 8: X86Ops.VPBROADCASTQ}[y.dtype.itemsize], src=(y,))
if is_foldable_load(ctx, n, y): return n
# if there isn't a load we can fold we need to move y from gpr to xmm
# this is hacky but required because int.vec(1) isn't supported
y = y if y.dtype.itemsize > 1 else y.cast(dtypes.int16)
return n.replace(src=(y.bitcast({2:dtypes.float16, 4:dtypes.float32, 8:dtypes.float64}[y.dtype.itemsize]),))
# we don't call ctx.vreg on the srcs to avoid duplicates, a rewrite will assign the tuple of valid registers to a vreg
def idiv(ctx:IselContext, x:UOp) -> UOp:
op = X86Ops.DIV if x.dtype in dtypes.uints else X86Ops.IDIV
# for >8bit need to zero/sign extend rax to rdx
if x.dtype in dtypes.int8s: ext = []
elif x.dtype in dtypes.uints: ext = [x.ins(X86Ops.MOVi, src=(imm(min(dtypes.uint32, x.dtype), 0),), tag=(RDX,))]
else: ext = [x.ins(X86Ops.SARi, src=(x.src[0], imm(dtypes.uint8, x.dtype.itemsize * 8 - 1)), tag=(RDX,))]
# for 8bit need to zero/sign extend al to ah
if x.dtype is dtypes.uint8: dividend = UOp(Ops.INS, arg=X86Ops.MOVZX, dtype=dtypes.int16, src=(x.src[0],), tag=(RAX,))
elif x.dtype is dtypes.int8: dividend = UOp(Ops.INS, arg=X86Ops.MOVSX, dtype=dtypes.int16, src=(x.src[0],), tag=(RAX,))
else: dividend = x.ins(X86Ops.MOV, src=(x.src[0],), tag=(RAX,))
# divisor can't be in rax or rdx
divisor = x.ins(X86Ops.MOV, src=(x.src[1],), tag=tuple(r for r in WGPR if r not in (RAX, RDX)))
# for >8bit both rax and rdx are written to
defs = (ctx.vreg(RAX),) if x.dtype in dtypes.int8s else (ctx.vreg(RAX), ctx.vreg(RDX))
idiv = x.ins(op, src=(dividend, divisor) + tuple(ext), tag=defs)
# this move "cleanses" the register constraint (rax) of idiv as it only applies on definition and not on the uses of idiv
return x.ins(X86Ops.MOV, src=(idiv,))
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp]:
def _disp(v:int) -> UOp: return imm(dtypes.int32 if abs(v) > dtypes.int8.max else dtypes.int8, v)
def _cast(v:UOp) -> UOp: return v.cast(dtypes.int64) if v.vmin < 0 else v
if x.op is not Ops.INDEX: return (x, UOp(Ops.NOOP), _disp(0))
base, idx = x.src
disp_scale = base.dtype.itemsize if isinstance(base.dtype, PtrDType) else 1
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * disp_scale))
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * disp_scale))
return (base, _cast(idx), _disp(0))
def alloc_defs(ctx:IselContext, x:UOp) -> UOp|None:
if x.dtype is dtypes.void or isinstance(x.tag, tuple): return None
if x.op in {Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL}:
i = ctx.func_args.index(x)
regs = (RCX, RDX, GPR[8], GPR[9]) if sys.platform == "win32" else (RDI, RSI, RDX, RCX, GPR[8], GPR[9])
if i < len(regs): return x.replace(tag=(ctx.vreg(regs[i]),))
defs = [ctx.vreg(WGPR)] if x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType) else [ctx.vreg(XMM)]
return x.replace(tag=tuple(defs))
def alloc_vregs(ctx:IselContext, x:UOp) -> UOp|None:
# immediates and real registers
if x.op is Ops.CONST: return None
if x.arg in (X86Ops.FRAME_INDEX, X86Ops.DEFINE_REG) and x.tag is not None: return None
# no register definition
if x.dtype is dtypes.void: return None
# already allocated vregs
if isinstance(x.tag, tuple) and x.tag[0]._cons: return None
# allocate vreg definitions
defs = []
if isinstance(x.tag, tuple): defs = [ctx.vreg(x.tag)]
elif x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType): defs = [ctx.vreg(WGPR)]
elif x.dtype in dtypes.floats or x.dtype.count > 1: defs = [ctx.vreg(XMM)]
# TODO: add this once the scheduler can track register pressure
# if x.arg in X86GroupOp.WriteFlags: defs.append(ctx.vreg(RFLAGS))
return x.replace(tag=tuple(defs))
def lower_abi(ctx, x:UOp):
i = ctx.func_arg_idxs[_uop_key(x)]
if sys.platform == "win32": regs, stack_base = (RCX, RDX, GPR[8], GPR[9]), 32
else: regs, stack_base = (RDI, RSI, RDX, RCX, GPR[8], GPR[9]), 0
if i < len(regs):
src = def_reg(x.dtype, regs[i])
if x.reg == src.reg: return src, [src]
return (nx:=x.ins(X86Ops.MOV, src=(src,)), [src, nx])
fi = UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=(i-len(regs)+1)*8+stack_base)
nx = x.ins(X86Ops.MOV, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), fi))
return nx, [fi, nx]
def lower_stack_define(ctx, x:UOp):
disp = imm(dtypes.int32, ctx.local_offsets[_uop_key(x)])
nx = x.ins(X86Ops.LEA, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), disp))
return nx, [disp, nx]
dts = dtypes.ints + (dtypes.bool, dtypes.float16, dtypes.float32, dtypes.float64)
dt_16bit = tuple(dt.vec(l) for dt in dts for l in [2,1] if l*dt.itemsize == 2 and dt not in dtypes.int16s)
dt_32bit = tuple(dt.vec(l) for dt in dts for l in [4,2,1] if l*dt.itemsize == 4 and dt not in dtypes.int32s)
dt_64bit = tuple(dt.vec(l) for dt in dts for l in [8,4,2,1] if l*dt.itemsize == 8 and dt not in dtypes.int64s)
dt_128bit = tuple(dt.vec(l) for dt in dts for l in [16,8,4,2,1] if l*dt.itemsize == 16)
isel_matcher = PatternMatcher([
# **** Op -> Op ****
# float gep(0) is a noop as it just moves the 0th element from one xmm register to another
# this is done here to not interfere with shuffles
(UPat(dtype=dtypes.floats).gep(0, name="x"), lambda x: x.replace(op=Ops.NOOP, arg=None)),
# range is lowered to acc, cmp, jmp after regalloc
(UPat(Ops.RANGE, src=(UPat.cvar("c"),), allow_any_len=True, name="x"), lambda c,x: x.replace(src=(imm(c.dtype, c.arg),) + x.src[1:])),
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(tag=(ctx.vreg(WGPR),)) if not isinstance(x.tag, tuple) else None),
# **** Op -> X86Op ****
# append return, callee-saved live ranges are inserted in regalloc
(UPat(Ops.SINK, name="x"), lambda x:
x.replace(src=(x.ins(X86Ops.RET, src=x.src),))
if not (len(x.src) == 1 and x.src[0].op is Ops.INS and x.src[0].arg is X86Ops.RET) else None),
# late lowered function args and stack backed locals still need virtual registers
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), alloc_defs),
# constants that can't be immediates, move them to registers
(UPat.cvar("x", dtypes.int64s), lambda x: x.ins(X86Ops.MOVABS, src=(imm(x.dtype, x.arg),)) if x.tag is None else None),
(UPat.cvar("x", dtypes.ints+(dtypes.bool,)), lambda x: x.ins(X86Ops.MOVi, src=(imm(x.dtype, x.arg),)) if x.tag is None else None),
(UPat.cvar("x", dtypes.floats), lambda x:
UOp.const(dt:=to_int(x.dtype), struct.unpack(dt.fmt, struct.pack(x.dtype.fmt, x.arg))[0]).bitcast(x.dtype) if x.tag is None else None),
# TODO: these should use a.maximum(b) / a.minimum(b)
((UPat.var("a") < UPat.var("b")).where(UPat.var("b", dtypes.float32), UPat.var("a")), lambda a,b:
a.ins(X86Ops.VMAXSS if a.dtype.count == 1 else X86Ops.VMAXPS, src=(a, b))),
((UPat.var("a") < UPat.var("b")).where(UPat.var("b", dtypes.float64), UPat.var("a")), lambda a,b:
a.ins(X86Ops.VMAXSD if a.dtype.count == 1 else X86Ops.VMAXPD, src=(a, b))),
((UPat.var("a") < UPat.var("b")).where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda a,b:
a.ins(X86Ops.VMINSS if a.dtype.count == 1 else X86Ops.VMINPS, src=(a, b))),
((UPat.var("a") < UPat.var("b")).where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda a,b:
a.ins(X86Ops.VMINSD if a.dtype.count == 1 else X86Ops.VMINPD, src=(a, b))),
# conditional moves that use masks NOTE: these currently assume a mask producing cmp exists
(UPat.var("m").where(UPat.var("a", dtypes.ints), UPat.var("b")), lambda m,a,b:
a.ins(X86Ops.VPBLENDVB, src=(b, a, m.replace(dtype=m.src[0].dtype))) if a.dtype.count > 1 else None),
(UPat.var("m").where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda m,a,b:
a.ins(X86Ops.VBLENDVPS, src=(b, a, m.replace(dtype=m.src[0].dtype)))),
(UPat.var("m").where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda m,a,b:
a.ins(X86Ops.VBLENDVPD, src=(b, a, m.replace(dtype=m.src[0].dtype)))),
# in this case we have a mask producing comparison whose user expects a bool, so we convert to bool
(UPat(GroupOp.Comparison, dtypes.bool, (UPat.var("y", (dtypes.float32, dtypes.float64)), UPat()), name="x"), lambda y,x:
x.replace(dtype=y.dtype).bitcast(to_int(y.dtype)).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)),
# conditional moves that use flags
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b:
a.ins(X86Ops.CMOVL, src=(b, a, cmp(m)))),
(UPat(Ops.CMPLT, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVB, src=(b, a, cmp(m)))),
(UPat(Ops.CMPEQ, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVE, src=(b, a, cmp(m)))),
(UPat(Ops.CMPNE, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVNE, src=(b, a, cmp(m)))),
# jumps, use flags
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.uints), UPat()), name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JB, src=(cmp(y),))),
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JL, src=(cmp(y),))),
(UPat(Ops.IF, src=(UPat(Ops.CMPEQ, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JE, src=(cmp(y),))),
(UPat(Ops.IF, src=(UPat(Ops.CMPNE, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JNE, src=(cmp(y),))),
# comparisons whose user doesn't use the flag, move flag result to register
(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="x"), lambda x: x.ins(X86Ops.SETB, src=(cmp(x),))),
(UPat(Ops.CMPLT, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETL, src=(cmp(x),))),
(UPat(Ops.CMPEQ, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETE, src=(cmp(x),))),
(UPat(Ops.CMPNE, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETNE, src=(cmp(x),))),
# comparisons that produce masks (these aren't bool dtype)
(UPat(GroupOp.Comparison, src=(UPat(dtype=(dtypes.float32, dtypes.float64)), UPat()), name="x"), vcmp),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int8s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQB)),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int16s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQW)),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int32s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQD)),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int64s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQQ)),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int8s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTB, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int16s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTW, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int32s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTD, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int64s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTQ, src=(b, a))),
# float unary
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSS, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPS)),
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSD, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPD)),
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x:
x.ins(X86Ops.VROUNDSS, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else x.ins(X86Ops.VROUNDPS, src=(y, imm(dtypes.uint8, 3)))),
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x:
x.ins(X86Ops.VROUNDSD, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else x.ins(X86Ops.VROUNDPD, src=(y, imm(dtypes.uint8, 3)))),
# shufles
(UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VBROADCASTSS, src=(y,))),
# for float16 we route the srcs through gprs unless we can fold them, this is suboptimal for values in xmms, in that case we want vpunpcklwd
(UPat(Ops.VECTORIZE, dtypes.float16, name="x"), lambda ctx,x:
vpins(x.replace(src=tuple(s if is_foldable_load(ctx, x, s) else s.bitcast(dtypes.int16) for s in x.src)))),
(UPat(Ops.VECTORIZE, (dtypes.float32.vec(4), dtypes.float32.vec(8)), name="x"), vshufps),
(UPat(Ops.VECTORIZE, (dtypes.float64.vec(2), dtypes.float64.vec(4)), name="x"), vshufpd),
(UPat(Ops.VECTORIZE, dtypes.float32, name="x"), vinsertps),
(UPat.var("y", dtypes.ints+(dtypes.bool,)).broadcast(name="x"), vpbroadcast),
(UPat(Ops.VECTORIZE, dtypes.ints+(dtypes.bool,), name="x"), vpins),
# gep
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.floats).gep(name="x"), lambda y,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, x.arg[0] * x.dtype.itemsize)))),
# fused multiply add TODO: don't fuse if mul used several times
(UPat.var('a', dtypes.float32) * UPat.var('b') + UPat.var('c'), lambda a,b,c:
a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(a, b, c))),
(UPat.var('a', dtypes.float64) * UPat.var('b') + UPat.var('c'), lambda a,b,c:
a.ins(X86Ops.VFMADD213SD if a.dtype.count == 1 else X86Ops.VFMADD213PD, src=(a, b, c))),
# packed bitwise
((UPat() & UPat()).named("x"), lambda x: x.ins(X86Ops.VPAND) if x.dtype.count > 1 else None),
((UPat() | UPat()).named("x"), lambda x: x.ins(X86Ops.VPOR) if x.dtype.count > 1 else None),
((UPat() ^ UPat()).named("x"), lambda x: x.ins(X86Ops.VPXOR) if x.dtype.count > 1 else None),
# packed int binary
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVQ) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVQ) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRAVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDB) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDW) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDQ) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.ins(X86Ops.VPSUBB) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPSUBW) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPSUBD) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPSUBQ) if x.dtype.count > 1 else None),
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMULLW) if x.dtype.count > 1 else None),
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMULLD) if x.dtype.count > 1 else None),
# scalar int binary
((UPat(dtype=dtypes.ints) // UPat()).named("x"), idiv),
# scalar int binary with immediate
(UPat.var("a", dtypes.ints) << UPat.cvar("c"), lambda a,c: a.ins(X86Ops.SHLi, src=(a, imm(dtypes.uint8, c.arg)))),
(UPat.var("a", dtypes.uints) >> UPat.cvar("c"), lambda a,c: a.ins(X86Ops.SHRi, src=(a, imm(dtypes.uint8, c.arg)))),
(UPat.var("a", dtypes.sints) >> UPat.cvar("c"), lambda a,c: a.ins(X86Ops.SARi, src=(a, imm(dtypes.uint8, c.arg)))),
(UPat.var("a", dtypes.ints) + UPat.cvar("c"), lambda a,c: a.ins(X86Ops.ADDi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
(UPat.var("a", dtypes.ints) * UPat.cvar("c"), lambda a,c: a.ins(X86Ops.IMULi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
(UPat.var("a", dtypes.ints+(dtypes.bool,)) & UPat.cvar("c"), lambda a,c: a.ins(X86Ops.ANDi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
(UPat.var("a", dtypes.ints+(dtypes.bool,)) | UPat.cvar("c"), lambda a,c: a.ins(X86Ops.ORi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
(UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.cvar("c"), lambda a,c: a.ins(X86Ops.XORi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.cvar("c"))), lambda a,c: a.ins(X86Ops.SUBi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
# scalar int binary with register
(UPat.var("a", dtypes.ints) << UPat.var("b"), lambda a,b: a.ins(X86Ops.SHL, src=(a, b))),
(UPat.var("a", dtypes.uints) >> UPat.var("b"), lambda a,b: a.ins(X86Ops.SHR, src=(a, b))),
(UPat.var("a", dtypes.sints) >> UPat.var("b"), lambda a,b: a.ins(X86Ops.SAR, src=(a, b))),
(UPat.var("a", dtypes.ints) + UPat.var("b"), lambda a,b: a.ins(X86Ops.ADD, src=(a, b))),
(UPat.var("a", dtypes.ints) * UPat.var("b"), lambda a,b: a.ins(X86Ops.IMUL, src=(a, b))),
(UPat.var("a", dtypes.ints+(dtypes.bool,)) & UPat.var("b"), lambda a,b: a.ins(X86Ops.AND, src=(a, b))),
(UPat.var("a", dtypes.ints+(dtypes.bool,)) | UPat.var("b"), lambda a,b: a.ins(X86Ops.OR, src=(a, b))),
(UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.var("b"), lambda a,b: a.ins(X86Ops.XOR, src=(a, b))),
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.var("b"))), lambda a,b: a.ins(X86Ops.SUB, src=(a, b))),
# float binary
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSS if x.dtype.count == 1 else X86Ops.VADDPS)),
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSD if x.dtype.count == 1 else X86Ops.VADDPD)),
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSS if x.dtype.count == 1 else X86Ops.VMULPS)),
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSD if x.dtype.count == 1 else X86Ops.VMULPD)),
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VSUBSS if x.dtype.count == 1 else X86Ops.VSUBPS)),
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)),
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)),
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)),
# casts
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPS2DQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPD2DQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPD2PS) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.float16, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PH, src=x.src + (imm(dtypes.uint8, 4),))),
(UPat(dtype=dtypes.float16).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPH2PS)),
(UPat(dtype=dtypes.float32).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSS2SI)),
(UPat(dtype=dtypes.float64).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSD2SI)),
(UPat.var("y", dtypes.float32).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSS2SD, src=(y, y))),
(UPat.var("y", dtypes.float64).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSD2SS, src=(y, y))),
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SS, src=(def_reg(x.dtype), y))),
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SD, src=(def_reg(x.dtype), y))),
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVZX) if x.dtype.count == 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVSXD) if x.dtype.count == 1 else None),
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVSX) if x.dtype.count == 1 else None),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBW)),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBD)),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBQ)),
(UPat(dtype=dtypes.uint16).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXWD)),
(UPat(dtype=dtypes.uint16).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXWQ)),
(UPat(dtype=dtypes.uint32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXDQ)),
(UPat(dtype=dtypes.int8).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBW)),
(UPat(dtype=dtypes.int8).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBD)),
(UPat(dtype=dtypes.int8).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBQ)),
(UPat(dtype=dtypes.int16).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXWD)),
(UPat(dtype=dtypes.int16).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXWQ)),
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXDQ)),
# bitcasts
(UPat.var("y", dtypes.float16).bitcast(dtypes.int16s).named("x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, 0)))),
(UPat(dtype=dtypes.int16s).bitcast(dtypes.float16).named("x"), vpins),
(UPat(dtype=dtypes.int32s).bitcast(dtypes.float32).named("x"), lambda x: x.ins(X86Ops.VMOVD)),
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.ins(X86Ops.VMOVQ)),
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
# index
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
# copy, load, store
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
(UPat(Ops.COPY, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS)),
(UPat(Ops.COPY, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD)),
(UPat(Ops.COPY, dt_32bit+dt_16bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS)),
(UPat(Ops.COPY, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV)),
(UPat(Ops.LOAD, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS, src=fold_address(x.src[0]))),
(UPat(Ops.LOAD, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD, src=fold_address(x.src[0]))),
(UPat(Ops.LOAD, dt_32bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS, src=fold_address(x.src[0]))),
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda x:
x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype, x.tag),) + fold_address(x.src[0]) + (imm(dtypes.uint8, 0),))),
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV, src=fold_address(x.src[0]))),
(UPat.var("a").store(UPat.var("b", dt_128bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVUPSm, src=fold_address(a) + (b,))),
(UPat.var("a").store(UPat.var("b", dt_64bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVSDm, src=fold_address(a) + (b,))),
(UPat.var("a").store(UPat.var("b", dt_32bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVSSm, src=fold_address(a) + (b,))),
(UPat.var("a").store(UPat.var("b", dt_16bit), name="x"), lambda a,b,x: x.ins(X86Ops.VPEXTRW, src=fold_address(a) + (b, imm(dtypes.uint8, 0)))),
(UPat.var("a").store(UPat.var("b", dtypes.ints+(dtypes.bool,)), name="x"), lambda a,b,x:
x.ins(X86Ops.MOVm, src=fold_address(a) + (b,)) if (i:=to_imm(b)) is None else x.ins(X86Ops.MOVi, src=fold_address(a) + (i,))),
# **** X86Op -> X86Op ****
# fold loads into X86Ops that allow it, if beneficial
(UPat(Ops.INS, src=(UPat(Ops.LOAD, name="y"),), allow_any_len=True, name="x"), lambda ctx,y,x:
x.replace(src=fold_address(y.src[0]) + x.src[1:]) if x.arg in X86GroupOp.ReadMem1st and is_foldable_load(ctx, x, y) else None),
(UPat(Ops.INS, src=(UPat(), UPat(Ops.LOAD, name="y")), allow_any_len=True, name="x"), lambda ctx,y,x:
x.replace(src=x.src[:1] + fold_address(y.src[0]) + x.src[2:]) if x.arg in X86GroupOp.ReadMem2nd and is_foldable_load(ctx, x, y) else None),
(UPat(Ops.INS, src=(UPat(), UPat(), UPat(Ops.LOAD, name="y")), allow_any_len=True, name="x"), lambda ctx,y,x:
x.replace(src=x.src[:2] + fold_address(y.src[0]) + x.src[3:]) if x.arg in X86GroupOp.ReadMem3rd and is_foldable_load(ctx, x, y) else None),
# allocate virtual registers
(UPat(Ops.INS, name="x"), alloc_vregs),
])
# ***** pre register allocation *****
# this handles flag clobbers. Unfortunately x86 doesn't have a good way to store/restore the flag register (then regalloc would handle it)
# so we rematerialize. This is different from rematerialization you might want to do in regalloc because it is not optional,
# regalloc shouldn't rematerialize if a src of the instruction is dead, but here you need to as there's no fallback load from stack
def flag_rematerialize(ctx:PreRegAllocContext, x:UOp):
flag_def = x if x.arg in X86GroupOp.WriteFlags or x.op is Ops.RANGE else x.src[-1] if x.arg in X86GroupOp.ReadFlags else None
if flag_def is None: return None
if ctx.lock is not None and ctx.lock is not flag_def: ctx.clobbered.add(ctx.lock)
ctx.lock = flag_def
if flag_def not in ctx.clobbered: return None
ctx.clobbered.remove(flag_def)
return (x, [flag_def, x])
pre_regalloc_matcher = PatternMatcher([
(UPat((Ops.INS, Ops.RANGE), name="x"), flag_rematerialize),
])
late_regalloc_matcher = PatternMatcher([
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lower_abi),
(UPat((Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), lower_stack_define),
])
# ***** post register allocation *****
# TODO: control flow should be overhauled so that this isn't necessary
def lower_range(ctx, x:UOp) -> tuple[UOp, list[UOp]]:
loop_label = "_".join(str(i) for i in x.arg[:-1])
acc = x.ins(X86Ops.MOVi, src=(imm(x.dtype, 0),) + x.src[1:])
label = UOp(Ops.INS, arg=X86Ops.LABEL, tag=f".LOOP_{loop_label}")
cmp = UOp(Ops.INS, arg=X86Ops.CMPi if x.src[0].op is Ops.CONST else X86Ops.CMP, src=(acc, x.src[0]))
jump_out = UOp(Ops.INS, arg=X86Ops.JGE, src=(cmp,), tag=f".LOOP_OUT_{loop_label}")
ctx.loop_label[acc] = loop_label
return (acc, [acc, label, cmp, jump_out])
# final rewrite to match the isa spec
post_regalloc_matcher = PatternMatcher([
# alloc stack space
(UPat(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dtypes.uint64, name="x"), lambda ctx,x:
(x, [x, x.ins(X86Ops.SUBi, src=(imm(dtypes.uint32, ctx.stack_size),), tag=(RSP,))]) if ctx.stack_size > 0 and x.reg is RSP else None),
# dealloc stack space
(UPat(Ops.INS, arg=X86Ops.RET, name="x"), lambda ctx,x: (x, [UOp(Ops.INS, arg=X86Ops.ADDi, dtype=dtypes.uint64,
src=(imm(dtypes.uint32, ctx.stack_size),), tag=(RSP,)), x]) if ctx.stack_size > 0 else None),
# rewrite FRAME_INDEX to CONST now that the stack size is known
(UPat(Ops.INS, arg=X86Ops.FRAME_INDEX, name="x"), lambda ctx,x: (nx:=UOp.const(x.dtype, ctx.stack_size + x.tag), [nx])),
# rewrite RANGE to ACC = 0 -> LABEL -> JUMP if ACC >= loop bound
(UPat(Ops.RANGE, name="x"), lambda ctx,x: lower_range(ctx, x)),
# rewrite END to ACC + 1 -> JUMP -> LABEL, also add the out of loop JUMP to the src so this becomes the jump target
(UPat(Ops.END, name="x"), lambda ctx,x: (jmp:=UOp(Ops.INS, arg=X86Ops.JMP, tag=f".LOOP_{ctx.loop_label[x.src[1]]}"),
[x.src[1].ins(X86Ops.ADDi, src=(imm(x.src[1].dtype, 1),)), jmp, UOp(Ops.INS, arg=X86Ops.LABEL, tag=f".LOOP_OUT_{ctx.loop_label[x.src[1]]}")])),
# rewrite two address instructions to two address form, if reused src wasn't coalesced insert a move
(UPat(Ops.INS, name="x"), lambda ctx,x: (nx:=x.replace(src=x.src[1:]),
[ctx.ren.copy(x.src[0], x.reg), nx] if x.reg != x.src[0].reg else [nx]) if x.arg in X86GroupOp.TwoAddress else None),
])
# ***** X86 spec *****
# TODO: do we even want this?
isa_spec = PatternMatcher([
# these are the only non X86Ops allowed
(UPat(Ops.CONST), lambda: True),
(UPat((Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER, Ops.SINK)), lambda: True),
(UPat(Ops.INS, name="x"), lambda x: x.arg in X86GroupOp.All),
])
# ***** X86 instruction encoding *****
def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) -> bytes|None:
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
nonlocal reg, opc
# get the encoding values of the different fields
reg = cast(int, cast(Register, reg_uop.reg).index if reg_uop is not None else reg)
rm = cast(Register, rm_uop.reg).index
idx = cast(Register, idx_uop.reg).index if idx_uop is not None and idx_uop.reg is not None else 4
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else rm_uop.dtype.itemsize
reg_sz = (reg_uop.dtype.itemsize if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
sz = reg_sz or rm_sz
# encode instruction
inst = bytes([])
assert 0 <= reg <= 15 and 0 <= idx <= 15 and 0 <= rm <= 15
# r extends reg field, x extends index field, b extends rm or base field
r, _x, b = reg >> 3, idx >> 3, rm >> 3
if sel: # VEX bytes
vvvv = cast(Register, vvvv_uop.reg).index if vvvv_uop is not None else 0
l = (max(reg_sz, rm_sz) > 16) & 0b1
if sel == 1 and _x == b == we == 0: inst += bytes([0xC5, (~r & 0b1) << 7 | (~vvvv & 0b1111) << 3 | l << 2 | pp])
else: inst += bytes([0xC4, (~r & 0b1) << 7 | (~_x & 0b1) << 6 | (~b & 0b1) << 5 | sel, we << 7 | (~vvvv & 0b1111) << 3 | l << 2 | pp])
else: # optional PREFIX and REX bytes
# PREFIX byte signaling 16 bit variant of instruction
if sz == 2: inst += bytes([0x66])
# bit signaling 64 bit variant of instruction
w = sz == 8
# REX byte is required when 64 bit or an extended reg is used (index 8 - 15) or lower 8 bits of (rsp, rbp, rsi, rdi) are accessed
if w | r | _x | b | (reg_sz == 1 & reg >> 2) | (rm_sz == 1 & rm >> 2): inst += bytes([0b0100 << 4 | w << 3 | r << 2 | _x << 1 | b])
# legacy 8bit opcode is 1 less than 16-64bit variants
if (rm_sz == 1 or reg_sz == 1) and x.arg not in X86GroupOp.ReadFlags | {X86Ops.LEA}: opc -= 1
# OPCODE byte
inst += opc.to_bytes((opc.bit_length() + 7) // 8, 'big')
# MODRM byte
# now we only care about the lower 3 bits
idx, rm, reg = idx & 0b111, rm & 0b111, reg & 0b111
# 0b00 -- signals memory access with no displacement
# 0b01 -- signals memory access with 8bit displacement
# 0b10 -- signals memory access with 32bit displacement
# 0b11 -- signals no memory access
if disp_uop is not None:
assert disp_uop.dtype in (dtypes.int8, dtypes.int32), "displacement can only be 1 or 4 byte signed int"
# rbp/r13 always require a displacement
if disp_uop.arg != 0 or rm == 0b101: mod = 0b01 if disp_uop.dtype.itemsize == 1 else 0b10
else: mod = 0b00
else: mod = 0b11
# x 0b0 and idx 0b100 means rsp which means no index exists
# rm 0b100 (rsp/r12) signals a sib byte is required, rm then is encoded in the base field of SIB
_rm = rm if idx == 0b100 and _x == 0b0 else 0b100
inst += bytes([mod << 6 | reg << 3 | _rm])
# SIB byte
if _rm == 0b100 and mod != 0b11:
scale = {1: 0b00, 2: 0b01, 4: 0b10, 8: 0b11}[1 if idx == 0b100 and _x == 0b0 else rm_sz]
inst += bytes([scale << 6 | idx << 3 | rm])
# DISP byte
if mod == 0b01 or mod == 0b10:
assert disp_uop is not None
inst += struct.pack(unwrap(disp_uop.dtype.fmt), disp_uop.arg)
# IMM byte
if imm_uop is not None:
if imm_uop.op is Ops.CONST: inst += struct.pack(unwrap(imm_uop.dtype.fmt), imm_uop.arg)
elif isinstance(imm_uop.reg, Register): inst += bytes([(imm_uop.reg.index & 0b1111) << 4 | 0b0000])
return inst
# get the encoding structure of the uop
# when a uop writes to memory it takes the form of a store, dtype is void, no definition
address:tuple[UOp|None, ...]
if x.arg in X86GroupOp.WriteMem:
if len(x.src) > 3: address, rest = x.src[:3], x.src[3:]
else: address, rest = (x, None, None), x.src
return _encode(rest[0], *address, *(None, *rest[1:])) if reg is None else _encode(None, *address, *(None, *rest[:1]))
if x.arg in X86GroupOp.Rm1st:
if len(x.src) > 2: address, rest = x.src[:3], x.src[3:]
else: address, rest = (x.src[0], None, None), x.src[1:]
imm_uop = rest[:1] if rest and (rest[0].op is Ops.CONST or isinstance(rest[0].reg, Register)) else (None,)
return _encode(x, *address, *(None, *imm_uop)) if reg is None else _encode(None, *address, *(x if sel else None, *imm_uop))
if x.arg in X86GroupOp.Rm2nd:
if len(x.src) > 3: address, rest = x.src[1:4], x.src[:1] + x.src[4:]
else: address, rest = (x.src[1], None, None), x.src[:1] + x.src[2:]
# cmp/vucomiss reg, rm don't define a new register
return _encode(x, *address, *rest) if x.dtype is not dtypes.void else _encode(rest[0], *address)
return None
# https://www.felixcloutier.com/x86/
# NOTE: LEGACY prefix == VEX prefix
# pp field: None == 0, 66 == 1, F3 == 2, F2 == 3
# map select: 0F == 1, 0F38 == 2, 0F3A == 3
encodings = {
# moves
X86Ops.MOVABS: lambda x:
bytes([0b0100 << 4 | 0b1 << 3 | 0b00 << 2 | x.tag[0].index >> 3, 0xB8 + (x.tag[0].index & 0b111)]) + struct.pack(x.dtype.fmt, x.src[0].arg),
X86Ops.MOV: lambda x: encode(x, 0x8B), X86Ops.MOVi: lambda x: encode(x, 0xC7, reg=0),
X86Ops.MOVm: lambda x: encode(x, 0x89), X86Ops.LEA: lambda x: encode(x, 0x8D),
X86Ops.VMOVSS: lambda x: encode(x, 0x10, pp=2, sel=1), X86Ops.VMOVSSm: lambda x: encode(x, 0x11, pp=2, sel=1),
X86Ops.VMOVSD: lambda x: encode(x, 0x10, pp=3, sel=1), X86Ops.VMOVSDm: lambda x: encode(x, 0x11, pp=3, sel=1),
X86Ops.VMOVUPS: lambda x: encode(x, 0x10, pp=0, sel=1), X86Ops.VMOVUPSm: lambda x: encode(x, 0x11, pp=0, sel=1),
X86Ops.VMOVD: lambda x: encode(x, 0x6E, pp=1, sel=1), X86Ops.VMOVQ: lambda x: encode(x, 0x6E, pp=1, sel=1, we=1),
X86Ops.VMOVDm: lambda x: encode(x, 0x7E, pp=1, sel=1), X86Ops.VMOVQm: lambda x: encode(x, 0x7E, pp=1, sel=1, we=1),
# casts
X86Ops.MOVZX: lambda x: encode(x, 0x0FB7),
X86Ops.MOVSX: lambda x: encode(x, 0x0FBF), X86Ops.MOVSXD: lambda x: encode(x, 0x63),
X86Ops.VPMOVZXBW: lambda x: encode(x, 0x30, pp=1, sel=2), X86Ops.VPMOVZXBD: lambda x: encode(x, 0x31, pp=1, sel=2),
X86Ops.VPMOVZXBQ: lambda x: encode(x, 0x32, pp=1, sel=2), X86Ops.VPMOVZXWD: lambda x: encode(x, 0x33, pp=1, sel=2),
X86Ops.VPMOVZXWQ: lambda x: encode(x, 0x34, pp=1, sel=2), X86Ops.VPMOVZXDQ: lambda x: encode(x, 0x35, pp=1, sel=2),
X86Ops.VPMOVSXBW: lambda x: encode(x, 0x20, pp=1, sel=2), X86Ops.VPMOVSXBD: lambda x: encode(x, 0x21, pp=1, sel=2),
X86Ops.VPMOVSXBQ: lambda x: encode(x, 0x22, pp=1, sel=2), X86Ops.VPMOVSXWD: lambda x: encode(x, 0x23, pp=1, sel=2),
X86Ops.VPMOVSXWQ: lambda x: encode(x, 0x24, pp=1, sel=2), X86Ops.VPMOVSXDQ: lambda x: encode(x, 0x25, pp=1, sel=2),
X86Ops.VCVTSS2SD: lambda x: encode(x, 0x5A, pp=2, sel=1), X86Ops.VCVTSD2SS: lambda x: encode(x, 0x5A, pp=3, sel=1),
X86Ops.VCVTPH2PS: lambda x: encode(x, 0x13, pp=1, sel=2), X86Ops.VCVTPS2PH: lambda x: encode(x, 0x1D, pp=1, sel=3),
X86Ops.VCVTDQ2PS: lambda x: encode(x, 0x5B, pp=0, sel=1), X86Ops.VCVTDQ2PD: lambda x: encode(x, 0xE6, pp=2, sel=1),
X86Ops.VCVTPS2PD: lambda x: encode(x, 0x5A, pp=0, sel=1), X86Ops.VCVTPD2PS: lambda x: encode(x, 0x5A, pp=1, sel=1),
X86Ops.VCVTTPS2DQ: lambda x: encode(x, 0x5B, pp=2, sel=1), X86Ops.VCVTTPD2DQ: lambda x: encode(x, 0xE6, pp=1, sel=1),
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.itemsize == 8),
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.itemsize == 8),
X86Ops.VCVTTSS2SI: lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype.itemsize == 8),
X86Ops.VCVTTSD2SI: lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype.itemsize == 8),
# int division
X86Ops.IDIV: lambda x: encode(x, 0xF7, reg=7), X86Ops.DIV: lambda x: encode(x, 0xF7, reg=6),
# scalar int binary
X86Ops.SHLi: lambda x: encode(x, 0xC1, reg=4),
X86Ops.SHRi: lambda x: encode(x, 0xC1, reg=5), X86Ops.SARi: lambda x: encode(x, 0xC1, reg=7),
X86Ops.ADD: lambda x: encode(x, 0x03), X86Ops.ADDi: lambda x: encode(x, 0x81, reg=0),
X86Ops.SUB: lambda x: encode(x, 0x2B), X86Ops.SUBi: lambda x: encode(x, 0x81, reg=5),
X86Ops.AND: lambda x: encode(x, 0x23), X86Ops.ANDi: lambda x: encode(x, 0x81, reg=4),
X86Ops.XOR: lambda x: encode(x, 0x33), X86Ops.XORi: lambda x: encode(x, 0x81, reg=6),
X86Ops.OR: lambda x: encode(x, 0x0B), X86Ops.ORi: lambda x: encode(x, 0x81, reg=1),
X86Ops.CMP: lambda x: encode(x, 0x3B), X86Ops.CMPi: lambda x: encode(x, 0x81, reg=7),
X86Ops.IMUL: lambda x: encode(x, 0x0FAF), X86Ops.IMULi: lambda x: encode(x, 0x69),
X86Ops.SETB: lambda x: encode(x, 0x0F92, reg=0), X86Ops.SETL: lambda x: encode(x, 0x0F9C, reg=0),
X86Ops.SETE: lambda x: encode(x, 0x0F94, reg=0), X86Ops.SETNE: lambda x: encode(x, 0x0F95, reg=0),
# packed bitwise NOTE: only bitwise and packed
X86Ops.VPAND: lambda x: encode(x, 0xDB, pp=1, sel=1), X86Ops.VPXOR: lambda x: encode(x, 0xEF, pp=1, sel=1),
X86Ops.VPOR: lambda x: encode(x, 0xEB, pp=1, sel=1),
# unary
X86Ops.VSQRTSS: lambda x: encode(x, 0x51, pp=2, sel=1), X86Ops.VSQRTPS: lambda x: encode(x, 0x51, pp=0, sel=1),
X86Ops.VSQRTSD: lambda x: encode(x, 0x51, pp=3, sel=1), X86Ops.VSQRTPD: lambda x: encode(x, 0x51, pp=1, sel=1),
X86Ops.VROUNDSS: lambda x: encode(x, 0x0A, pp=1, sel=3), X86Ops.VROUNDPS: lambda x: encode(x, 0x08, pp=1, sel=3),
X86Ops.VROUNDSD: lambda x: encode(x, 0x0B, pp=1, sel=3), X86Ops.VROUNDPD: lambda x: encode(x, 0x09, pp=1, sel=3),
# packed int binary
X86Ops.VPSLLVD: lambda x: encode(x, 0x47, pp=1, sel=2), X86Ops.VPSLLVQ: lambda x: encode(x, 0x47, pp=1, sel=2, we=1),
X86Ops.VPSRLVD: lambda x: encode(x, 0x45, pp=1, sel=2), X86Ops.VPSRLVQ: lambda x: encode(x, 0x45, pp=1, sel=2, we=1),
X86Ops.VPCMPGTB: lambda x: encode(x, 0x64, pp=1, sel=1), X86Ops.VPCMPGTW: lambda x: encode(x, 0x65, pp=1, sel=1),
X86Ops.VPCMPGTD: lambda x: encode(x, 0x66, pp=1, sel=1), X86Ops.VPCMPGTQ: lambda x: encode(x, 0x37, pp=1, sel=2),
X86Ops.VPCMPEQB: lambda x: encode(x, 0x74, pp=1, sel=1), X86Ops.VPCMPEQW: lambda x: encode(x, 0x75, pp=1, sel=1),
X86Ops.VPCMPEQD: lambda x: encode(x, 0x76, pp=1, sel=1), X86Ops.VPCMPEQQ: lambda x: encode(x, 0x29, pp=1, sel=2),
X86Ops.VPMULLW: lambda x: encode(x, 0xD5, pp=1, sel=1), X86Ops.VPMULLD: lambda x: encode(x, 0x40, pp=1, sel=2),
X86Ops.VPADDB: lambda x: encode(x, 0xFC, pp=1, sel=1), X86Ops.VPADDW: lambda x: encode(x, 0xFD, pp=1, sel=1),
X86Ops.VPADDD: lambda x: encode(x, 0xFE, pp=1, sel=1), X86Ops.VPADDQ: lambda x: encode(x, 0xD4, pp=1, sel=1),
X86Ops.VPSUBB: lambda x: encode(x, 0xF8, pp=1, sel=1), X86Ops.VPSUBW: lambda x: encode(x, 0xF9, pp=1, sel=1),
X86Ops.VPSUBD: lambda x: encode(x, 0xFA, pp=1, sel=1), X86Ops.VPSUBQ: lambda x: encode(x, 0xFB, pp=1, sel=1),
X86Ops.VPSRAVD: lambda x: encode(x, 0x46, pp=1, sel=2),
# float cmp
X86Ops.VUCOMISS: lambda x: encode(x, 0x2E, pp=0, sel=1), X86Ops.VUCOMISD: lambda x: encode(x, 0x2E, pp=1, sel=1),
# scalar / packed float binary
X86Ops.VADDSS: lambda x: encode(x, 0x58, pp=2, sel=1), X86Ops.VADDPS: lambda x: encode(x, 0x58, pp=0, sel=1),
X86Ops.VADDSD: lambda x: encode(x, 0x58, pp=3, sel=1), X86Ops.VADDPD: lambda x: encode(x, 0x58, pp=1, sel=1),
X86Ops.VSUBSS: lambda x: encode(x, 0x5C, pp=2, sel=1), X86Ops.VSUBPS: lambda x: encode(x, 0x5C, pp=0, sel=1),
X86Ops.VSUBSD: lambda x: encode(x, 0x5C, pp=3, sel=1), X86Ops.VSUBPD: lambda x: encode(x, 0x5C, pp=1, sel=1),
X86Ops.VMULSS: lambda x: encode(x, 0x59, pp=2, sel=1), X86Ops.VMULPS: lambda x: encode(x, 0x59, pp=0, sel=1),
X86Ops.VMULSD: lambda x: encode(x, 0x59, pp=3, sel=1), X86Ops.VMULPD: lambda x: encode(x, 0x59, pp=1, sel=1),
X86Ops.VDIVSS: lambda x: encode(x, 0x5E, pp=2, sel=1), X86Ops.VDIVPS: lambda x: encode(x, 0x5E, pp=0, sel=1),
X86Ops.VDIVSD: lambda x: encode(x, 0x5E, pp=3, sel=1), X86Ops.VDIVPD: lambda x: encode(x, 0x5E, pp=1, sel=1),
X86Ops.VCMPSS: lambda x: encode(x, 0xC2, pp=2, sel=1), X86Ops.VCMPPS: lambda x: encode(x, 0xC2, pp=0, sel=1),
X86Ops.VCMPSD: lambda x: encode(x, 0xC2, pp=3, sel=1), X86Ops.VCMPPD: lambda x: encode(x, 0xC2, pp=1, sel=1),
X86Ops.VMAXSS: lambda x: encode(x, 0x5F, pp=2, sel=1), X86Ops.VMAXPS: lambda x: encode(x, 0x5F, pp=0, sel=1),
X86Ops.VMAXSD: lambda x: encode(x, 0x5F, pp=3, sel=1), X86Ops.VMAXPD: lambda x: encode(x, 0x5F, pp=1, sel=1),
X86Ops.VMINSS: lambda x: encode(x, 0x5D, pp=2, sel=1), X86Ops.VMINPS: lambda x: encode(x, 0x5D, pp=0, sel=1),
X86Ops.VMINSD: lambda x: encode(x, 0x5D, pp=3, sel=1), X86Ops.VMINPD: lambda x: encode(x, 0x5D, pp=1, sel=1),
# ternary
X86Ops.CMOVB: lambda x: encode(x, 0x0F42), X86Ops.CMOVL: lambda x: encode(x, 0x0F4C),
X86Ops.CMOVE: lambda x: encode(x, 0x0F44), X86Ops.CMOVNE: lambda x: encode(x, 0x0F45),
X86Ops.VFMADD213SS: lambda x: encode(x, 0xA9, pp=1, sel=2), X86Ops.VFMADD213SD: lambda x: encode(x, 0xA9, pp=1, sel=2, we=1),
X86Ops.VFMADD213PS: lambda x: encode(x, 0xA8, pp=1, sel=2), X86Ops.VFMADD213PD: lambda x: encode(x, 0xA8, pp=1, sel=2, we=1),
X86Ops.VBLENDVPS: lambda x: encode(x, 0x4A, pp=1, sel=3), X86Ops.VBLENDVPD: lambda x: encode(x, 0x4B, pp=1, sel=3),
X86Ops.VPBLENDVB: lambda x: encode(x, 0x4C, pp=1, sel=3),
# shuffles
X86Ops.VPBROADCASTB: lambda x: encode(x, 0x78, pp=1, sel=2), X86Ops.VPBROADCASTW: lambda x: encode(x, 0x79, pp=1, sel=2),
X86Ops.VPBROADCASTD: lambda x: encode(x, 0x58, pp=1, sel=2), X86Ops.VPBROADCASTQ: lambda x: encode(x, 0x59, pp=1, sel=2),
X86Ops.VBROADCASTSS: lambda x: encode(x, 0x18, pp=1, sel=2), X86Ops.VPSRLDQ: lambda x: encode(x, 0x73, reg=3, pp=1, sel=1),
X86Ops.VPINSRB: lambda x: encode(x, 0x20, pp=1, sel=3), X86Ops.VPINSRW: lambda x: encode(x, 0xC4, pp=1, sel=1),
X86Ops.VPINSRD: lambda x: encode(x, 0x22, pp=1, sel=3), X86Ops.VPINSRQ: lambda x: encode(x, 0x22, pp=1, sel=3, we=1),
X86Ops.VSHUFPS: lambda x: encode(x, 0xC6, pp=0, sel=1), X86Ops.VSHUFPD: lambda x: encode(x, 0xC6, pp=1, sel=1),
X86Ops.VINSERTPS: lambda x: encode(x, 0x21, pp=1, sel=3),
# extract
X86Ops.VPEXTRB: lambda x: encode(x, 0x14, pp=1, sel=3), X86Ops.VPEXTRW: lambda x: encode(x, 0x15, pp=1, sel=3),
X86Ops.VPEXTRD: lambda x: encode(x, 0x16, pp=1, sel=3), X86Ops.VPEXTRQ: lambda x: encode(x, 0x16, pp=1, sel=3, we=1),
# jumps are encoded with a placeholder which gets patched later once the real offset is known
X86Ops.JE: lambda x: bytes([0x0F, 0x84]) + int(0).to_bytes(4),
X86Ops.JNE: lambda x: bytes([0x0F, 0x85]) + int(0).to_bytes(4),
X86Ops.JL: lambda x: bytes([0x0F, 0x8C]) + int(0).to_bytes(4),
X86Ops.JB: lambda x: bytes([0x0F, 0x82]) + int(0).to_bytes(4),
X86Ops.JGE: lambda x: bytes([0x0F, 0x8D]) + int(0).to_bytes(4),
X86Ops.JMP: lambda x: bytes([0xE9]) + int(0).to_bytes(4),
X86Ops.RET: lambda x: bytes([0xC3]),
}
class X86Renderer(ISARenderer):
device = "CPU"
has_local = False
has_threads = bool(getenv("THREADS", 1))
global_max = (CPU_COUNT.value, 0, 0)
extra_matcher = extra_matcher
pre_isel_matcher = pre_isel_matcher
isel_matcher = isel_matcher
pre_regalloc_matcher = pre_regalloc_matcher
late_regalloc_matcher = late_regalloc_matcher
post_regalloc_matcher = post_regalloc_matcher
isa_spec = isa_spec
code_for_op = {x: lambda: None for x in (Ops.SQRT, Ops.AND, Ops.OR, Ops.SHL, Ops.SHR, Ops.NEG, Ops.SUB, Ops.FDIV, Ops.CMPLT, Ops.CMPEQ)}
def __init__(self, target:Target):
super().__init__(target)
from tinygrad.runtime.support.compiler_cpu import X86Compiler
self.compiler = X86Compiler()
def callee_saved(self):
ordered = (RSP,) + tuple(r for r in CALLEE_SAVED if r is not RSP)
return tuple(def_reg(dtypes.uint64 if r in GPR else dtypes.float64.vec(2), r) for r in ordered)
def is_two_address(self, x:UOp) -> bool: return x.arg in X86GroupOp.TwoAddress
# nasty hacks to deal with pointers TODO: rm pointers
def copy(self, x:UOp, reg:Register):
dt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
ret = isel_matcher.rewrite(UOp(Ops.COPY, dt, (x,), tag=reg))
assert ret is not None
return ret.replace(dtype=x.dtype)
def spill(self, disp:UOp, x:UOp) -> UOp:
nx = x.replace(dtype=dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype)
ret = isel_matcher.rewrite(def_reg(dtypes.uint64, RSP).index(disp).store(nx))
assert ret is not None
return ret.replace(src=(s if s is not nx else x for s in ret.src))
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp:
ndt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
ret = isel_matcher.rewrite(def_reg(dtypes.uint64, RSP).index(disp).load(dtype=ndt, tag=reg))
assert ret is not None
return ret.replace(dtype=x.dtype)
def asm(self, uops:list[UOp], function_name:str) -> str:
def _format_op(x:UOp) -> str: return f" {(o[7:-1] if (o:=str(x.arg))[-1] in ('i', 'm') else o[7:]).lower():7s}"
def _format_operands(x:UOp) -> str:
def _format(src:tuple[UOp, ...]) -> list[str]:
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize if not isinstance(s.dtype, PtrDType) else 8, o) if \
(o:=str(s.reg)) in reg_strs else o for s in src if s.op is Ops.CONST or s.reg is not None]
def _mem_adress(base:UOp, idx:UOp, disp:UOp) -> str:
return f"[{base.reg}" + (f" + {idx.reg}*{base.dtype.itemsize}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"
if len(x.src) > 3 and x.arg in X86GroupOp.WriteMem: return ", ".join([_mem_adress(*x.src[:3])] + _format(x.src[3:]))
elif len(x.src) > 2 and x.arg in X86GroupOp.Rm1st: return ", ".join(_format((x,)) + [_mem_adress(*x.src[:3])] + _format(x.src[3:]))
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm2nd: return ", ".join(_format((x, x.src[0])) + [_mem_adress(*x.src[1:4])] + _format(x.src[4:]))
return ", ".join(_format((x,) + x.src))
asm = [f".{function_name}:"]
for u in uops:
if u.op is not Ops.INS: continue
if u.arg is X86Ops.DEFINE_REG: continue
if u.arg is X86Ops.LABEL: asm.append(f"{str(u.tag)}:")
elif u.arg is X86Ops.RET: asm.append(_format_op(u))
else: asm.append(_format_op(u) + " " + _format_operands(u))
return "\n".join(asm)
def render(self, uops:list[UOp]) -> str:
targets: dict[str, int] = {}
jumps: dict[UOp, int] = {}
binary = bytearray()
for u in uops:
if u.op is not Ops.INS: continue
if u.arg is X86Ops.DEFINE_REG: continue
if u.arg is X86Ops.LABEL:
targets[u.tag] = len(binary)
continue
if u.arg not in encodings or (l:=encodings[u.arg](u)) is None:
raise RuntimeError(f"failed to encode {u.arg} with {u.dtype} srcs {[x.dtype for x in u.src]}")
binary.extend(l)
if u.arg in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE, X86Ops.JGE, X86Ops.JMP): jumps[u] = len(binary)
# fixup jump targets now that encoding size is known
for u in uops:
if (t:=jumps.get(u)) is not None: binary[t-4:t] = (targets[u.tag] - t).to_bytes(4, 'little', signed=True)
return binary.hex()

View file

@ -8,6 +8,7 @@ from tinygrad.runtime.support.hcq import CLikeArgsState
from tinygrad.renderer.cstyle import ClangJITRenderer
from tinygrad.renderer.llvmir import CPULLVMRenderer
from tinygrad.renderer.nir import LVPRenderer
from tinygrad.renderer.isa.x86 import X86Renderer
from tinygrad.runtime.support.elf import jit_loader
from tinygrad.uop.ops import sint
@ -136,5 +137,5 @@ class CPUDevice(HCQCompiled):
def __init__(self, device:str=""):
self.tasks:queue.Queue = queue.Queue()
CPUWorker(self, self.tasks, thread_id=0).start()
renderers:list[type[Renderer]] = [ClangJITRenderer, CPULLVMRenderer, LVPRenderer]
renderers:list[type[Renderer]] = [ClangJITRenderer, CPULLVMRenderer, LVPRenderer, X86Renderer]
super().__init__(device, CPUAllocator(self), renderers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)

View file

@ -1,7 +1,6 @@
from __future__ import annotations
import ctypes, functools, os, pathlib, re, sys, sysconfig
from tinygrad.helpers import ceildiv, getenv, unwrap, DEBUG, OSX, WIN
from _ctypes import Array as _CArray, _SimpleCData, _Pointer
from typing import TYPE_CHECKING, get_type_hints, get_args, get_origin, overload, Annotated, Any, Generic, Iterable, ParamSpec, TypeVar
def _do_ioctl(__idir, __base, __nr, __struct, __fd, *args, __payload=None, **kwargs):
@ -34,22 +33,22 @@ if TYPE_CHECKING:
from _ctypes import _CData
class Array(Generic[T, U], _CData):
@overload
def __getitem__(self: Array[_SimpleCData[V], Any], key: int) -> V: ...
def __getitem__(self: Array[ctypes._SimpleCData[V], Any], key: int) -> V: ...
@overload
def __getitem__(self: Array[T, Any], key: slice) -> list[T]: ...
@overload
def __getitem__(self: Array[T, Any], key: int) -> T: ...
def __getitem__(self, key) -> Any: ...
@overload
def __setitem__(self: Array[_SimpleCData[V], Any], key: int, val: V): ...
def __setitem__(self: Array[ctypes._SimpleCData[V], Any], key: int, val: V): ...
@overload
def __setitem__(self: Array[T, Any], key: int, val: T): ...
@overload
def __setitem__(self: Array[T, Any], key: slice, val: Iterable[T]): ...
def __setitem__(self, key, val): ...
class POINTER(Generic[T], _Pointer): ...
class POINTER(Generic[T], ctypes._Pointer): ...
class CFUNCTYPE(Generic[T, P], _CFunctionType): ...
class Enum(_SimpleCData):
class Enum(ctypes._SimpleCData):
@classmethod
def get(cls, val:int, default="unknown") -> str: ...
@classmethod
@ -80,14 +79,9 @@ else:
return val
def pointer(obj): return ctypes.pointer(obj)
def i2b(i:int, sz:int) -> bytes: return i.to_bytes(sz, sys.byteorder)
def b2i(b:bytes) -> int: return int.from_bytes(b, sys.byteorder)
def mv(st) -> memoryview: return memoryview(st).cast('B')
class Struct(ctypes.Structure):
def __init__(self, *args, **kwargs):
ctypes.Structure.__init__(self)
self._objects_ = {}
for f,v in [*zip((rf[0] for rf in self._real_fields_), args), *kwargs.items()]: setattr(self, f, v)
def record(cls) -> type[Struct]:
@ -98,38 +92,38 @@ def record(cls) -> type[Struct]:
def init_records() -> None:
for cls, struct, ns in _pending_records:
setattr(struct, '_real_fields_', [])
for nm, t in get_type_hints(cls, globalns=ns, include_extras=True).items():
if t.__origin__ in (bool, bytes, str, int, float): setattr(struct, nm, Field(*(f:=t.__metadata__)))
else: setattr(struct, nm, Field(*(f:=(del_an(t.__origin__), *t.__metadata__))))
struct._real_fields_.append((nm,) + f) # type: ignore
for i, (nm, t) in enumerate(get_type_hints(cls, globalns=ns, include_extras=True).items()):
struct._real_fields_.append((nm, *(f:=(del_an(t.__origin__), *t.__metadata__) if isinstance(t.__metadata__[0], int) else t.__metadata__))) # type: ignore
setattr(struct, nm, Field(nm, i, *f))
_pending_records.clear()
class Field(property):
def __init__(self, typ, off:int, bit_width=None, bit_off=0):
if bit_width is not None:
sl, set_mask = slice(off,off+(sz:=ceildiv(bit_width+bit_off, 8))), ~((mask:=(1 << bit_width) - 1) << bit_off)
class Field:
def __init__(self, nm, idx, typ, off, bit_width=None, bit_off=0):
self.nm, self.idx, self.typ, self.off, self.bit_width, self.bit_off = nm, idx, typ, off, bit_width, bit_off
# lazily resolve field descriptors
def _resolve(self, cls):
if self.bit_width: # handle bitfields ourselves
sl, set_mask = slice(self.off, self.off+(sz:=ceildiv(self.bit_width+self.bit_off, 8))), ~((mask:=(1 << self.bit_width) - 1) << self.bit_off)
def b2i(obj): return int.from_bytes(memoryview(obj).cast("B")[sl], sys.byteorder)
def bset(obj, v): memoryview(obj).cast("B")[sl] = ((b2i(obj) & set_mask) | v << self.bit_off).to_bytes(sz, sys.byteorder)
# FIXME: signedness
super().__init__(lambda self: (b2i(mv(self)[sl]) >> bit_off) & mask,
lambda self,v: mv(self).__setitem__(sl, i2b((b2i(mv(self)[sl]) & set_mask) | (v << bit_off), sz)))
else:
sl = slice(off, off + ctypes.sizeof(typ))
def set_with_objs(f):
def wrapper(self, v):
if hasattr(v, '_objects') and hasattr(self, '_objects_'): self._objects_[off] = {'_self_': v, **(v._objects or {})}
mv(self).__setitem__(sl, bytes(v if isinstance(v, typ) else f(v)))
return wrapper
if issubclass(typ, _CArray):
getter = (lambda self: typ.from_buffer(mv(self)[sl]).value) if typ._type_ is ctypes.c_char else (lambda self: typ.from_buffer(mv(self)[sl]))
super().__init__(getter, set_with_objs(lambda v: typ(*v)))
else: super().__init__(lambda self: v.value if isinstance(v:=typ.from_buffer(mv(self)[sl]), _SimpleCData) else v, set_with_objs(typ))
self.offset = off
cf = property(lambda obj: b2i(obj) >> self.bit_off & mask, bset)
# pull the CField descriptor from a dummy class, zero length arrays are so ctypes manages references to child objects for us
else: cf = type(self.nm, (ctypes.Structure,), {"_layout_": "ms", "_pack_": 1, "_fields_": [(str(i), ctypes.c_byte * 0) for i in range(self.idx)] +
[("_", ctypes.c_byte * self.off), ("v", self.typ)]}).v # type: ignore
setattr(cls, self.nm, cf)
return cf
def __get__(self, obj, objtype=None): return self._resolve(objtype).__get__(obj, objtype) if objtype else self
def __set__(self, obj, value): self._resolve(obj.__class__).__set__(obj, value)
@functools.cache
def init_c_struct_t(sz:int, fields: tuple[tuple, ...]):
CStruct = type("CStruct", (Struct,), {'_fields_': [('_mem_', ctypes.c_byte * sz)], '_real_fields_': []})
for nm,ty,*args in fields:
setattr(CStruct, nm, Field(*(f:=(del_an(ty), *args))))
CStruct._real_fields_.append((nm,) + f) # type: ignore
for i,(nm,ty,*args) in enumerate(fields):
CStruct._real_fields_.append((nm, *(f:=(del_an(ty), *args)))) # type: ignore
setattr(CStruct, nm, Field(nm, i, *f))
return CStruct
def init_c_var(ty, creat_cb): return (creat_cb(v:=del_an(ty)()), v)[1]

View file

@ -91,3 +91,8 @@ class CPULLVMCompiler(LLVMCompiler):
# +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx
cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures())
super().__init__(cpu.decode(), feats.decode(), cache_key)
class X86Compiler(Compiler):
def __init__(self): super().__init__(None)
def compile(self, src:str) -> bytes: return bytes.fromhex(src)
def disassemble(self, lib:bytes): return capstone_flatdump(lib)

View file

@ -473,6 +473,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
def ins(self, arg, **kwargs): return UOp(Ops.INS, kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), arg, kwargs.pop("tag", self.tag))
def contract(self, *rngs:UOp):
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs))
@ -537,6 +538,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
for s in self.src: yield from s.split_uop(sep)
else: yield self
@property
def reg(self:UOp):
# TODO: add a way to access the nth element in src, sea of nodes call this a projection
if self.op in (Ops.NOOP, Ops.AFTER) and self.src: return self.src[0].reg
if isinstance(self.tag, tuple): return self.tag[0]
return self.tag
# *** multi-device helpers ***
def multi(self, axis:int|None):
@ -1010,10 +1018,13 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
# ***** uop helpers *****
def print_uops(uops:list[UOp]):
def format_tag(u:UOp) -> str: return "" if u.tag is None else str(u.tag)
uops_index = {u:i for i,u in enumerate(uops)}
tag_width = max((len(format_tag(u)) for u in uops), default=0)
for i,u in enumerate(uops):
formatted_srcs = [(uops_index[x] if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} "
f"{str(formatted_srcs):32s} {format_tag(u):{tag_width}s} {u.arg}")
# ***** pattern matcher *****

View file

@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.SHAPED_WMMA: "#FF5B5B",
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.RANGE: "#76349c", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
@ -104,7 +104,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
if u.op in {Ops.VECTORIZE, Ops.NOOP} and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
for u in toposort:
@ -288,18 +288,22 @@ metrics:dict[str, Callable[[dict[str, tuple[int, int, int]]], str]] = {
def unpack_pmc(e) -> dict:
agg_cols = ["Name", "Sum"]
sample_cols = ["XCC", "INST", "SE", "SA", "WGP", "Value"]
rows:list[list] = []
stats:dict[str, tuple[int, int, int]] = {} # name -> (sum, max, count)
view, ptr = memoryview(e.blob).cast('Q'), 0
for s in e.sched:
sample_cols = ["XCC", "INST", "SE", "SA"] + [f"WGP:{i}" for i in range(s.wgp)]
row:list = [s.name, 0, {"cols":sample_cols, "rows":[]}]
max_val, cnt = 0, 0
for sample in itertools.product(range(s.xcc), range(s.inst), range(s.se), range(s.sa), range(s.wgp)):
row[1] += (val:=int(view[ptr]))
max_val, cnt = max(max_val, val), cnt + 1
row[2]["rows"].append(sample+(val,))
ptr += 1
for sample in itertools.product(range(s.xcc), range(s.inst), range(s.se), range(s.sa)):
vals:list[int] = []
# pack work group processors on the same se
for _ in range(s.wgp):
row[1] += (val:=int(view[ptr]))
max_val, cnt = max(max_val, val), cnt + 1
vals.append(val)
ptr += 1
row[2]["rows"].append(sample+tuple(vals))
stats[s.name] = (row[1], max_val, cnt)
rows.append(row)
for name, fn in metrics.items():