mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
158 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1568c92f7d | ||
|
|
5d09363b5f | ||
|
|
c012f9c5a7 | ||
|
|
934c0c5797 | ||
|
|
c559c29d0b | ||
|
|
e528bc389b | ||
|
|
bd6d7e22ce | ||
|
|
fb40e711dd | ||
|
|
62a7b84aba | ||
|
|
a17988a52d | ||
|
|
12f073e137 |
||
|
|
29582199c1 |
||
|
|
14be3279c1 |
||
|
|
2e61817001 | ||
|
|
8868abe830 | ||
|
|
8346332061 | ||
|
|
91bf07e702 |
||
|
|
891807a1b9 | ||
|
|
0b0ea63439 | ||
|
|
53ef0d36ec | ||
|
|
d333ac1242 |
||
|
|
85e6b77c13 |
||
|
|
a2b32b3abf | ||
|
|
1fb940f762 |
||
|
|
cd0152efec |
||
|
|
35fc12b839 | ||
|
|
292e1745b2 | ||
|
|
e81878abd9 | ||
|
|
acdc232d65 | ||
|
|
0dc615b588 | ||
|
|
449c79ada2 | ||
|
|
221eafcd8d | ||
|
|
7115ed0c22 | ||
|
|
1a52341196 | ||
|
|
037c5e6f82 | ||
|
|
aaab4407af | ||
|
|
465be0d333 |
||
|
|
dd5076529b | ||
|
|
41f2bd8a05 | ||
|
|
255a788dea | ||
|
|
b172b5d72c |
||
|
|
64f574572c | ||
|
|
3dc9bbd831 | ||
|
|
cafa3b74d4 | ||
|
|
393e591f49 |
||
|
|
82954c7ca4 | ||
|
|
eb4ad1ebf0 | ||
|
|
ce2c690721 |
||
|
|
a982a8709e | ||
|
|
94c317a437 | ||
|
|
13c4f2fb04 | ||
|
|
ee4455952d |
||
|
|
c247bdd9b9 | ||
|
|
e29915e3ff | ||
|
|
b4152bff20 |
||
|
|
0af6d94422 | ||
|
|
a9b5a368da | ||
|
|
24bff881e8 | ||
|
|
563a31f791 | ||
|
|
510ef99411 | ||
|
|
8b83cc3aeb | ||
|
|
59603f4d93 |
||
|
|
84b361df95 | ||
|
|
82e42ec061 | ||
|
|
dacfa01c0d | ||
|
|
6b7d75683f | ||
|
|
18cf8c57e8 | ||
|
|
1e316e025a | ||
|
|
a9f8c06f84 | ||
|
|
cb3c4b8b47 | ||
|
|
999483490a | ||
|
|
ad3882bf08 | ||
|
|
250b1b2520 | ||
|
|
b5db91bfdf | ||
|
|
3f01b9970a | ||
|
|
9b3b425518 | ||
|
|
dd558ecfae | ||
|
|
1f140d9d53 | ||
|
|
194d498d28 | ||
|
|
d1c28c2692 | ||
|
|
72f341a534 | ||
|
|
b32bafe1ae | ||
|
|
0681e3311c | ||
|
|
9fbf64e339 | ||
|
|
ce31a4fbec | ||
|
|
86b5441781 | ||
|
|
878557004c | ||
|
|
80e68f3706 | ||
|
|
f0565ed5dc | ||
|
|
78171c4f70 | ||
|
|
e1bf9c9e02 | ||
|
|
6ff67781f1 |
||
|
|
fe2b08bee3 | ||
|
|
5c2b0b2363 | ||
|
|
733789e294 | ||
|
|
ef76bfa081 | ||
|
|
fdaad71b6a | ||
|
|
e2d49fa578 | ||
|
|
74e24d53c9 | ||
|
|
c4c69d8276 | ||
|
|
a3d1f8435a | ||
|
|
1d8a277928 | ||
|
|
4d6ed29af3 | ||
|
|
93022ac35a | ||
|
|
6de9da8b9c |
||
|
|
74e3d9faf3 | ||
|
|
bbe012ac86 | ||
|
|
4c3081b613 | ||
|
|
69a27d9a5c | ||
|
|
77a28ac3f2 | ||
|
|
0ae5c5e4f9 | ||
|
|
e9f2e89f8f |
||
|
|
983f7a2155 | ||
|
|
f0234b9da3 | ||
|
|
a198cb54e2 | ||
|
|
b53bcb3319 | ||
|
|
f1327ebff6 | ||
|
|
6f977100ff | ||
|
|
db3ed92ae3 | ||
|
|
3fcde08b20 | ||
|
|
c1b2816d8b | ||
|
|
f8ade82553 | ||
|
|
dd48f6a111 | ||
|
|
1fe4185e89 | ||
|
|
037c824f9d | ||
|
|
609d9385d8 | ||
|
|
5a61a10547 | ||
|
|
ff5f071ba2 | ||
|
|
7864067e34 | ||
|
|
a5e189794a | ||
|
|
7bafe52335 | ||
|
|
c133d3b1d0 | ||
|
|
0daa1d94d0 | ||
|
|
423f7e66ca | ||
|
|
7ab99089fc | ||
|
|
b4f8d64d2b | ||
|
|
f9b2f51554 |
||
|
|
0fe5d75982 | ||
|
|
138e20adcf | ||
|
|
d0d3272df1 | ||
|
|
243f6c85b9 | ||
|
|
c005ab0122 | ||
|
|
f92e2d259a | ||
|
|
bcd8b2b5cc | ||
|
|
f4309a3b1a | ||
|
|
587259976d | ||
|
|
8d4a48fcd3 | ||
|
|
885172f4bc | ||
|
|
1eca96ea44 | ||
|
|
12714337f0 | ||
|
|
32942f12b7 | ||
|
|
8365bc84ee | ||
|
|
54396f5cb3 | ||
|
|
b8f06970fa | ||
|
|
edb592f314 | ||
|
|
678a6b3689 | ||
|
|
51e1292200 | ||
|
|
98f0ba7fb8 |
21 changed files with 1620 additions and 72 deletions
10
.github/workflows/test.yml
vendored
10
.github/workflows/test.yml
vendored
|
|
@ -793,7 +793,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, cpu, opencl, lvp]
|
||||
backend: [llvm, cpu, opencl, lvp, x86]
|
||||
|
||||
name: Linux (${{ matrix.backend }})
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
@ -810,7 +810,7 @@ jobs:
|
|||
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
|
||||
mesa: ${{ matrix.backend == 'lvp' && 'true' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'DEV=CL' || matrix.backend == 'lvp' && 'DEV=CPU:LVP' }}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'DEV=CL' || matrix.backend == 'lvp' && 'DEV=CPU:LVP' || matrix.backend == 'x86' && 'DEV=CPU:X86' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CPU','CL'], Device.DEFAULT"
|
||||
|
|
@ -960,7 +960,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, cpu, webgpu]
|
||||
backend: [llvm, cpu, webgpu, x86]
|
||||
|
||||
name: Windows (${{ matrix.backend }})
|
||||
runs-on: windows-latest
|
||||
|
|
@ -976,7 +976,7 @@ jobs:
|
|||
pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }}
|
||||
- name: Set env
|
||||
shell: bash
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'DEV=WEBGPU'}}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'DEV=CPU:LLVM' || matrix.backend == 'cpu' && 'DEV=CPU\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'DEV=WEBGPU' || matrix.backend == 'x86' && 'DEV=CPU:X86' }}" >> $GITHUB_ENV
|
||||
- name: Run unit tests
|
||||
if: matrix.backend=='llvm'
|
||||
# test_newton_schulz hits RecursionError
|
||||
|
|
@ -988,7 +988,7 @@ jobs:
|
|||
- name: Run pytest (${{ matrix.backend }})
|
||||
shell: bash
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU', 'X86':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
|
||||
python -m pytest -n=auto test/test_tiny.py test/backend/test_ops.py --durations=20
|
||||
|
||||
# ****** Compile-only Tests ******
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ List all codegen steps for a kernel: `--rewrites -s E_3`
|
|||
Get source code: `--rewrites -s E_3 -i "View Source"`
|
||||
Inspect a graph rewrite: `--rewrites -s E_3 -i "initial symbolic"`
|
||||
|
||||
# SQTT tracing
|
||||
## SQTT tracing
|
||||
|
||||
Supported on AMD for RDNA3 and RDNA4 (best) and CDNA (developing).
|
||||
|
||||
|
|
@ -38,8 +38,12 @@ You can select a specific trace with --source, Example workflow:
|
|||
VIZ=-2 python extra/gemm/amd_asm_matmul.py
|
||||
|
||||
# View barriers
|
||||
extra/viz/cli.py --profile -s "SQTT kernel PKTS SE:0" | rg BARRIER | head -10
|
||||
extra/viz/cli.py --profile -s "kernel SQTT SE:0 PKTS" | rg BARRIER | head -10
|
||||
|
||||
# Get bank conflicts from performance counters
|
||||
|
||||
python extra/viz/cli.py -p -s "kernel PMC" -i "SQC_LDS_BANK_CONFLICT"
|
||||
|
||||
# Find the EXEC corresponding to a DISPATCH at cycle 410
|
||||
extra/viz/cli.py --profile -s "SQTT kernel PKTS SE:0" | awk '/EXEC/ && $1 - $5 == 410'
|
||||
extra/viz/cli.py --profile -s "kernel SQTT SE:0 PKTS" | awk '/EXEC/ && $1 - $5 == 410'
|
||||
```
|
||||
|
|
|
|||
|
|
@ -47,7 +47,9 @@ def decode_profile(data:bytes) -> dict:
|
|||
def get(data:dict, key:str):
|
||||
for k,v in data.items():
|
||||
if ansistrip(k) == key: return v
|
||||
raise RuntimeError(f'item "{key}" not found in list')
|
||||
import difflib
|
||||
match = difflib.get_close_matches(key, [ansistrip(k) for k in data], n=1, cutoff=0.6)
|
||||
raise RuntimeError(f'item "{key}" not found in list'+(f", did you mean {match[0]!r}?" if match else ''))
|
||||
|
||||
def main(args) -> None:
|
||||
viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))
|
||||
|
|
@ -59,8 +61,8 @@ def main(args) -> None:
|
|||
events:list = viz.load_pickle(args.profile_path, default=[])
|
||||
if (profile_bytes:=viz.get_profile(events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
|
||||
profile = decode_profile(profile_bytes)
|
||||
profile["layout"].update([(f'{c["name"]} {s["name"]}', s["data"]) for c in viz.ctxs if c["name"].startswith("SQTT") for s in c["steps"]
|
||||
if "PKTS" in s["name"]])
|
||||
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz.ctxs
|
||||
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
|
||||
if args.src is None:
|
||||
for k in profile["layout"]:
|
||||
print(f" {format_colored(k)}")
|
||||
|
|
@ -99,6 +101,20 @@ def main(args) -> None:
|
|||
print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay or ''):<4} {info}")
|
||||
return None
|
||||
|
||||
# ** PMC printer
|
||||
if "PMC" in args.src:
|
||||
table = viz.unpack_pmc(data[0])
|
||||
cols = table["cols"]
|
||||
rows:list = []
|
||||
for r in table["rows"]:
|
||||
if args.item is None: rows.append(r[:2])
|
||||
elif args.item == r[0]:
|
||||
rows = r[2]["rows"] if len(r) > 2 else [r[:2]]
|
||||
cols = r[2]["cols"] if len(r) > 2 else cols
|
||||
from tabulate import tabulate
|
||||
print(tabulate(rows, headers=cols, tablefmt="github"))
|
||||
return None
|
||||
|
||||
# ** Profiler printer
|
||||
agg:dict[str, tuple[float, int]] = {}
|
||||
total = 0
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ class TestArange(unittest.TestCase):
|
|||
self.assertEqual(self._get_flops(Tensor.arange(256), np.arange(256)), 0)
|
||||
self.assertEqual(self._get_flops(Tensor.arange(2560), np.arange(2560)), 0)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
|
||||
def test_arange_cumsum(self):
|
||||
np.testing.assert_equal(Tensor.arange(513).cumsum(0).numpy(), np.arange(513).cumsum())
|
||||
|
||||
def test_arange_cat(self):
|
||||
t = Tensor.arange(2, dtype=dtypes.int)+Tensor([3])
|
||||
self.assertEqual(t.cat(t).tolist(), [3, 4, 3, 4])
|
||||
|
|
|
|||
150
test/backend/test_encodings.py
Normal file
150
test/backend/test_encodings.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import unittest
|
||||
from typing import cast
|
||||
from tinygrad import Device
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.renderer.isa.x86 import X86Ops, X86Renderer, RBP, RDI, RSP, RSI, RAX, RDX, XMM, GPR, imm, def_reg
|
||||
|
||||
def ins(op, dt, src, tag=None): return UOp(Ops.INS, arg=op, dtype=dt, src=src, tag=tag)
|
||||
|
||||
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only on x86")
|
||||
class TestEncodingsX86(unittest.TestCase):
|
||||
# NOTE: x86 supports a single displacement as memory address and index without base memory address
|
||||
# these have no use cases so they aren't supported
|
||||
def encode(self, u:UOp): return cast(X86Renderer, Device[Device.DEFAULT].renderer).render([u])
|
||||
|
||||
# displacement of 0 isn't emitted
|
||||
def test_base_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RDI)
|
||||
# mov edi, dword ptr [rdi]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 3F"))
|
||||
|
||||
# rsp/r12 require a sib byte when used as base memory address
|
||||
def test_rsp_base_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RSP)
|
||||
# mov esp, dword ptr [rsp]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 24 24"))
|
||||
|
||||
# rbp/r13 require a displacement when used as base memory address
|
||||
def test_rbp_base_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RBP)
|
||||
# mov ebp, dword ptr [rbp + 0]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 6D 00"))
|
||||
|
||||
# test [base + index*scale]
|
||||
def test_base_index_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0)), RAX)
|
||||
# mov eax, dword ptr [rax + rdx*4]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 04 90"))
|
||||
|
||||
# rsp as index means no index
|
||||
def test_rsp_index_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0)), RAX)
|
||||
# mov eax, dword ptr [rax]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 00"))
|
||||
|
||||
# however r12 is a valid index
|
||||
def test_r12_index_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0)), RAX)
|
||||
# mov eax, dword ptr [rax + r12*4]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("42 8B 04 A0"))
|
||||
|
||||
# test [base + index*scale + 8bit disp]
|
||||
def test_complex_address_8bit_disp(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)), RDI)
|
||||
# mov edi, dword ptr [rdi + rsi*4 + 0xa]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 7C B7 0A"))
|
||||
|
||||
# test [base + index*scale + 32bit disp]
|
||||
def test_complex_address_32bit_disp(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000)), RDI)
|
||||
# mov edi, dword ptr [rdi + rsi*4 + 0x2710]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B BC B7 10 27 00 00"))
|
||||
|
||||
# 8bit variants of legacy instructions subtract 1 from opcode
|
||||
def test_8bit_legacy_encoding(self):
|
||||
cast = ins(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDX),), RAX)
|
||||
# movsx eax, dl
|
||||
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("0F BE C2"))
|
||||
|
||||
# accessing lower 8 bits of rsp, rbp, rsi, rdi requires rex prefix
|
||||
def test_lower_8bits_reg(self):
|
||||
cast = ins(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDI),), RAX)
|
||||
# movsx eax, dil
|
||||
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("40 0F BE C7"))
|
||||
|
||||
# test 16 bit variant of legacy instruction
|
||||
def test_16bit_legacy_encoding(self):
|
||||
cast = ins(X86Ops.MOVSX, dtypes.int16, (def_reg(dtypes.int8, RDX),), RAX)
|
||||
# movsx ax, dl
|
||||
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("66 0F BE C2"))
|
||||
|
||||
# test 64 bit variant of legacy instruction
|
||||
def test_64bit_legacy_encoding(self):
|
||||
cast = ins(X86Ops.MOVSX, dtypes.int64, (def_reg(dtypes.int8, RDX),), RAX)
|
||||
# movsx rax, dl
|
||||
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("48 0F BE C2"))
|
||||
|
||||
# test compact vex encoding
|
||||
def test_compact_vex_encoding(self):
|
||||
xmm0, xmm1 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1])
|
||||
add = ins(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm1), XMM[0])
|
||||
# vaddss xmm0, xmm0, xmm1
|
||||
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FA 58 C1"))
|
||||
|
||||
# test long vex encoding
|
||||
def test_long_vex_encoding(self):
|
||||
xmm0, xmm8 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[8])
|
||||
add = ins(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm8), XMM[0])
|
||||
# vaddss xmm0, xmm0, xmm8
|
||||
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C4 C1 7A 58 C0"))
|
||||
|
||||
# test ymm encoding
|
||||
def test_ymm_encoding(self):
|
||||
xmm0, xmm1 = def_reg(dtypes.float32.vec(8), XMM[0]), def_reg(dtypes.float32.vec(8), XMM[1])
|
||||
add = ins(X86Ops.VADDPS, dtypes.float32.vec(8), (xmm0, xmm1), XMM[0])
|
||||
# vaddps ymm0, ymm0, ymm1
|
||||
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FC 58 C1"))
|
||||
|
||||
# test encoding where register is in the immediate field
|
||||
def test_reg_in_imm_field(self):
|
||||
xmm0, xmm1, xmm2 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1]), def_reg(dtypes.float32, XMM[2])
|
||||
blend = ins(X86Ops.VBLENDVPS, dtypes.float32, (xmm0, xmm1, xmm2), XMM[0])
|
||||
# vblendvps xmm0, xmm0, xmm1, xmm2
|
||||
self.assertEqual(bytes.fromhex(self.encode(blend)), bytes.fromhex("C4 E3 79 4A C1 20"))
|
||||
|
||||
# when writting to mem the uop takes the store form where dtype is void and there's no definition
|
||||
def test_write_mem(self):
|
||||
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
|
||||
xmm0 = def_reg(dtypes.float32, XMM[0])
|
||||
extr = ins(X86Ops.VPEXTRD, dtypes.void, (base, index, disp, xmm0, imm(dtypes.uint8, 0)))
|
||||
# vpextrd dword ptr [rdi + rsi*4 + 0xa], xmm0, 0
|
||||
self.assertEqual(bytes.fromhex(self.encode(extr)), bytes.fromhex("C4 E3 79 16 44 B7 0A 00"))
|
||||
|
||||
# test two address instruction with fused load works
|
||||
def test_two_address_load(self):
|
||||
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
|
||||
cmove = ins(X86Ops.CMOVE, dtypes.int32, (base, index, disp), RAX)
|
||||
# cmove eax, dword ptr [rdi + rsi*4 + 0xa]
|
||||
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 44 B7 0A"))
|
||||
|
||||
# test instruction where displacement and imm have the same value
|
||||
def test_disp_imm_same_value(self):
|
||||
base, index, disp = def_reg(dtypes.int8.ptr(), RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10)
|
||||
mov = ins(X86Ops.MOVi, dtypes.void, (base, index, disp, disp))
|
||||
# mov byte ptr [rdi + rsi + 0xa], 0xa
|
||||
self.assertEqual(bytes.fromhex(self.encode(mov)), bytes.fromhex("40 C6 44 37 0A 0A"))
|
||||
|
||||
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10)
|
||||
imul = ins(X86Ops.IMULi, dtypes.int32, (base, index, disp) + (imm(dtypes.int32, 10),), RDI)
|
||||
# imul edi, dword ptr [rdi + rsi*4 + 0xa], 0xa
|
||||
self.assertEqual(bytes.fromhex(self.encode(imul)), bytes.fromhex("69 BC B7 0A 00 00 00 0A 00 00 00"))
|
||||
|
||||
# cmoves have the cmp as the last src even though it is not explicitly used, the cmp doesn't define a reg and is ignored in the encoding
|
||||
def test_cmove_ignore_cmp(self):
|
||||
cmove = ins(X86Ops.CMOVE, dtypes.int32, (def_reg(dtypes.int32, RAX), UOp(Ops.INS, arg=X86Ops.CMP)), RDX)
|
||||
# cmove edx, eax
|
||||
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 D0"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
152
test/backend/test_isel.py
Normal file
152
test/backend/test_isel.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
import unittest
|
||||
from typing import cast
|
||||
from tinygrad import Device
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
|
||||
from tinygrad.renderer.isa.x86 import X86Renderer, X86Ops
|
||||
from tinygrad.renderer.isa import IselContext
|
||||
|
||||
# these tests are to catch changes that don't cause incorrect codegen but cause worse codegen
|
||||
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only x86")
|
||||
class TestIselX86(unittest.TestCase):
|
||||
def isel_rewrite(self, x:UOp):
|
||||
return graph_rewrite(x, cast(X86Renderer, Device[Device.DEFAULT].renderer).isel_matcher, IselContext(x), bottom_up=True)
|
||||
|
||||
def _check_op(self, dt_op, expr):
|
||||
nargs = expr.__code__.co_argcount
|
||||
for dt,op in dt_op:
|
||||
with self.subTest(dtype=dt):
|
||||
v = [UOp.variable(str(i), 0, 0, dt) for i in range(nargs)]
|
||||
n = self.isel_rewrite(expr(*v))
|
||||
self.assertIs(n.arg, op)
|
||||
|
||||
def test_cmove(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
b = UOp.variable("b", 0, 0, dtypes.int32)
|
||||
c = (a < b).where(a, b)
|
||||
d = (a != b).where(a, b)
|
||||
f = c + d
|
||||
n = self.isel_rewrite(f)
|
||||
self.assertTrue(n.src[0].arg is X86Ops.CMOVL and n.src[1].arg is X86Ops.CMOVNE)
|
||||
# both comparisons become the same instruction
|
||||
self.assertTrue(n.src[0].src[2] == n.src[1].src[2] and n.src[0].src[2].arg is X86Ops.CMP)
|
||||
|
||||
def test_vmax(self):
|
||||
dt_op = [(dtypes.float32, X86Ops.VMAXSS), (dtypes.float64, X86Ops.VMAXSD),
|
||||
(dtypes.float32.vec(4), X86Ops.VMAXPS), (dtypes.float64.vec(4), X86Ops.VMAXPD)]
|
||||
self._check_op(dt_op, lambda a,b: (a < b).where(b, a))
|
||||
|
||||
def test_vmin(self):
|
||||
dt_op = [(dtypes.float32, X86Ops.VMINSS), (dtypes.float64, X86Ops.VMINSD),
|
||||
(dtypes.float32.vec(4), X86Ops.VMINPS), (dtypes.float64.vec(4), X86Ops.VMINPD)]
|
||||
self._check_op(dt_op, lambda a,b: (a < b).where(a, b))
|
||||
|
||||
def test_vfmadd(self):
|
||||
dt_op = [(dtypes.float32, X86Ops.VFMADD213SS), (dtypes.float64, X86Ops.VFMADD213SD),
|
||||
(dtypes.float32.vec(4), X86Ops.VFMADD213PS), (dtypes.float64.vec(4), X86Ops.VFMADD213PD)]
|
||||
self._check_op(dt_op, lambda a,b,c: a * b + c)
|
||||
|
||||
# TODO: shouldn't match fmadd if var is used multiple times
|
||||
@unittest.expectedFailure
|
||||
def test_vfmadd_fail(self):
|
||||
dt_op = [(dtypes.float32, X86Ops.VADDSS), (dtypes.float64, X86Ops.VADDSD),
|
||||
(dtypes.float32.vec(4), X86Ops.VADDPS), (dtypes.float64.vec(4), X86Ops.VADDPD)]
|
||||
self._check_op(dt_op, lambda a,b: a * b + b)
|
||||
|
||||
def test_vpbroadcast(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
n = self.isel_rewrite(a.broadcast(4))
|
||||
# need to move src from gpr to xmm before broadcasting
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and n.src[0].arg is X86Ops.VMOVD)
|
||||
# if we can fuse a load we can skip the move and access memory directly
|
||||
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
n = self.isel_rewrite(load.broadcast(4))
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
|
||||
|
||||
def test_vbroadcastss(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32)
|
||||
valid = [UOp.vectorize(a, a, a, a), UOp.vectorize(a, a, a, a, a, a, a, a)]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VBROADCASTSS)
|
||||
|
||||
def test_vshufps(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32.vec(8))
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32.vec(8))
|
||||
c = UOp.variable("c", 0, 0, dtypes.float32)
|
||||
d = UOp.variable("d", 0, 0, dtypes.float32)
|
||||
|
||||
valid = [UOp.vectorize(c, c, d, d),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), c, c),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(1), a.gep(2), a.gep(3), a.gep(0)),
|
||||
UOp.vectorize(a.gep(3), a.gep(2), a.gep(1), a.gep(0), a.gep(7), a.gep(6), a.gep(5), a.gep(4)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), b.gep(1), b.gep(1), a.gep(4), a.gep(4), b.gep(5), b.gep(5))]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
|
||||
|
||||
invalid = [UOp.vectorize(a.gep(0), a.gep(1), b.gep(4), b.gep(5)),
|
||||
UOp.vectorize(a.gep(0), a.gep(5), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), a.gep(0), a.gep(0), a.gep(4), a.gep(4), a.gep(4), a.gep(5)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), b.gep(0), b.gep(0), a.gep(4), a.gep(4), b.gep(4), a.gep(4))]
|
||||
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
|
||||
|
||||
def test_vshufpd(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float64.vec(4))
|
||||
b = UOp.variable("b", 0, 0, dtypes.float64.vec(4))
|
||||
c = UOp.variable("c", 0, 0, dtypes.float64)
|
||||
d = UOp.variable("d", 0, 0, dtypes.float64)
|
||||
|
||||
valid = [UOp.vectorize(c, d),
|
||||
UOp.vectorize(a.gep(0), c),
|
||||
UOp.vectorize(a.gep(1), b.gep(1)),
|
||||
UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(1), a.gep(1), a.gep(3), a.gep(3))]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
|
||||
|
||||
invalid = [UOp.vectorize(c, c, c, c),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(2), b.gep(3), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(0), b.gep(1), a.gep(0), b.gep(1))]
|
||||
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
|
||||
|
||||
# this is the fallback slow VECTORIZE, 1 vinsertps per src in VECTORIZE
|
||||
def test_vinsertps(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32.vec(4))
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32.vec(4))
|
||||
c = UOp.variable("c", 0, 0, dtypes.float32.vec(4))
|
||||
d = UOp.variable("e", 0, 0, dtypes.float32)
|
||||
# pack 1 from vector and 1 from scalar, moving 0th element to position 0 does nothing so only 1 vinsertps is generated
|
||||
n = self.isel_rewrite(UOp.vectorize(a.gep(0), d))
|
||||
self.assertIs(n.arg, X86Ops.VINSERTPS)
|
||||
self.assertIsNot(n.src[0].arg, X86Ops.VINSERTPS)
|
||||
|
||||
valid = [UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)), # TODO: this should be vunpck
|
||||
UOp.vectorize(a.gep(3), b.gep(2), c.gep(1), d)]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VINSERTPS)
|
||||
|
||||
# complex address is [base + index*scale + displacement]
|
||||
def test_complex_address(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(a + 1, ptr=True).load()
|
||||
n = self.isel_rewrite(load)
|
||||
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
|
||||
self.assertTrue(n.src[2].op is Ops.CONST and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4)
|
||||
|
||||
def test_fold_load(self):
|
||||
load1 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load2 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 1), ptr=True).load()
|
||||
n = self.isel_rewrite(load1 + load2)
|
||||
self.assertTrue(len(n.src) == 4)
|
||||
|
||||
# don't fold when used multiple times
|
||||
def test_dont_fold_load(self):
|
||||
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
# used by multiple users
|
||||
n = self.isel_rewrite(load + 1 + load)
|
||||
self.assertTrue(len(n.src) == 2)
|
||||
# used mutiple times by same user
|
||||
n = self.isel_rewrite(load * load)
|
||||
self.assertTrue(len(n.src) == 2)
|
||||
|
||||
# TODO: might want to check that load isn't part of another range when fusing
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -510,7 +510,7 @@ class TestSchedule(unittest.TestCase):
|
|||
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_cumsum_parallel_reduce_fused(self):
|
||||
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END
|
||||
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END (same nesting context = should merge)
|
||||
step, num_steps = 513, 10
|
||||
t = Tensor.arange(step).float().realize()
|
||||
phase = t.cumsum()
|
||||
|
|
@ -521,6 +521,12 @@ class TestSchedule(unittest.TestCase):
|
|||
expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten()
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
|
||||
def test_reduce_different_nesting_depth(self):
|
||||
# two REDUCEs sharing the same RANGE at different nesting depths must NOT merge
|
||||
x = Tensor.arange(768).reshape(3, 256).float()
|
||||
np.testing.assert_allclose((x.sum(axis=1) + x.sum(axis=1).sum()).numpy(), x.numpy().sum(axis=1) + x.numpy().sum(axis=1).sum())
|
||||
|
||||
def test_multimatmul_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,51 @@ class TestC(unittest.TestCase):
|
|||
subprocess.check_output(('clang', '-x', 'c', '-fPIC', '-shared', '-', '-o', f.name), input=src.encode())
|
||||
return DLL("test", f.name)
|
||||
|
||||
def test_struct_array_init(self):
|
||||
@record
|
||||
class Foo:
|
||||
SIZE = 12
|
||||
a: Annotated[ctypes.c_int * 3, 0]
|
||||
init_records()
|
||||
|
||||
f = Foo((1,2,3))
|
||||
assert f.a[0] == 1
|
||||
assert f.a[1] == 2
|
||||
assert f.a[2] == 3
|
||||
f = Foo((ctypes.c_int * 3)(1,2,3))
|
||||
assert f.a[0] == 1
|
||||
assert f.a[1] == 2
|
||||
assert f.a[2] == 3
|
||||
|
||||
def test_field_ranges(self):
|
||||
@record
|
||||
class Foo:
|
||||
SIZE = 2
|
||||
s: Annotated[ctypes.c_int8, 0]
|
||||
u: Annotated[ctypes.c_uint8, 1]
|
||||
init_records()
|
||||
|
||||
f = Foo()
|
||||
f.s = -1
|
||||
f.u = -1
|
||||
assert f.s == -1
|
||||
assert f.u == 255
|
||||
|
||||
# this syntax is inherited from ctypes, but it seems a bit nonsensical?
|
||||
def test_voidp_none(self):
|
||||
@record
|
||||
class Foo:
|
||||
SIZE = 8
|
||||
p: Annotated[ctypes.c_void_p, 0]
|
||||
init_records()
|
||||
|
||||
f = Foo(None)
|
||||
assert f.p is None
|
||||
f.p = ctypes.c_void_p(0xDEADBEEF)
|
||||
assert f.p == 0xDEADBEEF
|
||||
f.p = None
|
||||
assert f.p is None
|
||||
|
||||
def test_packed_struct(self):
|
||||
@record
|
||||
class Baz:
|
||||
|
|
|
|||
|
|
@ -101,6 +101,13 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
|||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext
|
||||
if isinstance(ren, ISARenderer):
|
||||
linear_sink = graph_rewrite(sink, ren.pre_isel_matcher, name="pre instruction selection", bottom_up=True)
|
||||
isel_ctx = IselContext(linear_sink)
|
||||
linear_sink = graph_rewrite(linear_sink, ren.isel_matcher, ctx=isel_ctx, name="instruction selection", bottom_up=True)
|
||||
sink = linear_sink
|
||||
|
||||
# return the rewritten sink
|
||||
return sink
|
||||
|
||||
|
|
@ -114,20 +121,35 @@ pm_linearize_cleanups = PatternMatcher([
|
|||
])
|
||||
|
||||
# requires lst be toposorted. like graph rewrite, but for lines
|
||||
def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]:
|
||||
def line_rewrite(lst:list[UOp], pm:PatternMatcher, ctx=None) -> list[UOp]:
|
||||
newlst = []
|
||||
replaced: dict[UOp, UOp] = {}
|
||||
for u in lst:
|
||||
nu = u.replace(src=tuple([replaced[x] for x in u.src]))
|
||||
ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu)) or (nu, [nu])
|
||||
nu = u.replace(src=tuple([replaced.get(x, x) for x in u.src]))
|
||||
ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu, ctx)) or (nu, [nu])
|
||||
replaced[u] = ret[0]
|
||||
newlst.extend(ret[1])
|
||||
return newlst
|
||||
|
||||
def do_linearize(prg:UOp, sink:UOp) -> UOp:
|
||||
lst = line_rewrite(linearize(sink), pm_linearize_cleanups)
|
||||
if SPEC: type_verify(lst, program_spec)
|
||||
return prg.replace(src=prg.src + (UOp(Ops.LINEAR, src=tuple(lst)),))
|
||||
def do_linearize(ctx:Renderer, prg:UOp, sink:UOp) -> UOp:
|
||||
from tinygrad.renderer.isa import ISARenderer
|
||||
generic_lst = line_rewrite(linearize(sink), pm_linearize_cleanups) if sink.arg.estimates is None and not isinstance(ctx, ISARenderer) else None
|
||||
if isinstance(ctx, ISARenderer):
|
||||
from tinygrad.renderer.isa import PreRegAllocContext
|
||||
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
||||
lst = linearize(sink)
|
||||
if ctx.pre_regalloc_matcher is not None: lst = line_rewrite(lst, ctx.pre_regalloc_matcher, PreRegAllocContext())
|
||||
regalloc_ctx = LinearScanRegallocContext(lst, ctx)
|
||||
lst = line_rewrite(lst, pm_regalloc_rewrite, regalloc_ctx)
|
||||
if ctx.late_regalloc_matcher is not None: lst = line_rewrite(lst, ctx.late_regalloc_matcher, regalloc_ctx)
|
||||
lst = line_rewrite(lst, ctx.post_regalloc_matcher, regalloc_ctx)
|
||||
if DEBUG >= 4: print(ctx.asm(lst, sink.arg.function_name))
|
||||
if SPEC: type_verify(lst, ctx.isa_spec)
|
||||
else:
|
||||
assert generic_lst is not None
|
||||
lst = generic_lst
|
||||
if SPEC: type_verify(lst, program_spec)
|
||||
return prg.replace(src=(sink,)+prg.src[1:] + (UOp(Ops.LINEAR, src=tuple(lst)),))
|
||||
|
||||
def do_estimates(prg:UOp, sink:UOp, lin:UOp) -> UOp|None:
|
||||
if sink.arg.estimates is not None: return None
|
||||
|
|
|
|||
|
|
@ -328,11 +328,23 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
|||
return acc.after(end).index(UOp.const(dtypes.int, 0))
|
||||
|
||||
def merge_reduce_ends(ctx:ReduceContext, sink:UOp):
|
||||
# merge ENDs that share the same range (only those created by reduce_to_acc)
|
||||
# merge ENDs that share the same range and nesting context (only those created by reduce_to_acc)
|
||||
# ENDs at different nesting depths get cloned RANGEs so each RANGE maps to one END
|
||||
range_to_ends: dict[tuple[UOp, ...], list[UOp]] = {}
|
||||
for u in sink.backward_slice:
|
||||
if u.op is Ops.END and u.tag == "mergeable": range_to_ends.setdefault(u.src[1:], []).append(u)
|
||||
subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in range_to_ends.items() if len(ends) > 1 for e in ends}
|
||||
subs: dict[UOp, UOp] = {}
|
||||
next_axis = max((u.arg[0] for u in sink.backward_slice if u.op is Ops.RANGE), default=-1) + 1
|
||||
for r, ends in range_to_ends.items():
|
||||
if len(ends) <= 1: continue
|
||||
by_ctx: dict[frozenset[UOp], list[UOp]] = {}
|
||||
for e in ends: by_ctx.setdefault(frozenset(e.ranges), []).append(e)
|
||||
for i, group in enumerate(by_ctx.values()):
|
||||
tr = r if i == 0 else tuple(rr.replace(arg=(next_axis + j, *rr.arg[1:])) for j, rr in enumerate(r))
|
||||
if i > 0: next_axis += len(r)
|
||||
mapped = [e.substitute(dict(zip(r, tr))) if i > 0 else e for e in group]
|
||||
merged = mapped[0] if len(mapped) == 1 else UOp.group(*(e.src[0] for e in mapped)).end(*tr)
|
||||
for e in group: subs[e] = merged
|
||||
return sink.substitute(subs) if subs else None
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import heapq
|
|||
from typing import Any
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
|
||||
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
|
||||
from tinygrad.helpers import prod, getenv, TUPLE_ORDER, DEV
|
||||
|
||||
def linearize(sink:UOp) -> list[UOp]:
|
||||
# this is a toposort with priority
|
||||
|
|
@ -31,10 +31,18 @@ def linearize(sink:UOp) -> list[UOp]:
|
|||
case Ops.RANGE: priority = 5 # placing RANGE is good
|
||||
case Ops.END: priority = -5 # placing END is bad
|
||||
case _: priority = 0 # everything else has priority 0
|
||||
|
||||
# stack pointer needs to be scheduled at the top of the kernel
|
||||
# TODO: remove once there's a proper isa scheduler
|
||||
if u.op is Ops.INS:
|
||||
from tinygrad.renderer.isa.x86 import X86Ops, RSP
|
||||
match u.arg:
|
||||
case X86Ops.DEFINE_REG: priority, extra = (-21 if u.tag[0] == RSP else -20), u.tag[0].index
|
||||
|
||||
priorities[u] = (run_count, priority, extra)
|
||||
|
||||
# number the uops in "ideal" order
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER else ())))}
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER and DEV.value.renderer != "X86" else ())))}
|
||||
|
||||
# then force them to be toposorted in as close to the ideal order as possible
|
||||
heap = [(-nkey[sink], sink)]
|
||||
|
|
@ -93,4 +101,4 @@ def do_split_ends(e:UOp):
|
|||
pm_split_ends = PatternMatcher([
|
||||
# split the ends
|
||||
(UPat(Ops.END, name="e"), do_split_ends),
|
||||
])
|
||||
])
|
||||
|
|
|
|||
132
tinygrad/codegen/late/regalloc.py
Normal file
132
tinygrad/codegen/late/regalloc.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
import itertools
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
||||
from tinygrad.renderer.isa import ISARenderer, Register
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
|
||||
PSEUDO_OPS = {Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP}
|
||||
def _uop_key(u:UOp): return (u.op, u.dtype, u.arg)
|
||||
|
||||
# loosely based on: https://bernsteinbear.com/assets/img/register-spilling-range-splitting-ssa.pdf
|
||||
class LinearScanRegallocContext:
|
||||
def __init__(self, uops:list[UOp], ren:ISARenderer):
|
||||
if saved:=ren.callee_saved():
|
||||
ret_i = next(i for i,u in reversed(tuple(enumerate(uops))) if u.op is Ops.INS and getattr(u.arg, "name", None) == "RET")
|
||||
uops[0:0] = saved
|
||||
uops[ret_i+len(saved)] = uops[ret_i+len(saved)].replace(src=uops[ret_i+len(saved)].src + saved)
|
||||
live_range: dict[Register, list[int]] = {}
|
||||
live: dict[Register, Register] = {}
|
||||
live_ins: list[dict[Register, Register]] = []
|
||||
self.defs: dict[Register, UOp] = {} # mapping from virtual to uop that defines it
|
||||
self.real_defs: dict[Register, Register] = {} # mapping from virtual to real at definition
|
||||
self.spills: dict[Register, UOp] = {} # mapping from virtual to stack slot
|
||||
self.fills: dict[int, dict[int, tuple[Register, Register]]] = {} # mapping from program point to mapping from idx to virtual and real to fill to
|
||||
self.insert_before: dict[int, list[tuple[Register, Register]]] = {} # mapping from program point to fills to be inserted
|
||||
self.idx = itertools.count()
|
||||
self.ren = ren
|
||||
self.stack_size = 0
|
||||
# the label associated with each loop NOTE: this is only used post regalloc and should be removed
|
||||
self.loop_label: dict[UOp, str] = {}
|
||||
arg_order = {Ops.PARAM: 0, Ops.DEFINE_VAR: 1, Ops.SPECIAL: 2}
|
||||
self.func_arg_idxs = {_uop_key(u): i for i,u in enumerate(sorted({u for u in uops if u.op in arg_order}, key=lambda k: (arg_order[k.op], k.arg)))}
|
||||
self.local_offsets: dict[tuple, int] = {}
|
||||
for u in uops:
|
||||
if u.op not in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG): continue
|
||||
self.local_offsets.setdefault(_uop_key(u), self.stack_size)
|
||||
self.stack_size += u.dtype.nbytes()
|
||||
# compute live ranges
|
||||
lr = live_range
|
||||
ranges: list[Register] = []
|
||||
for i,u in enumerate(reversed(uops)):
|
||||
if u.op in PSEUDO_OPS: continue
|
||||
defs = u.tag if isinstance(u.tag, tuple) else ()
|
||||
for v in defs + tuple(s.reg for s in set(u.src)):
|
||||
if isinstance(v, Register): lr.setdefault(v, []).insert(0, len(uops) - 1 - i)
|
||||
for v in defs:
|
||||
if isinstance(v, Register): self.defs[v] = u
|
||||
if v in lr and (n:=max((lr[rng][-1] for rng in ranges if lr[rng][0] <= lr[v][-1] < lr[rng][-1]), default=None)): lr[v].append(n)
|
||||
if u.op is Ops.RANGE: ranges.append(u.reg)
|
||||
|
||||
def alloc(cons:tuple[Register, ...], i:int) -> Register:
|
||||
live_inv = {v:k for k,v in live.items()}
|
||||
# allocate the best register. Registers not in live or not used again are free and have priority,
|
||||
# otherwise pick the one with the furthest next use. Regs that appear first in cons have priority in case of a tie
|
||||
reg,vreg = max(((r,live_inv.get(r)) for r in cons),
|
||||
key=lambda rv: next((j-i for j in ([] if rv[1] is None else live_range[rv[1]]) if j >= i), len(uops)))
|
||||
return live.pop(vreg) if vreg is not None else reg
|
||||
|
||||
# assign register to spilled virtual and record load to be emitted before current uop, also assign it a stack slot
|
||||
def fill(v:Register, i:int, cons:tuple[Register, ...]|None=None) -> Register:
|
||||
if v not in self.spills:
|
||||
dt = self.defs[v].dtype
|
||||
sz = dt.scalar().itemsize * dt.count if not isinstance(dt, PtrDType) else 8
|
||||
assert sz > 0
|
||||
offset = self.stack_size + (sz - self.stack_size % sz) % sz
|
||||
self.spills[v] = UOp.const(dtypes.int32, offset)
|
||||
self.stack_size = offset + sz
|
||||
r = alloc(cons if cons is not None else v.cons, i)
|
||||
self.insert_before.setdefault(i, []).append((v, r))
|
||||
return r
|
||||
|
||||
for i,u in enumerate(uops):
|
||||
if u.op in PSEUDO_OPS: continue
|
||||
# allocate uses
|
||||
for j,s in enumerate(u.src):
|
||||
# HACK: cause of later hacks to lower range
|
||||
if u.op is Ops.END: continue
|
||||
if not isinstance(v:=s.reg, Register): continue
|
||||
if v not in live: live[v] = fill(v, i)
|
||||
if v in self.spills: self.fills.setdefault(i, {})[j] = (v, live[v])
|
||||
|
||||
# allocate defs
|
||||
if isinstance(u.tag, tuple):
|
||||
for j,v in enumerate(u.tag):
|
||||
assert isinstance(v, Register) and v not in live
|
||||
cons = v.cons
|
||||
# two address instructions (src is reused by def) can only coalesce reused src. reused src goes first to get priority in case of a tiebreak
|
||||
if ren.is_two_address(u) and j == 0:
|
||||
ins = tuple(live.get(s.reg) for s in u.src)
|
||||
cons = ((ins[0],) if ins[0] in cons else ()) + tuple(r for r in cons if r not in ins)
|
||||
assert cons
|
||||
# HACK: cause the range is missing the comparison
|
||||
self.real_defs[v] = live[v] = alloc(cons, i+1 if u.op is not Ops.RANGE else i)
|
||||
|
||||
# loop prologue, avoid loading inside the loop
|
||||
if u.op is Ops.RANGE:
|
||||
# we move to registers vars used in the loop sorted by next use, vars not used in the loop will not be reloaded in the epilogue
|
||||
used_in_loop = [v for v in live.keys() | self.spills.keys() if any(i <= l < live_range[u.reg][-1] for l in live_range[v])]
|
||||
sorted_uses = sorted(used_in_loop, key=lambda k: next(l-i for l in live_range[k] if l >= i))
|
||||
live_in: dict[Register, Register] = {}
|
||||
for v in sorted_uses:
|
||||
# if all the possible registers are already in live_in there's no space for this var
|
||||
if set(v.cons).issubset(live_in.values()): continue
|
||||
if v not in live: live[v] = fill(v, i)
|
||||
assert live[v] not in live_in.values()
|
||||
live_in[v] = live[v]
|
||||
live_ins.append(live_in)
|
||||
|
||||
# loop epilogue, reload registers that were live at loop entry
|
||||
if u.op is Ops.END:
|
||||
# TODO: if a uop is in a different reg in live out vs live in move between registers instead of loading
|
||||
# TODO: don't reload if first use in loop is a load
|
||||
for v,r in live_ins.pop().items():
|
||||
if v not in live or live[v] != r: live[v] = fill(v, i, (r,))
|
||||
|
||||
def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
|
||||
i = next(ctx.idx)
|
||||
if x.op in PSEUDO_OPS: return None
|
||||
nsrc = []
|
||||
for j,s in enumerate(x.src):
|
||||
if i in ctx.fills and j in ctx.fills[i]:
|
||||
v,r = ctx.fills[i][j]
|
||||
nsrc.append(ctx.ren.fill(ctx.spills[v], ctx.defs[v], r))
|
||||
else: nsrc.append(s)
|
||||
ndefs = tuple(ctx.real_defs[v] for v in x.tag) if isinstance(x.tag, tuple) else x.tag
|
||||
nx = x.replace(src=tuple(nsrc), tag=ndefs)
|
||||
fills = [ctx.ren.fill(ctx.spills[v], ctx.defs[v], r) for v,r in ctx.insert_before.get(i, [])]
|
||||
spills = [ctx.ren.spill(ctx.spills[v], nx) for v in x.tag if v in ctx.spills] if isinstance(x.tag, tuple) else []
|
||||
return nx, fills + [nx] + spills
|
||||
|
||||
pm_regalloc_rewrite = PatternMatcher([
|
||||
(UPat({Ops.INS, Ops.CONST, Ops.RANGE, Ops.END, Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER,
|
||||
Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"), regalloc_rewrite),
|
||||
])
|
||||
|
|
@ -319,7 +319,8 @@ def is_dtype_supported(dtype:DType, target:Target|None=None) -> bool:
|
|||
case "METAL": return not CI or BENCHMARKS
|
||||
case "CUDA": return (not CI or BENCHMARKS) and target.renderer != "PTX"
|
||||
case "NV": return (not CI or BENCHMARKS) and target.renderer not in ("PTX", "NAK")
|
||||
case "CPU": return (not CI or BENCHMARKS) and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and target.renderer != "LVP"
|
||||
case "CPU": return (not CI or BENCHMARKS) and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and \
|
||||
target.renderer not in ("LVP", "X86")
|
||||
case "AMD" | "CL" | "PYTHON" | "NULL": return True
|
||||
case _: return False
|
||||
if dtype in dtypes.fp8_ocp:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import ctypes
|
|||
from tinygrad.helpers import ceildiv, round_up
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.runtime.autogen import amdgpu_kd, hsa, libc
|
||||
from tinygrad.renderer.amd.dsl import Reg, FixedBitField
|
||||
from tinygrad.renderer.amd.dsl import Inst, Reg, FixedBitField
|
||||
from tinygrad.runtime.autogen.amd.common import OpType
|
||||
|
||||
# instructions used for padding
|
||||
|
|
@ -11,8 +11,9 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import s_code_end # same encoding as
|
|||
from tinygrad.runtime.autogen.amd.cdna.ins import s_nop as s_nop_cdna
|
||||
|
||||
_arch_map = {"gfx9": "cdna", "gfx10": "rdna3", "gfx11": "rdna3", "gfx12": "rdna4"}
|
||||
def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp:
|
||||
def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp|None:
|
||||
insts = [u.arg for u in lin.src]
|
||||
if not all(isinstance(inst, Inst) for inst in insts): return None
|
||||
|
||||
# ** scan for max vgpr/sgpr/accvgpr
|
||||
max_vgpr, max_sgpr, max_accvgpr = 0, 0, 0
|
||||
|
|
|
|||
44
tinygrad/renderer/isa/__init__.py
Normal file
44
tinygrad/renderer/isa/__init__.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from __future__ import annotations
|
||||
import itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Register:
|
||||
name: str
|
||||
index: int
|
||||
_cons: tuple[Register, ...] = field(default_factory=tuple)
|
||||
@property
|
||||
def cons(self): return self._cons or (self,)
|
||||
def __repr__(self): return self.name
|
||||
|
||||
class IselContext:
|
||||
def __init__(self, sink:UOp):
|
||||
self.uses = sink.get_consumer_map()
|
||||
self.reg_n = itertools.count()
|
||||
arg_order = {Ops.PARAM: 0, Ops.DEFINE_VAR: 1, Ops.SPECIAL: 2}
|
||||
self.func_args = sorted([u for u in self.uses if u.op in arg_order], key=lambda k: (arg_order[k.op], k.arg))
|
||||
|
||||
def vreg(self, cons:tuple[Register, ...]|Register):
|
||||
return Register(f"v{next(self.reg_n)}", 0, _cons=cons if isinstance(cons, tuple) else (cons,))
|
||||
|
||||
@dataclass
|
||||
class PreRegAllocContext:
|
||||
lock: UOp|None = None
|
||||
clobbered: set[UOp] = field(default_factory=set)
|
||||
|
||||
class ISARenderer(Renderer):
|
||||
isa_spec: PatternMatcher
|
||||
pre_isel_matcher: PatternMatcher
|
||||
isel_matcher: PatternMatcher
|
||||
pre_regalloc_matcher: PatternMatcher|None = None
|
||||
late_regalloc_matcher: PatternMatcher|None = None
|
||||
post_regalloc_matcher: PatternMatcher
|
||||
|
||||
def callee_saved(self) -> tuple[UOp, ...]: return tuple()
|
||||
def is_two_address(self, x:UOp) -> bool: return False
|
||||
def copy(self, x:UOp, reg:Register) -> UOp: raise NotImplementedError("arch specific")
|
||||
def spill(self, disp:UOp, x:UOp) -> UOp: raise NotImplementedError("arch specific")
|
||||
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp: raise NotImplementedError("arch specific")
|
||||
def asm(self, uops:list[UOp], function_name:str) -> str: raise NotImplementedError("arch specific")
|
||||
936
tinygrad/renderer/isa/x86.py
Normal file
936
tinygrad/renderer/isa/x86.py
Normal file
|
|
@ -0,0 +1,936 @@
|
|||
# flake8: noqa: E702
|
||||
# allow semicolons to put multiple ops on one line
|
||||
import sys, struct, functools
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, DType, truncate
|
||||
from tinygrad.uop import FastEnum, auto, Ops, GroupOp
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext, Register, PreRegAllocContext
|
||||
from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
|
||||
|
||||
# ***** X86 Ops *****
|
||||
|
||||
class X86Ops(FastEnum):
|
||||
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
|
||||
# these aren't real instructions
|
||||
DEFINE_REG = auto(); FRAME_INDEX = auto(); LABEL = auto()
|
||||
# index
|
||||
LEA = auto()
|
||||
# register / memory / immediate moves
|
||||
MOV = auto(); MOVm = auto(); MOVi = auto(); MOVABS = auto()
|
||||
VMOVSS = auto(); VMOVSD = auto(); VMOVUPS = auto()
|
||||
VMOVSSm = auto(); VMOVSDm = auto(); VMOVUPSm = auto()
|
||||
# casts
|
||||
MOVZX = auto(); MOVSX = auto(); MOVSXD = auto()
|
||||
VPMOVZXBW = auto(); VPMOVZXBD = auto(); VPMOVZXBQ = auto()
|
||||
VPMOVZXWD = auto(); VPMOVZXWQ = auto(); VPMOVZXDQ = auto()
|
||||
VPMOVSXBW = auto(); VPMOVSXBD = auto(); VPMOVSXBQ = auto()
|
||||
VPMOVSXWD = auto(); VPMOVSXWQ = auto(); VPMOVSXDQ = auto()
|
||||
VCVTDQ2PS = auto(); VCVTDQ2PD = auto(); VCVTTPS2DQ = auto(); VCVTTPD2DQ = auto()
|
||||
VCVTPH2PS = auto(); VCVTPS2PH = auto(); VCVTPS2PD = auto(); VCVTPD2PS = auto()
|
||||
VCVTSS2SD = auto(); VCVTSD2SS = auto(); VCVTSI2SS = auto(); VCVTSI2SD = auto()
|
||||
VCVTTSS2SI = auto(); VCVTTSD2SI = auto()
|
||||
# bitcasts
|
||||
VMOVD = auto(); VMOVQ = auto(); VMOVDm = auto(); VMOVQm = auto()
|
||||
# comparisons
|
||||
VUCOMISS = auto(); VUCOMISD = auto()
|
||||
VCMPSS = auto(); VCMPSD = auto(); VCMPPS = auto(); VCMPPD = auto()
|
||||
VPCMPGTB = auto(); VPCMPGTW = auto(); VPCMPGTD = auto(); VPCMPGTQ = auto()
|
||||
VPCMPEQB = auto(); VPCMPEQW = auto(); VPCMPEQD = auto(); VPCMPEQQ = auto()
|
||||
SETNE = auto(); SETE = auto(); SETL = auto(); SETB = auto()
|
||||
# where
|
||||
CMOVNE = auto(); CMOVE = auto(); CMOVL = auto(); CMOVB = auto()
|
||||
VPBLENDVB = auto(); VBLENDVPS = auto(); VBLENDVPD = auto()
|
||||
# jumps
|
||||
JNE = auto(); JE = auto(); JL = auto(); JB = auto(); JGE = auto(); JMP = auto()
|
||||
# vectorize / gep
|
||||
VSHUFPS = auto(); VSHUFPD = auto(); VINSERTPS = auto(); VPSRLDQ = auto()
|
||||
VPEXTRB = auto(); VPEXTRW = auto(); VPEXTRD = auto(); VPEXTRQ = auto()
|
||||
VPINSRB = auto(); VPINSRW = auto(); VPINSRD = auto(); VPINSRQ = auto()
|
||||
VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto()
|
||||
VBROADCASTSS = auto()
|
||||
# int binary
|
||||
IDIV = auto(); DIV = auto()
|
||||
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto()
|
||||
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto()
|
||||
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto()
|
||||
# float unary (sometimes not unary)
|
||||
VROUNDSS = auto(); VROUNDSD = auto(); VROUNDPS = auto(); VROUNDPD = auto()
|
||||
VSQRTSS = auto(); VSQRTSD = auto(); VSQRTPS = auto(); VSQRTPD = auto()
|
||||
# float scalar / vector binary
|
||||
VADDSS = auto(); VADDSD = auto(); VADDPS = auto(); VADDPD = auto()
|
||||
VSUBSS = auto(); VSUBSD = auto(); VSUBPS = auto(); VSUBPD = auto()
|
||||
VMULSS = auto(); VMULSD = auto(); VMULPS = auto(); VMULPD = auto()
|
||||
VDIVSS = auto(); VDIVSD = auto(); VDIVPS = auto(); VDIVPD = auto()
|
||||
VMAXSS = auto(); VMAXSD = auto(); VMAXPS = auto(); VMAXPD = auto()
|
||||
VMINSS = auto(); VMINSD = auto(); VMINPS = auto(); VMINPD = auto()
|
||||
# int vector binary
|
||||
VPADDB = auto(); VPADDW = auto(); VPADDD = auto(); VPADDQ = auto()
|
||||
VPSUBB = auto(); VPSUBW = auto(); VPSUBD = auto(); VPSUBQ = auto()
|
||||
VPMULLW = auto(); VPMULLD = auto()
|
||||
# packed bitwise
|
||||
VPAND = auto(); VPOR = auto(); VPXOR = auto()
|
||||
# packed variable shifts
|
||||
VPSLLVD = auto(); VPSLLVQ = auto(); VPSRLVD = auto(); VPSRLVQ = auto(); VPSRAVD = auto()
|
||||
# fused multiply add TODO: add other variants to fuse more loads
|
||||
VFMADD213SS = auto(); VFMADD213SD = auto(); VFMADD213PS = auto(); VFMADD213PD = auto()
|
||||
# return
|
||||
RET = auto()
|
||||
|
||||
# TODO: add commutative groupop to fuse more loads
|
||||
class X86GroupOp:
|
||||
# X86Ops whose first src is also the destination
|
||||
TwoAddress = {X86Ops.ADD, X86Ops.ADDi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi, X86Ops.OR, X86Ops.ORi, X86Ops.IMUL,
|
||||
X86Ops.SUB, X86Ops.SUBi, X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi,
|
||||
X86Ops.IDIV, X86Ops.DIV, X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD,
|
||||
X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB}
|
||||
|
||||
# X86Ops whose first src can read from memory
|
||||
ReadMem1st = {X86Ops.MOV, X86Ops.VMOVSS, X86Ops.VMOVSD, X86Ops.VMOVUPS, X86Ops.MOVZX, X86Ops.MOVSX, X86Ops.MOVSXD, X86Ops.VMOVD, X86Ops.VMOVQ,
|
||||
X86Ops.VPMOVZXBW, X86Ops.VPMOVZXBD, X86Ops.VPMOVZXBQ, X86Ops.VPMOVZXWD, X86Ops.VPMOVZXWQ, X86Ops.VPMOVZXDQ,
|
||||
X86Ops.VPMOVSXBW, X86Ops.VPMOVSXBD, X86Ops.VPMOVSXBQ, X86Ops.VPMOVSXWD, X86Ops.VPMOVSXWQ, X86Ops.VPMOVSXDQ,
|
||||
X86Ops.VCVTDQ2PS, X86Ops.VCVTDQ2PD, X86Ops.VCVTTPS2DQ, X86Ops.VCVTTPD2DQ, X86Ops.VCVTTSS2SI, X86Ops.VCVTTSD2SI,
|
||||
X86Ops.VCVTPH2PS, X86Ops.VCVTPS2PD, X86Ops.VCVTPD2PS, X86Ops.VROUNDPS, X86Ops.VROUNDPD, X86Ops.VSQRTPS, X86Ops.VSQRTPD,
|
||||
X86Ops.VPBROADCASTB, X86Ops.VPBROADCASTW, X86Ops.VPBROADCASTD, X86Ops.VPBROADCASTQ, X86Ops.VBROADCASTSS,
|
||||
X86Ops.CMPi, X86Ops.IMULi, X86Ops.LEA}
|
||||
|
||||
# X86Ops whose second src can read from memory NOTE: some of these are TwoAddress so the second src is actually the first
|
||||
ReadMem2nd = {X86Ops.ADD, X86Ops.SUB, X86Ops.AND, X86Ops.OR, X86Ops.XOR, X86Ops.SHL, X86Ops.SHR, X86Ops.SAR, X86Ops.IMUL, X86Ops.CMP,
|
||||
X86Ops.VADDSS, X86Ops.VADDSD, X86Ops.VADDPS, X86Ops.VADDPD, X86Ops.VSUBSS, X86Ops.VSUBSD, X86Ops.VSUBPS, X86Ops.VSUBPD,
|
||||
X86Ops.VMULSS, X86Ops.VMULSD, X86Ops.VMULPS, X86Ops.VMULPD, X86Ops.VDIVSS, X86Ops.VDIVSD, X86Ops.VDIVPS, X86Ops.VDIVPD,
|
||||
X86Ops.VPADDB, X86Ops.VPADDW, X86Ops.VPADDD, X86Ops.VPADDQ, X86Ops.VPSUBB, X86Ops.VPSUBW, X86Ops.VPSUBD, X86Ops.VPSUBQ,
|
||||
X86Ops.VPCMPEQB, X86Ops.VPCMPEQW, X86Ops.VPCMPEQD, X86Ops.VPCMPEQQ, X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD,
|
||||
X86Ops.VPCMPGTB, X86Ops.VPCMPGTW, X86Ops.VPCMPGTD, X86Ops.VPCMPGTQ, X86Ops.VCMPSS, X86Ops.VCMPSD, X86Ops.VCMPPS, X86Ops.VCMPPD,
|
||||
X86Ops.VPMULLW, X86Ops.VPMULLD, X86Ops.VROUNDSS, X86Ops.VROUNDSD, X86Ops.VSQRTSS, X86Ops.VSQRTSD, X86Ops.VSHUFPS, X86Ops.VINSERTPS,
|
||||
X86Ops.VPINSRB, X86Ops.VPINSRW, X86Ops.VPINSRD, X86Ops.VPINSRQ, X86Ops.VPAND, X86Ops.VPOR, X86Ops.VPXOR, X86Ops.VPSLLVD,
|
||||
X86Ops.VPSLLVQ, X86Ops.VPSRLVD, X86Ops.VPSRLVQ, X86Ops.VPSRAVD, X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB,
|
||||
X86Ops.VMAXSS, X86Ops.VMAXSD, X86Ops.VMAXPS, X86Ops.VMAXPD, X86Ops.VMINSS, X86Ops.VMINSD, X86Ops.VMINPS, X86Ops.VMINPD,
|
||||
X86Ops.VCVTSI2SS, X86Ops.VCVTSI2SD, X86Ops.VCVTSS2SD, X86Ops.VCVTSD2SS, X86Ops.VUCOMISS, X86Ops.VUCOMISD, X86Ops.IDIV, X86Ops.DIV,
|
||||
X86Ops.VSHUFPD}
|
||||
|
||||
# X86Ops whose third src can read from memory NOTE: these are TwoAddress so the third src is actually the second
|
||||
ReadMem3rd = {X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD}
|
||||
|
||||
# X86Ops that can write to memory
|
||||
WriteMem = {X86Ops.MOVm, X86Ops.MOVi, X86Ops.VMOVSSm, X86Ops.VMOVSDm, X86Ops.VMOVUPSm, X86Ops.VMOVDm, X86Ops.VMOVQm,
|
||||
X86Ops.ADDi, X86Ops.SUBi, X86Ops.ANDi, X86Ops.ORi, X86Ops.XORi, X86Ops.SHLi, X86Ops.SHRi, X86Ops.SARi, X86Ops.SETNE,
|
||||
X86Ops.SETE, X86Ops.SETL, X86Ops.SETB, X86Ops.VCVTPS2PH, X86Ops.VPEXTRB, X86Ops.VPEXTRW, X86Ops.VPEXTRD, X86Ops.VPEXTRQ}
|
||||
|
||||
# X86Ops that read flags
|
||||
ReadFlags = {X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE, X86Ops.SETB, X86Ops.SETL, X86Ops.SETE, X86Ops.SETNE, X86Ops.JB, X86Ops.JL,
|
||||
X86Ops.JE, X86Ops.JNE, X86Ops.JGE}
|
||||
|
||||
# X86Ops that write flags or can modify flags to undefined values
|
||||
WriteFlags = {X86Ops.CMP, X86Ops.CMPi, X86Ops.ADD, X86Ops.ADDi, X86Ops.SUB, X86Ops.SUBi, X86Ops.IMUL, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV,
|
||||
X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi,
|
||||
X86Ops.OR, X86Ops.ORi, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
|
||||
|
||||
# X86Ops whose first src is the rm field
|
||||
Rm1st = ReadMem1st | (ReadMem2nd & TwoAddress) | {X86Ops.VPSRLDQ}
|
||||
|
||||
# X86Ops whose second src is the rm field
|
||||
Rm2nd = ReadMem2nd | (ReadMem3rd & TwoAddress)
|
||||
|
||||
All = set(X86Ops)
|
||||
|
||||
# ***** X86 legalization *****
|
||||
|
||||
extra_matcher = PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPEQ is XOR+XOR, bool CMPLT is XOR+AND
|
||||
# TODO: how does this work for vector dtypes?
|
||||
(UPat.var('x', dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||
(UPat.var('x', dtypes.bool).alu(Ops.CMPEQ, UPat.var('y')), lambda x,y: (x^y)^True),
|
||||
(UPat.var('x', dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
|
||||
# cast to pointer is a noop
|
||||
(UPat.var("y").cast(name="x"), lambda y,x: y if isinstance(x.dtype, PtrDType) or y.dtype == dtypes.void else None),
|
||||
# can't cast from float16 to ints/float64 directly and vice versa
|
||||
(UPat.var("y", dtypes.float16).cast((dtypes.float64,)+dtypes.ints, name="x"), lambda y,x: y.cast(dtypes.float32).cast(x.dtype)),
|
||||
(UPat.var("y", (dtypes.float64,)+dtypes.ints).cast(dtypes.float16, name="x"), lambda y,x: y.cast(dtypes.float32).cast(x.dtype)),
|
||||
# can't cast from float to int8/16 directly and vice versa
|
||||
(UPat.var("y", dtypes.floats).cast(dtypes.int8s+dtypes.int16s, name="x"), lambda y,x: y.cast(dtypes.int32).cast(x.dtype)),
|
||||
(UPat.var("y", (dtypes.bool,)+dtypes.int8s+dtypes.int16s).cast(dtypes.floats, name="x"), lambda y,x: y.cast(dtypes.int32).cast(x.dtype)),
|
||||
# int/float casts only for signed int
|
||||
(UPat.var("y", dtypes.uint32).cast(dtypes.floats, name="x"), lambda y,x: y.cast(dtypes.int64).cast(x.dtype)),
|
||||
# casting uint64 to float requires special handling
|
||||
(UPat.var("y", dtypes.uint64).cast(dtypes.floats, name="x"), lambda y,x:
|
||||
(y >> 1).cast(dtypes.int64).cast(x.dtype) * 2 + (y & 1).cast(dtypes.int64).cast(x.dtype)),
|
||||
# no int8 mul or cmove, cast to int16
|
||||
(UPat.var("a", dtypes.int8s) * UPat.var("b"), lambda a,b: (a.cast(dtypes.int16) * b.cast(dtypes.int16)).cast(a.dtype)),
|
||||
(UPat.var("m").where(UPat.var("a", (dtypes.bool,)+dtypes.int8s), UPat.var("b")),
|
||||
lambda m,a,b: m.where(a.cast(dtypes.int16), b.cast(dtypes.int16)).cast(a.dtype) if a.dtype.count == 1 else None),
|
||||
# float16 alus are done in float32
|
||||
(UPat(GroupOp.ALU, dtypes.float16, name="x"), lambda x: UOp(x.op, dtypes.float.vec(x.dtype.count),
|
||||
tuple(s.cast(dtypes.float) if s.dtype != dtypes.bool else s for s in x.src)).cast(x.dtype)),
|
||||
(UPat(GroupOp.Comparison, src=(UPat.var("a", dtypes.float16), UPat.var("b")), name="x"),
|
||||
lambda x,a,b: UOp(x.op, x.dtype, (a.cast(dtypes.float32), b.cast(dtypes.float32))).cast(x.dtype)),
|
||||
# no cmpne for packed ints, y != x => !(y==x)
|
||||
(UPat(Ops.CMPNE, src=(UPat.var("y", dtypes.ints), UPat.var("x")), name="cmp"),
|
||||
lambda y,x,cmp: UOp(Ops.CMPEQ, cmp.dtype, (y,x))^True if y.dtype.count > 1 else None),
|
||||
# float where expects a mask TODO: handle float64 cmp to float32 where
|
||||
(UPat.var("m", dtypes.bool).where(UPat.var("a", dtypes.floats), UPat.var("b")),
|
||||
lambda m,a,b: m.cast(a.dtype).ne(0).where(a, b) if m.src[0].dtype not in dtypes.floats else None),
|
||||
# TODO: do we want this? If yes make it general
|
||||
#(UPat(Ops.VECTORIZE, dtypes.float16, name="x"), lambda x: x.replace(dtype=dtypes.float32.vec(x.dtype.count),
|
||||
# src=tuple(s.src[0] for s in x.src)).cast(x.dtype) if all(s.op is Ops.CAST for s in x.src) else None),
|
||||
# rewrite -x -> 0 - x
|
||||
(UPat(Ops.NEG, name="x"), lambda x: UOp(Ops.SUB, x.dtype, (x.const_like(0),) + x.src)),
|
||||
])
|
||||
|
||||
# ***** X86 pre instruction selection *****
|
||||
|
||||
# these must be done in a separate matcher because they violate the spec
|
||||
pre_isel_matcher = PatternMatcher([
|
||||
# zero extending scalar 32bit int is a noop
|
||||
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
|
||||
# cast between signed and unsigned int is a noop
|
||||
(UPat.var("y", dtypes.ints+(dtypes.bool,)).cast(dtypes.ints, name="x"),
|
||||
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize == y.dtype.itemsize else None),
|
||||
# cast to < scalar int is a noop
|
||||
(UPat.var("y", dtypes.ints).cast(dtypes.ints, name="x"),
|
||||
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize < y.dtype.itemsize and y.dtype.count == 1 else None),
|
||||
# bitcasts between scalar floats and ints are real, rest are noops
|
||||
(UPat.var("y").bitcast().named("x"), lambda y,x: None if y.dtype in dtypes.floats and x.dtype in dtypes.ints or \
|
||||
y.dtype in dtypes.ints and x.dtype in dtypes.floats else x.replace(op=Ops.NOOP)),
|
||||
# noop of a noop is removed
|
||||
(UPat(Ops.NOOP, src=(UPat(Ops.NOOP),), name="x"), lambda x: x.replace(src=x.src[0].src)),
|
||||
# moving elements of a single register to another without shuffling is a noop
|
||||
(UPat(Ops.VECTORIZE, src=(UPat.var("y"),), allow_any_len=True, name="x"),
|
||||
lambda y,x: UOp(Ops.NOOP, x.dtype, y.src) if all(s.op is Ops.GEP and s.src == y.src and s.arg[0] == i for i,s in enumerate(x.src)) else None),
|
||||
# gated index becomes a conditional move on the index, the load/store are unconditional
|
||||
(UPat.var("base").index(UPat.var("idx"), UPat.var("gate")).load(UPat.var("alt"), name="x"), lambda base,idx,gate,alt,x:
|
||||
gate.where(base.index(idx, ptr=True), (l:=UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count), arg=0)
|
||||
.index(UOp.const(dtypes.int32, 0), ptr=True)).after(l.store(alt))).load(dtype=x.dtype)),
|
||||
(UPat.var("base").index(UPat.var("idx"), UPat.var("gate")).store(UPat.var("val")), lambda base,idx,gate,val:
|
||||
gate.where(base.index(idx, ptr=True), UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count), arg=0)
|
||||
.index(UOp.const(dtypes.int32, 0), ptr=True)).store(val)),
|
||||
# TODO: remove this once we allow all flag producing ops in cmove
|
||||
# if gate in scalar int cmove is not a comparison need to add one to set the flag
|
||||
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
|
||||
lambda m,a,b: m.ne(0).where(a,b) if m.op not in GroupOp.Comparison and a.dtype.count == 1 else None),
|
||||
])
|
||||
|
||||
# ***** X86 registers *****
|
||||
|
||||
RAX = Register("rax", 0)
|
||||
RCX = Register("rcx", 1)
|
||||
RDX = Register("rdx", 2)
|
||||
RBX = Register("rbx", 3)
|
||||
RSP = Register("rsp", 4)
|
||||
RBP = Register("rbp", 5)
|
||||
RSI = Register("rsi", 6)
|
||||
RDI = Register("rdi", 7)
|
||||
GPR = (RAX, RCX, RDX, RBX, RSP, RBP, RSI, RDI) + tuple(Register(f"r{i}", i) for i in range(8, 16))
|
||||
XMM = tuple(Register(f"xmm{i}", i) for i in range(16))
|
||||
# gprs you can write to
|
||||
WGPR = tuple(r for r in GPR if r != RSP)
|
||||
|
||||
CALLEE_SAVED = (RBX, RSP, RBP, GPR[12], GPR[13], GPR[14], GPR[15]) + ((RSI, RDI) + XMM[6:16] if sys.platform == "win32" else ())
|
||||
|
||||
reg_strs = {"rax": {4:"eax", 2:"ax", 1:"al"}, "rcx": {4:"ecx", 2:"cx", 1:"cl"}, "rdx": {4:"edx", 2:"dx", 1:"dl"}, "rbx": {4:"ebx", 2:"bx", 1:"bl"},
|
||||
"rsp": {4:"esp", 2:"sp", 1:"spl"}, "rbp": {4:"ebp", 2:"bp", 1:"bpl"}, "rsi": {4:"esi", 2:"si", 1:"sil"}, "rdi": {4:"edi", 2:"di", 1:"dil"},
|
||||
**{f"r{i}": {4:f"r{i}d", 2:f"r{i}w", 1:f"r{i}b"} for i in range(8, 16)}, **{f"xmm{i}": {64:f"zmm{i}", 32:f"ymm{i}"} for i in range(16)}}
|
||||
|
||||
# ***** X86 instruction selection *****
|
||||
# if the load is used multiple times we don't fold
|
||||
def is_foldable_load(ctx:IselContext, x:UOp, s:UOp) -> bool: return s.op is Ops.LOAD and len(ctx.uses[s]) == x.src.count(s) == 1
|
||||
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
|
||||
def lane(x:UOp, i:int) -> int: return s.arg[0] if (s:=x.src[i]).op is Ops.GEP else 0
|
||||
def to_int(dt:DType): return {dtypes.float16: dtypes.int16, dtypes.float32: dtypes.int32, dtypes.float64: dtypes.int64}[dt]
|
||||
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dt, tag=None if reg is None else (reg,))
|
||||
def imm(dt:DType, v:int) -> UOp: return UOp(Ops.CONST, dt, arg=truncate[dt](v), tag="__x86_imm__")
|
||||
def _uop_key(u:UOp): return (u.op, u.dtype, u.arg)
|
||||
def to_imm(c:UOp) -> UOp|None:
|
||||
if c.op is not Ops.CONST: return None
|
||||
if c.dtype is dtypes.int64: return imm(dtypes.int32, c.arg) if not c.overflows(dtypes.int32) else None
|
||||
if c.dtype is dtypes.uint64: return imm(dtypes.uint32, c.arg) if not c.overflows(dtypes.uint32) else None
|
||||
if c.dtype in dtypes.ints+(dtypes.bool,): return imm(c.dtype, c.arg)
|
||||
return None
|
||||
def cmp(x:UOp) -> UOp:
|
||||
if x.src[0].dtype is dtypes.float32: return x.ins(X86Ops.VUCOMISS, dtype=dtypes.void)
|
||||
if x.src[0].dtype is dtypes.float64: return x.ins(X86Ops.VUCOMISD, dtype=dtypes.void)
|
||||
return x.ins(X86Ops.CMP, dtype=dtypes.void) if (i:=to_imm(x.src[1])) is None else x.ins(X86Ops.CMPi, dtype=dtypes.void, src=(x.src[0], i))
|
||||
def vcmp(x:UOp) -> UOp:
|
||||
v = imm(dtypes.uint8, {Ops.CMPLT: 1, Ops.CMPNE: 4, Ops.CMPEQ: 0}[x.op])
|
||||
if x.dtype.scalar() is dtypes.float32: return x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (v,))
|
||||
return x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (v,))
|
||||
|
||||
# vshufps xmm2, xmm0, xmm1, imm
|
||||
# for 128 bit xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1 according to imm
|
||||
# for 256 bit ymm2 repeats the shuffle for its upper 128 bits selecting from the upper 128 bits of ymm0 and ymm1
|
||||
def vshufps(x:UOp) -> UOp|None:
|
||||
a, b = base(x, 0), base(x, 2)
|
||||
if not (a is base(x, 1) and b is base(x, 3)) or any(lane(x, i) > 3 for i in range(4)): return None
|
||||
if len(x.src) == 8:
|
||||
if not (a is base(x, 4) is base(x, 5) and b is base(x, 6) is base(x, 7)) or any(lane(x, i+4) != lane(x, i)+4 for i in range(4)): return None
|
||||
return x.ins(X86Ops.VSHUFPS, src=(a, b, imm(dtypes.uint8, sum(lane(x, i) << 2*i for i in range(4)))))
|
||||
|
||||
# vshufpd xmm2, xmm0, xmm1, imm
|
||||
# for 128 bit xmm2 selects its lower 64 bits from xmm0 and its upper 64 bits from xmm1 according to imm
|
||||
# for 256 bit ymm2 also selects its upper 128 bits from the upper 128 bits of ymm0 and ymm1 following the same constraint
|
||||
def vshufpd(x:UOp) -> UOp|None:
|
||||
a, b = base(x, 0), base(x, 1)
|
||||
if lane(x, 0) > 1 or lane(x, 1) > 1: return None
|
||||
if len(x.src) == 4 and not (a is base(x, 2) and b is base(x, 3) and lane(x, 2) > 1 and lane(x, 3) > 1): return None
|
||||
return x.ins(X86Ops.VSHUFPD, src=(a, b, imm(dtypes.uint8, sum(lane(x, i) << i for i in range(len(x.src))))))
|
||||
|
||||
# vinsertps xmm2, xmm0, xmm1, imm
|
||||
# inserts any 32 bit element in xmm1 into any position in xmm0 according to immm, result is written to xmm2
|
||||
# this is the fallback slow case for when you can't match more a powerful shuffle
|
||||
def vinsertps(x:UOp) -> UOp:
|
||||
def _insert(ret:UOp, i:int) -> UOp:
|
||||
s, v = base(x, i), lane(x, i)
|
||||
# moving the 0th element into the 0th position does nothing
|
||||
return s if i == v == 0 else x.ins(X86Ops.VINSERTPS, src=(ret, s, imm(dtypes.uint8, v << 6 | i << 4)))
|
||||
return functools.reduce(_insert, range(len(x.src)), def_reg(x.dtype))
|
||||
|
||||
# vpinsq xmm2, xmm0, rax, imm
|
||||
# inserts element in rax into any position in xmm0, result is written to xmm2 according to imm
|
||||
def vpins(x:UOp) -> UOp:
|
||||
op = {1: X86Ops.VPINSRB, 2: X86Ops.VPINSRW, 4: X86Ops.VPINSRD, 8: X86Ops.VPINSRQ}[x.dtype.scalar().itemsize]
|
||||
return functools.reduce(lambda ret,i: x.ins(op, src=(ret, x.src[i], imm(dtypes.uint8, i))), range(len(x.src)), def_reg(x.dtype))
|
||||
|
||||
# vpbroadcastd xmm1, xmm0
|
||||
# inserts scalar int in xmm0 into all lanes of xmm1
|
||||
def vpbroadcast(ctx:IselContext, x:UOp, y:UOp) -> UOp:
|
||||
n = x.ins({1: X86Ops.VPBROADCASTB, 2: X86Ops.VPBROADCASTW, 4: X86Ops.VPBROADCASTD, 8: X86Ops.VPBROADCASTQ}[y.dtype.itemsize], src=(y,))
|
||||
if is_foldable_load(ctx, n, y): return n
|
||||
# if there isn't a load we can fold we need to move y from gpr to xmm
|
||||
# this is hacky but required because int.vec(1) isn't supported
|
||||
y = y if y.dtype.itemsize > 1 else y.cast(dtypes.int16)
|
||||
return n.replace(src=(y.bitcast({2:dtypes.float16, 4:dtypes.float32, 8:dtypes.float64}[y.dtype.itemsize]),))
|
||||
|
||||
# we don't call ctx.vreg on the srcs to avoid duplicates, a rewrite will assign the tuple of valid registers to a vreg
|
||||
def idiv(ctx:IselContext, x:UOp) -> UOp:
|
||||
op = X86Ops.DIV if x.dtype in dtypes.uints else X86Ops.IDIV
|
||||
# for >8bit need to zero/sign extend rax to rdx
|
||||
if x.dtype in dtypes.int8s: ext = []
|
||||
elif x.dtype in dtypes.uints: ext = [x.ins(X86Ops.MOVi, src=(imm(min(dtypes.uint32, x.dtype), 0),), tag=(RDX,))]
|
||||
else: ext = [x.ins(X86Ops.SARi, src=(x.src[0], imm(dtypes.uint8, x.dtype.itemsize * 8 - 1)), tag=(RDX,))]
|
||||
# for 8bit need to zero/sign extend al to ah
|
||||
if x.dtype is dtypes.uint8: dividend = UOp(Ops.INS, arg=X86Ops.MOVZX, dtype=dtypes.int16, src=(x.src[0],), tag=(RAX,))
|
||||
elif x.dtype is dtypes.int8: dividend = UOp(Ops.INS, arg=X86Ops.MOVSX, dtype=dtypes.int16, src=(x.src[0],), tag=(RAX,))
|
||||
else: dividend = x.ins(X86Ops.MOV, src=(x.src[0],), tag=(RAX,))
|
||||
# divisor can't be in rax or rdx
|
||||
divisor = x.ins(X86Ops.MOV, src=(x.src[1],), tag=tuple(r for r in WGPR if r not in (RAX, RDX)))
|
||||
# for >8bit both rax and rdx are written to
|
||||
defs = (ctx.vreg(RAX),) if x.dtype in dtypes.int8s else (ctx.vreg(RAX), ctx.vreg(RDX))
|
||||
idiv = x.ins(op, src=(dividend, divisor) + tuple(ext), tag=defs)
|
||||
# this move "cleanses" the register constraint (rax) of idiv as it only applies on definition and not on the uses of idiv
|
||||
return x.ins(X86Ops.MOV, src=(idiv,))
|
||||
|
||||
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp]:
|
||||
def _disp(v:int) -> UOp: return imm(dtypes.int32 if abs(v) > dtypes.int8.max else dtypes.int8, v)
|
||||
def _cast(v:UOp) -> UOp: return v.cast(dtypes.int64) if v.vmin < 0 else v
|
||||
if x.op is not Ops.INDEX: return (x, UOp(Ops.NOOP), _disp(0))
|
||||
base, idx = x.src
|
||||
disp_scale = base.dtype.itemsize if isinstance(base.dtype, PtrDType) else 1
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * disp_scale))
|
||||
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * disp_scale))
|
||||
return (base, _cast(idx), _disp(0))
|
||||
|
||||
def alloc_defs(ctx:IselContext, x:UOp) -> UOp|None:
|
||||
if x.dtype is dtypes.void or isinstance(x.tag, tuple): return None
|
||||
if x.op in {Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL}:
|
||||
i = ctx.func_args.index(x)
|
||||
regs = (RCX, RDX, GPR[8], GPR[9]) if sys.platform == "win32" else (RDI, RSI, RDX, RCX, GPR[8], GPR[9])
|
||||
if i < len(regs): return x.replace(tag=(ctx.vreg(regs[i]),))
|
||||
defs = [ctx.vreg(WGPR)] if x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType) else [ctx.vreg(XMM)]
|
||||
return x.replace(tag=tuple(defs))
|
||||
|
||||
def alloc_vregs(ctx:IselContext, x:UOp) -> UOp|None:
|
||||
# immediates and real registers
|
||||
if x.op is Ops.CONST: return None
|
||||
if x.arg in (X86Ops.FRAME_INDEX, X86Ops.DEFINE_REG) and x.tag is not None: return None
|
||||
# no register definition
|
||||
if x.dtype is dtypes.void: return None
|
||||
# already allocated vregs
|
||||
if isinstance(x.tag, tuple) and x.tag[0]._cons: return None
|
||||
# allocate vreg definitions
|
||||
defs = []
|
||||
if isinstance(x.tag, tuple): defs = [ctx.vreg(x.tag)]
|
||||
elif x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType): defs = [ctx.vreg(WGPR)]
|
||||
elif x.dtype in dtypes.floats or x.dtype.count > 1: defs = [ctx.vreg(XMM)]
|
||||
# TODO: add this once the scheduler can track register pressure
|
||||
# if x.arg in X86GroupOp.WriteFlags: defs.append(ctx.vreg(RFLAGS))
|
||||
return x.replace(tag=tuple(defs))
|
||||
|
||||
def lower_abi(ctx, x:UOp):
|
||||
i = ctx.func_arg_idxs[_uop_key(x)]
|
||||
if sys.platform == "win32": regs, stack_base = (RCX, RDX, GPR[8], GPR[9]), 32
|
||||
else: regs, stack_base = (RDI, RSI, RDX, RCX, GPR[8], GPR[9]), 0
|
||||
if i < len(regs):
|
||||
src = def_reg(x.dtype, regs[i])
|
||||
if x.reg == src.reg: return src, [src]
|
||||
return (nx:=x.ins(X86Ops.MOV, src=(src,)), [src, nx])
|
||||
fi = UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=(i-len(regs)+1)*8+stack_base)
|
||||
nx = x.ins(X86Ops.MOV, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), fi))
|
||||
return nx, [fi, nx]
|
||||
|
||||
def lower_stack_define(ctx, x:UOp):
|
||||
disp = imm(dtypes.int32, ctx.local_offsets[_uop_key(x)])
|
||||
nx = x.ins(X86Ops.LEA, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), disp))
|
||||
return nx, [disp, nx]
|
||||
|
||||
dts = dtypes.ints + (dtypes.bool, dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
dt_16bit = tuple(dt.vec(l) for dt in dts for l in [2,1] if l*dt.itemsize == 2 and dt not in dtypes.int16s)
|
||||
dt_32bit = tuple(dt.vec(l) for dt in dts for l in [4,2,1] if l*dt.itemsize == 4 and dt not in dtypes.int32s)
|
||||
dt_64bit = tuple(dt.vec(l) for dt in dts for l in [8,4,2,1] if l*dt.itemsize == 8 and dt not in dtypes.int64s)
|
||||
dt_128bit = tuple(dt.vec(l) for dt in dts for l in [16,8,4,2,1] if l*dt.itemsize == 16)
|
||||
|
||||
isel_matcher = PatternMatcher([
|
||||
# **** Op -> Op ****
|
||||
# float gep(0) is a noop as it just moves the 0th element from one xmm register to another
|
||||
# this is done here to not interfere with shuffles
|
||||
(UPat(dtype=dtypes.floats).gep(0, name="x"), lambda x: x.replace(op=Ops.NOOP, arg=None)),
|
||||
# range is lowered to acc, cmp, jmp after regalloc
|
||||
(UPat(Ops.RANGE, src=(UPat.cvar("c"),), allow_any_len=True, name="x"), lambda c,x: x.replace(src=(imm(c.dtype, c.arg),) + x.src[1:])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(tag=(ctx.vreg(WGPR),)) if not isinstance(x.tag, tuple) else None),
|
||||
# **** Op -> X86Op ****
|
||||
# append return, callee-saved live ranges are inserted in regalloc
|
||||
(UPat(Ops.SINK, name="x"), lambda x:
|
||||
x.replace(src=(x.ins(X86Ops.RET, src=x.src),))
|
||||
if not (len(x.src) == 1 and x.src[0].op is Ops.INS and x.src[0].arg is X86Ops.RET) else None),
|
||||
# late lowered function args and stack backed locals still need virtual registers
|
||||
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), alloc_defs),
|
||||
# constants that can't be immediates, move them to registers
|
||||
(UPat.cvar("x", dtypes.int64s), lambda x: x.ins(X86Ops.MOVABS, src=(imm(x.dtype, x.arg),)) if x.tag is None else None),
|
||||
(UPat.cvar("x", dtypes.ints+(dtypes.bool,)), lambda x: x.ins(X86Ops.MOVi, src=(imm(x.dtype, x.arg),)) if x.tag is None else None),
|
||||
(UPat.cvar("x", dtypes.floats), lambda x:
|
||||
UOp.const(dt:=to_int(x.dtype), struct.unpack(dt.fmt, struct.pack(x.dtype.fmt, x.arg))[0]).bitcast(x.dtype) if x.tag is None else None),
|
||||
# TODO: these should use a.maximum(b) / a.minimum(b)
|
||||
((UPat.var("a") < UPat.var("b")).where(UPat.var("b", dtypes.float32), UPat.var("a")), lambda a,b:
|
||||
a.ins(X86Ops.VMAXSS if a.dtype.count == 1 else X86Ops.VMAXPS, src=(a, b))),
|
||||
((UPat.var("a") < UPat.var("b")).where(UPat.var("b", dtypes.float64), UPat.var("a")), lambda a,b:
|
||||
a.ins(X86Ops.VMAXSD if a.dtype.count == 1 else X86Ops.VMAXPD, src=(a, b))),
|
||||
((UPat.var("a") < UPat.var("b")).where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda a,b:
|
||||
a.ins(X86Ops.VMINSS if a.dtype.count == 1 else X86Ops.VMINPS, src=(a, b))),
|
||||
((UPat.var("a") < UPat.var("b")).where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda a,b:
|
||||
a.ins(X86Ops.VMINSD if a.dtype.count == 1 else X86Ops.VMINPD, src=(a, b))),
|
||||
# conditional moves that use masks NOTE: these currently assume a mask producing cmp exists
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.ints), UPat.var("b")), lambda m,a,b:
|
||||
a.ins(X86Ops.VPBLENDVB, src=(b, a, m.replace(dtype=m.src[0].dtype))) if a.dtype.count > 1 else None),
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda m,a,b:
|
||||
a.ins(X86Ops.VBLENDVPS, src=(b, a, m.replace(dtype=m.src[0].dtype)))),
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda m,a,b:
|
||||
a.ins(X86Ops.VBLENDVPD, src=(b, a, m.replace(dtype=m.src[0].dtype)))),
|
||||
# in this case we have a mask producing comparison whose user expects a bool, so we convert to bool
|
||||
(UPat(GroupOp.Comparison, dtypes.bool, (UPat.var("y", (dtypes.float32, dtypes.float64)), UPat()), name="x"), lambda y,x:
|
||||
x.replace(dtype=y.dtype).bitcast(to_int(y.dtype)).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)),
|
||||
# conditional moves that use flags
|
||||
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b:
|
||||
a.ins(X86Ops.CMOVL, src=(b, a, cmp(m)))),
|
||||
(UPat(Ops.CMPLT, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVB, src=(b, a, cmp(m)))),
|
||||
(UPat(Ops.CMPEQ, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVE, src=(b, a, cmp(m)))),
|
||||
(UPat(Ops.CMPNE, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVNE, src=(b, a, cmp(m)))),
|
||||
# jumps, use flags
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.uints), UPat()), name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JB, src=(cmp(y),))),
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JL, src=(cmp(y),))),
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPEQ, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JE, src=(cmp(y),))),
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPNE, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JNE, src=(cmp(y),))),
|
||||
# comparisons whose user doesn't use the flag, move flag result to register
|
||||
(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="x"), lambda x: x.ins(X86Ops.SETB, src=(cmp(x),))),
|
||||
(UPat(Ops.CMPLT, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETL, src=(cmp(x),))),
|
||||
(UPat(Ops.CMPEQ, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETE, src=(cmp(x),))),
|
||||
(UPat(Ops.CMPNE, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETNE, src=(cmp(x),))),
|
||||
# comparisons that produce masks (these aren't bool dtype)
|
||||
(UPat(GroupOp.Comparison, src=(UPat(dtype=(dtypes.float32, dtypes.float64)), UPat()), name="x"), vcmp),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int8s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQB)),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int16s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQW)),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int32s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQD)),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int64s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQQ)),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int8s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTB, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int16s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTW, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int32s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTD, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int64s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTQ, src=(b, a))),
|
||||
# float unary
|
||||
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSS, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPS)),
|
||||
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSD, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPD)),
|
||||
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x:
|
||||
x.ins(X86Ops.VROUNDSS, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else x.ins(X86Ops.VROUNDPS, src=(y, imm(dtypes.uint8, 3)))),
|
||||
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x:
|
||||
x.ins(X86Ops.VROUNDSD, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else x.ins(X86Ops.VROUNDPD, src=(y, imm(dtypes.uint8, 3)))),
|
||||
# shufles
|
||||
(UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VBROADCASTSS, src=(y,))),
|
||||
# for float16 we route the srcs through gprs unless we can fold them, this is suboptimal for values in xmms, in that case we want vpunpcklwd
|
||||
(UPat(Ops.VECTORIZE, dtypes.float16, name="x"), lambda ctx,x:
|
||||
vpins(x.replace(src=tuple(s if is_foldable_load(ctx, x, s) else s.bitcast(dtypes.int16) for s in x.src)))),
|
||||
(UPat(Ops.VECTORIZE, (dtypes.float32.vec(4), dtypes.float32.vec(8)), name="x"), vshufps),
|
||||
(UPat(Ops.VECTORIZE, (dtypes.float64.vec(2), dtypes.float64.vec(4)), name="x"), vshufpd),
|
||||
(UPat(Ops.VECTORIZE, dtypes.float32, name="x"), vinsertps),
|
||||
(UPat.var("y", dtypes.ints+(dtypes.bool,)).broadcast(name="x"), vpbroadcast),
|
||||
(UPat(Ops.VECTORIZE, dtypes.ints+(dtypes.bool,), name="x"), vpins),
|
||||
# gep
|
||||
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.floats).gep(name="x"), lambda y,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, x.arg[0] * x.dtype.itemsize)))),
|
||||
# fused multiply add TODO: don't fuse if mul used several times
|
||||
(UPat.var('a', dtypes.float32) * UPat.var('b') + UPat.var('c'), lambda a,b,c:
|
||||
a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(a, b, c))),
|
||||
(UPat.var('a', dtypes.float64) * UPat.var('b') + UPat.var('c'), lambda a,b,c:
|
||||
a.ins(X86Ops.VFMADD213SD if a.dtype.count == 1 else X86Ops.VFMADD213PD, src=(a, b, c))),
|
||||
# packed bitwise
|
||||
((UPat() & UPat()).named("x"), lambda x: x.ins(X86Ops.VPAND) if x.dtype.count > 1 else None),
|
||||
((UPat() | UPat()).named("x"), lambda x: x.ins(X86Ops.VPOR) if x.dtype.count > 1 else None),
|
||||
((UPat() ^ UPat()).named("x"), lambda x: x.ins(X86Ops.VPXOR) if x.dtype.count > 1 else None),
|
||||
# packed int binary
|
||||
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVQ) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVQ) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRAVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDB) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDW) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDQ) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.ins(X86Ops.VPSUBB) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPSUBW) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPSUBD) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPSUBQ) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMULLW) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMULLD) if x.dtype.count > 1 else None),
|
||||
# scalar int binary
|
||||
((UPat(dtype=dtypes.ints) // UPat()).named("x"), idiv),
|
||||
# scalar int binary with immediate
|
||||
(UPat.var("a", dtypes.ints) << UPat.cvar("c"), lambda a,c: a.ins(X86Ops.SHLi, src=(a, imm(dtypes.uint8, c.arg)))),
|
||||
(UPat.var("a", dtypes.uints) >> UPat.cvar("c"), lambda a,c: a.ins(X86Ops.SHRi, src=(a, imm(dtypes.uint8, c.arg)))),
|
||||
(UPat.var("a", dtypes.sints) >> UPat.cvar("c"), lambda a,c: a.ins(X86Ops.SARi, src=(a, imm(dtypes.uint8, c.arg)))),
|
||||
(UPat.var("a", dtypes.ints) + UPat.cvar("c"), lambda a,c: a.ins(X86Ops.ADDi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
|
||||
(UPat.var("a", dtypes.ints) * UPat.cvar("c"), lambda a,c: a.ins(X86Ops.IMULi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
|
||||
(UPat.var("a", dtypes.ints+(dtypes.bool,)) & UPat.cvar("c"), lambda a,c: a.ins(X86Ops.ANDi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
|
||||
(UPat.var("a", dtypes.ints+(dtypes.bool,)) | UPat.cvar("c"), lambda a,c: a.ins(X86Ops.ORi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
|
||||
(UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.cvar("c"), lambda a,c: a.ins(X86Ops.XORi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
|
||||
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.cvar("c"))), lambda a,c: a.ins(X86Ops.SUBi, src=(a, i)) if (i:=to_imm(c)) is not None else None),
|
||||
# scalar int binary with register
|
||||
(UPat.var("a", dtypes.ints) << UPat.var("b"), lambda a,b: a.ins(X86Ops.SHL, src=(a, b))),
|
||||
(UPat.var("a", dtypes.uints) >> UPat.var("b"), lambda a,b: a.ins(X86Ops.SHR, src=(a, b))),
|
||||
(UPat.var("a", dtypes.sints) >> UPat.var("b"), lambda a,b: a.ins(X86Ops.SAR, src=(a, b))),
|
||||
(UPat.var("a", dtypes.ints) + UPat.var("b"), lambda a,b: a.ins(X86Ops.ADD, src=(a, b))),
|
||||
(UPat.var("a", dtypes.ints) * UPat.var("b"), lambda a,b: a.ins(X86Ops.IMUL, src=(a, b))),
|
||||
(UPat.var("a", dtypes.ints+(dtypes.bool,)) & UPat.var("b"), lambda a,b: a.ins(X86Ops.AND, src=(a, b))),
|
||||
(UPat.var("a", dtypes.ints+(dtypes.bool,)) | UPat.var("b"), lambda a,b: a.ins(X86Ops.OR, src=(a, b))),
|
||||
(UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.var("b"), lambda a,b: a.ins(X86Ops.XOR, src=(a, b))),
|
||||
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.var("b"))), lambda a,b: a.ins(X86Ops.SUB, src=(a, b))),
|
||||
# float binary
|
||||
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSS if x.dtype.count == 1 else X86Ops.VADDPS)),
|
||||
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSD if x.dtype.count == 1 else X86Ops.VADDPD)),
|
||||
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSS if x.dtype.count == 1 else X86Ops.VMULPS)),
|
||||
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSD if x.dtype.count == 1 else X86Ops.VMULPD)),
|
||||
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VSUBSS if x.dtype.count == 1 else X86Ops.VSUBPS)),
|
||||
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)),
|
||||
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)),
|
||||
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)),
|
||||
# casts
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPS2DQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPD2DQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPD2PS) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.float16, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PH, src=x.src + (imm(dtypes.uint8, 4),))),
|
||||
(UPat(dtype=dtypes.float16).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPH2PS)),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSS2SI)),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSD2SI)),
|
||||
(UPat.var("y", dtypes.float32).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSS2SD, src=(y, y))),
|
||||
(UPat.var("y", dtypes.float64).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSD2SS, src=(y, y))),
|
||||
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SS, src=(def_reg(x.dtype), y))),
|
||||
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SD, src=(def_reg(x.dtype), y))),
|
||||
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVZX) if x.dtype.count == 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVSXD) if x.dtype.count == 1 else None),
|
||||
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVSX) if x.dtype.count == 1 else None),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBW)),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBD)),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBQ)),
|
||||
(UPat(dtype=dtypes.uint16).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXWD)),
|
||||
(UPat(dtype=dtypes.uint16).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXWQ)),
|
||||
(UPat(dtype=dtypes.uint32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXDQ)),
|
||||
(UPat(dtype=dtypes.int8).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBW)),
|
||||
(UPat(dtype=dtypes.int8).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBD)),
|
||||
(UPat(dtype=dtypes.int8).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBQ)),
|
||||
(UPat(dtype=dtypes.int16).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXWD)),
|
||||
(UPat(dtype=dtypes.int16).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXWQ)),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXDQ)),
|
||||
# bitcasts
|
||||
(UPat.var("y", dtypes.float16).bitcast(dtypes.int16s).named("x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, 0)))),
|
||||
(UPat(dtype=dtypes.int16s).bitcast(dtypes.float16).named("x"), vpins),
|
||||
(UPat(dtype=dtypes.int32s).bitcast(dtypes.float32).named("x"), lambda x: x.ins(X86Ops.VMOVD)),
|
||||
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.ins(X86Ops.VMOVQ)),
|
||||
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
|
||||
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
|
||||
# index
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
|
||||
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
|
||||
# copy, load, store
|
||||
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
|
||||
(UPat(Ops.COPY, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS)),
|
||||
(UPat(Ops.COPY, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD)),
|
||||
(UPat(Ops.COPY, dt_32bit+dt_16bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS)),
|
||||
(UPat(Ops.COPY, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV)),
|
||||
(UPat(Ops.LOAD, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS, src=fold_address(x.src[0]))),
|
||||
(UPat(Ops.LOAD, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD, src=fold_address(x.src[0]))),
|
||||
(UPat(Ops.LOAD, dt_32bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS, src=fold_address(x.src[0]))),
|
||||
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda x:
|
||||
x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype, x.tag),) + fold_address(x.src[0]) + (imm(dtypes.uint8, 0),))),
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV, src=fold_address(x.src[0]))),
|
||||
(UPat.var("a").store(UPat.var("b", dt_128bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVUPSm, src=fold_address(a) + (b,))),
|
||||
(UPat.var("a").store(UPat.var("b", dt_64bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVSDm, src=fold_address(a) + (b,))),
|
||||
(UPat.var("a").store(UPat.var("b", dt_32bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVSSm, src=fold_address(a) + (b,))),
|
||||
(UPat.var("a").store(UPat.var("b", dt_16bit), name="x"), lambda a,b,x: x.ins(X86Ops.VPEXTRW, src=fold_address(a) + (b, imm(dtypes.uint8, 0)))),
|
||||
(UPat.var("a").store(UPat.var("b", dtypes.ints+(dtypes.bool,)), name="x"), lambda a,b,x:
|
||||
x.ins(X86Ops.MOVm, src=fold_address(a) + (b,)) if (i:=to_imm(b)) is None else x.ins(X86Ops.MOVi, src=fold_address(a) + (i,))),
|
||||
# **** X86Op -> X86Op ****
|
||||
# fold loads into X86Ops that allow it, if beneficial
|
||||
(UPat(Ops.INS, src=(UPat(Ops.LOAD, name="y"),), allow_any_len=True, name="x"), lambda ctx,y,x:
|
||||
x.replace(src=fold_address(y.src[0]) + x.src[1:]) if x.arg in X86GroupOp.ReadMem1st and is_foldable_load(ctx, x, y) else None),
|
||||
(UPat(Ops.INS, src=(UPat(), UPat(Ops.LOAD, name="y")), allow_any_len=True, name="x"), lambda ctx,y,x:
|
||||
x.replace(src=x.src[:1] + fold_address(y.src[0]) + x.src[2:]) if x.arg in X86GroupOp.ReadMem2nd and is_foldable_load(ctx, x, y) else None),
|
||||
(UPat(Ops.INS, src=(UPat(), UPat(), UPat(Ops.LOAD, name="y")), allow_any_len=True, name="x"), lambda ctx,y,x:
|
||||
x.replace(src=x.src[:2] + fold_address(y.src[0]) + x.src[3:]) if x.arg in X86GroupOp.ReadMem3rd and is_foldable_load(ctx, x, y) else None),
|
||||
# allocate virtual registers
|
||||
(UPat(Ops.INS, name="x"), alloc_vregs),
|
||||
])
|
||||
|
||||
# ***** pre register allocation *****
|
||||
# this handles flag clobbers. Unfortunately x86 doesn't have a good way to store/restore the flag register (then regalloc would handle it)
|
||||
# so we rematerialize. This is different from rematerialization you might want to do in regalloc because it is not optional,
|
||||
# regalloc shouldn't rematerialize if a src of the instruction is dead, but here you need to as there's no fallback load from stack
|
||||
def flag_rematerialize(ctx:PreRegAllocContext, x:UOp):
|
||||
flag_def = x if x.arg in X86GroupOp.WriteFlags or x.op is Ops.RANGE else x.src[-1] if x.arg in X86GroupOp.ReadFlags else None
|
||||
if flag_def is None: return None
|
||||
if ctx.lock is not None and ctx.lock is not flag_def: ctx.clobbered.add(ctx.lock)
|
||||
ctx.lock = flag_def
|
||||
if flag_def not in ctx.clobbered: return None
|
||||
ctx.clobbered.remove(flag_def)
|
||||
return (x, [flag_def, x])
|
||||
|
||||
pre_regalloc_matcher = PatternMatcher([
|
||||
(UPat((Ops.INS, Ops.RANGE), name="x"), flag_rematerialize),
|
||||
])
|
||||
|
||||
late_regalloc_matcher = PatternMatcher([
|
||||
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lower_abi),
|
||||
(UPat((Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), lower_stack_define),
|
||||
])
|
||||
|
||||
# ***** post register allocation *****
|
||||
# TODO: control flow should be overhauled so that this isn't necessary
|
||||
def lower_range(ctx, x:UOp) -> tuple[UOp, list[UOp]]:
|
||||
loop_label = "_".join(str(i) for i in x.arg[:-1])
|
||||
acc = x.ins(X86Ops.MOVi, src=(imm(x.dtype, 0),) + x.src[1:])
|
||||
label = UOp(Ops.INS, arg=X86Ops.LABEL, tag=f".LOOP_{loop_label}")
|
||||
cmp = UOp(Ops.INS, arg=X86Ops.CMPi if x.src[0].op is Ops.CONST else X86Ops.CMP, src=(acc, x.src[0]))
|
||||
jump_out = UOp(Ops.INS, arg=X86Ops.JGE, src=(cmp,), tag=f".LOOP_OUT_{loop_label}")
|
||||
ctx.loop_label[acc] = loop_label
|
||||
return (acc, [acc, label, cmp, jump_out])
|
||||
|
||||
# final rewrite to match the isa spec
|
||||
post_regalloc_matcher = PatternMatcher([
|
||||
# alloc stack space
|
||||
(UPat(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dtypes.uint64, name="x"), lambda ctx,x:
|
||||
(x, [x, x.ins(X86Ops.SUBi, src=(imm(dtypes.uint32, ctx.stack_size),), tag=(RSP,))]) if ctx.stack_size > 0 and x.reg is RSP else None),
|
||||
# dealloc stack space
|
||||
(UPat(Ops.INS, arg=X86Ops.RET, name="x"), lambda ctx,x: (x, [UOp(Ops.INS, arg=X86Ops.ADDi, dtype=dtypes.uint64,
|
||||
src=(imm(dtypes.uint32, ctx.stack_size),), tag=(RSP,)), x]) if ctx.stack_size > 0 else None),
|
||||
# rewrite FRAME_INDEX to CONST now that the stack size is known
|
||||
(UPat(Ops.INS, arg=X86Ops.FRAME_INDEX, name="x"), lambda ctx,x: (nx:=UOp.const(x.dtype, ctx.stack_size + x.tag), [nx])),
|
||||
# rewrite RANGE to ACC = 0 -> LABEL -> JUMP if ACC >= loop bound
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x: lower_range(ctx, x)),
|
||||
# rewrite END to ACC + 1 -> JUMP -> LABEL, also add the out of loop JUMP to the src so this becomes the jump target
|
||||
(UPat(Ops.END, name="x"), lambda ctx,x: (jmp:=UOp(Ops.INS, arg=X86Ops.JMP, tag=f".LOOP_{ctx.loop_label[x.src[1]]}"),
|
||||
[x.src[1].ins(X86Ops.ADDi, src=(imm(x.src[1].dtype, 1),)), jmp, UOp(Ops.INS, arg=X86Ops.LABEL, tag=f".LOOP_OUT_{ctx.loop_label[x.src[1]]}")])),
|
||||
# rewrite two address instructions to two address form, if reused src wasn't coalesced insert a move
|
||||
(UPat(Ops.INS, name="x"), lambda ctx,x: (nx:=x.replace(src=x.src[1:]),
|
||||
[ctx.ren.copy(x.src[0], x.reg), nx] if x.reg != x.src[0].reg else [nx]) if x.arg in X86GroupOp.TwoAddress else None),
|
||||
])
|
||||
|
||||
# ***** X86 spec *****
|
||||
# TODO: do we even want this?
|
||||
isa_spec = PatternMatcher([
|
||||
# these are the only non X86Ops allowed
|
||||
(UPat(Ops.CONST), lambda: True),
|
||||
(UPat((Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER, Ops.SINK)), lambda: True),
|
||||
(UPat(Ops.INS, name="x"), lambda x: x.arg in X86GroupOp.All),
|
||||
])
|
||||
|
||||
# ***** X86 instruction encoding *****
|
||||
|
||||
def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) -> bytes|None:
|
||||
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
|
||||
nonlocal reg, opc
|
||||
# get the encoding values of the different fields
|
||||
reg = cast(int, cast(Register, reg_uop.reg).index if reg_uop is not None else reg)
|
||||
rm = cast(Register, rm_uop.reg).index
|
||||
idx = cast(Register, idx_uop.reg).index if idx_uop is not None and idx_uop.reg is not None else 4
|
||||
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else rm_uop.dtype.itemsize
|
||||
reg_sz = (reg_uop.dtype.itemsize if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
|
||||
sz = reg_sz or rm_sz
|
||||
|
||||
# encode instruction
|
||||
inst = bytes([])
|
||||
assert 0 <= reg <= 15 and 0 <= idx <= 15 and 0 <= rm <= 15
|
||||
# r extends reg field, x extends index field, b extends rm or base field
|
||||
r, _x, b = reg >> 3, idx >> 3, rm >> 3
|
||||
if sel: # VEX bytes
|
||||
vvvv = cast(Register, vvvv_uop.reg).index if vvvv_uop is not None else 0
|
||||
l = (max(reg_sz, rm_sz) > 16) & 0b1
|
||||
if sel == 1 and _x == b == we == 0: inst += bytes([0xC5, (~r & 0b1) << 7 | (~vvvv & 0b1111) << 3 | l << 2 | pp])
|
||||
else: inst += bytes([0xC4, (~r & 0b1) << 7 | (~_x & 0b1) << 6 | (~b & 0b1) << 5 | sel, we << 7 | (~vvvv & 0b1111) << 3 | l << 2 | pp])
|
||||
else: # optional PREFIX and REX bytes
|
||||
# PREFIX byte signaling 16 bit variant of instruction
|
||||
if sz == 2: inst += bytes([0x66])
|
||||
# bit signaling 64 bit variant of instruction
|
||||
w = sz == 8
|
||||
# REX byte is required when 64 bit or an extended reg is used (index 8 - 15) or lower 8 bits of (rsp, rbp, rsi, rdi) are accessed
|
||||
if w | r | _x | b | (reg_sz == 1 & reg >> 2) | (rm_sz == 1 & rm >> 2): inst += bytes([0b0100 << 4 | w << 3 | r << 2 | _x << 1 | b])
|
||||
# legacy 8bit opcode is 1 less than 16-64bit variants
|
||||
if (rm_sz == 1 or reg_sz == 1) and x.arg not in X86GroupOp.ReadFlags | {X86Ops.LEA}: opc -= 1
|
||||
# OPCODE byte
|
||||
inst += opc.to_bytes((opc.bit_length() + 7) // 8, 'big')
|
||||
# MODRM byte
|
||||
# now we only care about the lower 3 bits
|
||||
idx, rm, reg = idx & 0b111, rm & 0b111, reg & 0b111
|
||||
# 0b00 -- signals memory access with no displacement
|
||||
# 0b01 -- signals memory access with 8bit displacement
|
||||
# 0b10 -- signals memory access with 32bit displacement
|
||||
# 0b11 -- signals no memory access
|
||||
if disp_uop is not None:
|
||||
assert disp_uop.dtype in (dtypes.int8, dtypes.int32), "displacement can only be 1 or 4 byte signed int"
|
||||
# rbp/r13 always require a displacement
|
||||
if disp_uop.arg != 0 or rm == 0b101: mod = 0b01 if disp_uop.dtype.itemsize == 1 else 0b10
|
||||
else: mod = 0b00
|
||||
else: mod = 0b11
|
||||
# x 0b0 and idx 0b100 means rsp which means no index exists
|
||||
# rm 0b100 (rsp/r12) signals a sib byte is required, rm then is encoded in the base field of SIB
|
||||
_rm = rm if idx == 0b100 and _x == 0b0 else 0b100
|
||||
inst += bytes([mod << 6 | reg << 3 | _rm])
|
||||
# SIB byte
|
||||
if _rm == 0b100 and mod != 0b11:
|
||||
scale = {1: 0b00, 2: 0b01, 4: 0b10, 8: 0b11}[1 if idx == 0b100 and _x == 0b0 else rm_sz]
|
||||
inst += bytes([scale << 6 | idx << 3 | rm])
|
||||
# DISP byte
|
||||
if mod == 0b01 or mod == 0b10:
|
||||
assert disp_uop is not None
|
||||
inst += struct.pack(unwrap(disp_uop.dtype.fmt), disp_uop.arg)
|
||||
# IMM byte
|
||||
if imm_uop is not None:
|
||||
if imm_uop.op is Ops.CONST: inst += struct.pack(unwrap(imm_uop.dtype.fmt), imm_uop.arg)
|
||||
elif isinstance(imm_uop.reg, Register): inst += bytes([(imm_uop.reg.index & 0b1111) << 4 | 0b0000])
|
||||
return inst
|
||||
|
||||
# get the encoding structure of the uop
|
||||
# when a uop writes to memory it takes the form of a store, dtype is void, no definition
|
||||
address:tuple[UOp|None, ...]
|
||||
if x.arg in X86GroupOp.WriteMem:
|
||||
if len(x.src) > 3: address, rest = x.src[:3], x.src[3:]
|
||||
else: address, rest = (x, None, None), x.src
|
||||
return _encode(rest[0], *address, *(None, *rest[1:])) if reg is None else _encode(None, *address, *(None, *rest[:1]))
|
||||
|
||||
if x.arg in X86GroupOp.Rm1st:
|
||||
if len(x.src) > 2: address, rest = x.src[:3], x.src[3:]
|
||||
else: address, rest = (x.src[0], None, None), x.src[1:]
|
||||
imm_uop = rest[:1] if rest and (rest[0].op is Ops.CONST or isinstance(rest[0].reg, Register)) else (None,)
|
||||
return _encode(x, *address, *(None, *imm_uop)) if reg is None else _encode(None, *address, *(x if sel else None, *imm_uop))
|
||||
|
||||
if x.arg in X86GroupOp.Rm2nd:
|
||||
if len(x.src) > 3: address, rest = x.src[1:4], x.src[:1] + x.src[4:]
|
||||
else: address, rest = (x.src[1], None, None), x.src[:1] + x.src[2:]
|
||||
# cmp/vucomiss reg, rm don't define a new register
|
||||
return _encode(x, *address, *rest) if x.dtype is not dtypes.void else _encode(rest[0], *address)
|
||||
|
||||
return None
|
||||
|
||||
# https://www.felixcloutier.com/x86/
|
||||
# NOTE: LEGACY prefix == VEX prefix
|
||||
# pp field: None == 0, 66 == 1, F3 == 2, F2 == 3
|
||||
# map select: 0F == 1, 0F38 == 2, 0F3A == 3
|
||||
encodings = {
|
||||
# moves
|
||||
X86Ops.MOVABS: lambda x:
|
||||
bytes([0b0100 << 4 | 0b1 << 3 | 0b00 << 2 | x.tag[0].index >> 3, 0xB8 + (x.tag[0].index & 0b111)]) + struct.pack(x.dtype.fmt, x.src[0].arg),
|
||||
X86Ops.MOV: lambda x: encode(x, 0x8B), X86Ops.MOVi: lambda x: encode(x, 0xC7, reg=0),
|
||||
X86Ops.MOVm: lambda x: encode(x, 0x89), X86Ops.LEA: lambda x: encode(x, 0x8D),
|
||||
X86Ops.VMOVSS: lambda x: encode(x, 0x10, pp=2, sel=1), X86Ops.VMOVSSm: lambda x: encode(x, 0x11, pp=2, sel=1),
|
||||
X86Ops.VMOVSD: lambda x: encode(x, 0x10, pp=3, sel=1), X86Ops.VMOVSDm: lambda x: encode(x, 0x11, pp=3, sel=1),
|
||||
X86Ops.VMOVUPS: lambda x: encode(x, 0x10, pp=0, sel=1), X86Ops.VMOVUPSm: lambda x: encode(x, 0x11, pp=0, sel=1),
|
||||
X86Ops.VMOVD: lambda x: encode(x, 0x6E, pp=1, sel=1), X86Ops.VMOVQ: lambda x: encode(x, 0x6E, pp=1, sel=1, we=1),
|
||||
X86Ops.VMOVDm: lambda x: encode(x, 0x7E, pp=1, sel=1), X86Ops.VMOVQm: lambda x: encode(x, 0x7E, pp=1, sel=1, we=1),
|
||||
# casts
|
||||
X86Ops.MOVZX: lambda x: encode(x, 0x0FB7),
|
||||
X86Ops.MOVSX: lambda x: encode(x, 0x0FBF), X86Ops.MOVSXD: lambda x: encode(x, 0x63),
|
||||
X86Ops.VPMOVZXBW: lambda x: encode(x, 0x30, pp=1, sel=2), X86Ops.VPMOVZXBD: lambda x: encode(x, 0x31, pp=1, sel=2),
|
||||
X86Ops.VPMOVZXBQ: lambda x: encode(x, 0x32, pp=1, sel=2), X86Ops.VPMOVZXWD: lambda x: encode(x, 0x33, pp=1, sel=2),
|
||||
X86Ops.VPMOVZXWQ: lambda x: encode(x, 0x34, pp=1, sel=2), X86Ops.VPMOVZXDQ: lambda x: encode(x, 0x35, pp=1, sel=2),
|
||||
X86Ops.VPMOVSXBW: lambda x: encode(x, 0x20, pp=1, sel=2), X86Ops.VPMOVSXBD: lambda x: encode(x, 0x21, pp=1, sel=2),
|
||||
X86Ops.VPMOVSXBQ: lambda x: encode(x, 0x22, pp=1, sel=2), X86Ops.VPMOVSXWD: lambda x: encode(x, 0x23, pp=1, sel=2),
|
||||
X86Ops.VPMOVSXWQ: lambda x: encode(x, 0x24, pp=1, sel=2), X86Ops.VPMOVSXDQ: lambda x: encode(x, 0x25, pp=1, sel=2),
|
||||
X86Ops.VCVTSS2SD: lambda x: encode(x, 0x5A, pp=2, sel=1), X86Ops.VCVTSD2SS: lambda x: encode(x, 0x5A, pp=3, sel=1),
|
||||
X86Ops.VCVTPH2PS: lambda x: encode(x, 0x13, pp=1, sel=2), X86Ops.VCVTPS2PH: lambda x: encode(x, 0x1D, pp=1, sel=3),
|
||||
X86Ops.VCVTDQ2PS: lambda x: encode(x, 0x5B, pp=0, sel=1), X86Ops.VCVTDQ2PD: lambda x: encode(x, 0xE6, pp=2, sel=1),
|
||||
X86Ops.VCVTPS2PD: lambda x: encode(x, 0x5A, pp=0, sel=1), X86Ops.VCVTPD2PS: lambda x: encode(x, 0x5A, pp=1, sel=1),
|
||||
X86Ops.VCVTTPS2DQ: lambda x: encode(x, 0x5B, pp=2, sel=1), X86Ops.VCVTTPD2DQ: lambda x: encode(x, 0xE6, pp=1, sel=1),
|
||||
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.itemsize == 8),
|
||||
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.itemsize == 8),
|
||||
X86Ops.VCVTTSS2SI: lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype.itemsize == 8),
|
||||
X86Ops.VCVTTSD2SI: lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype.itemsize == 8),
|
||||
# int division
|
||||
X86Ops.IDIV: lambda x: encode(x, 0xF7, reg=7), X86Ops.DIV: lambda x: encode(x, 0xF7, reg=6),
|
||||
# scalar int binary
|
||||
X86Ops.SHLi: lambda x: encode(x, 0xC1, reg=4),
|
||||
X86Ops.SHRi: lambda x: encode(x, 0xC1, reg=5), X86Ops.SARi: lambda x: encode(x, 0xC1, reg=7),
|
||||
X86Ops.ADD: lambda x: encode(x, 0x03), X86Ops.ADDi: lambda x: encode(x, 0x81, reg=0),
|
||||
X86Ops.SUB: lambda x: encode(x, 0x2B), X86Ops.SUBi: lambda x: encode(x, 0x81, reg=5),
|
||||
X86Ops.AND: lambda x: encode(x, 0x23), X86Ops.ANDi: lambda x: encode(x, 0x81, reg=4),
|
||||
X86Ops.XOR: lambda x: encode(x, 0x33), X86Ops.XORi: lambda x: encode(x, 0x81, reg=6),
|
||||
X86Ops.OR: lambda x: encode(x, 0x0B), X86Ops.ORi: lambda x: encode(x, 0x81, reg=1),
|
||||
X86Ops.CMP: lambda x: encode(x, 0x3B), X86Ops.CMPi: lambda x: encode(x, 0x81, reg=7),
|
||||
X86Ops.IMUL: lambda x: encode(x, 0x0FAF), X86Ops.IMULi: lambda x: encode(x, 0x69),
|
||||
X86Ops.SETB: lambda x: encode(x, 0x0F92, reg=0), X86Ops.SETL: lambda x: encode(x, 0x0F9C, reg=0),
|
||||
X86Ops.SETE: lambda x: encode(x, 0x0F94, reg=0), X86Ops.SETNE: lambda x: encode(x, 0x0F95, reg=0),
|
||||
# packed bitwise NOTE: only bitwise and packed
|
||||
X86Ops.VPAND: lambda x: encode(x, 0xDB, pp=1, sel=1), X86Ops.VPXOR: lambda x: encode(x, 0xEF, pp=1, sel=1),
|
||||
X86Ops.VPOR: lambda x: encode(x, 0xEB, pp=1, sel=1),
|
||||
# unary
|
||||
X86Ops.VSQRTSS: lambda x: encode(x, 0x51, pp=2, sel=1), X86Ops.VSQRTPS: lambda x: encode(x, 0x51, pp=0, sel=1),
|
||||
X86Ops.VSQRTSD: lambda x: encode(x, 0x51, pp=3, sel=1), X86Ops.VSQRTPD: lambda x: encode(x, 0x51, pp=1, sel=1),
|
||||
X86Ops.VROUNDSS: lambda x: encode(x, 0x0A, pp=1, sel=3), X86Ops.VROUNDPS: lambda x: encode(x, 0x08, pp=1, sel=3),
|
||||
X86Ops.VROUNDSD: lambda x: encode(x, 0x0B, pp=1, sel=3), X86Ops.VROUNDPD: lambda x: encode(x, 0x09, pp=1, sel=3),
|
||||
# packed int binary
|
||||
X86Ops.VPSLLVD: lambda x: encode(x, 0x47, pp=1, sel=2), X86Ops.VPSLLVQ: lambda x: encode(x, 0x47, pp=1, sel=2, we=1),
|
||||
X86Ops.VPSRLVD: lambda x: encode(x, 0x45, pp=1, sel=2), X86Ops.VPSRLVQ: lambda x: encode(x, 0x45, pp=1, sel=2, we=1),
|
||||
X86Ops.VPCMPGTB: lambda x: encode(x, 0x64, pp=1, sel=1), X86Ops.VPCMPGTW: lambda x: encode(x, 0x65, pp=1, sel=1),
|
||||
X86Ops.VPCMPGTD: lambda x: encode(x, 0x66, pp=1, sel=1), X86Ops.VPCMPGTQ: lambda x: encode(x, 0x37, pp=1, sel=2),
|
||||
X86Ops.VPCMPEQB: lambda x: encode(x, 0x74, pp=1, sel=1), X86Ops.VPCMPEQW: lambda x: encode(x, 0x75, pp=1, sel=1),
|
||||
X86Ops.VPCMPEQD: lambda x: encode(x, 0x76, pp=1, sel=1), X86Ops.VPCMPEQQ: lambda x: encode(x, 0x29, pp=1, sel=2),
|
||||
X86Ops.VPMULLW: lambda x: encode(x, 0xD5, pp=1, sel=1), X86Ops.VPMULLD: lambda x: encode(x, 0x40, pp=1, sel=2),
|
||||
X86Ops.VPADDB: lambda x: encode(x, 0xFC, pp=1, sel=1), X86Ops.VPADDW: lambda x: encode(x, 0xFD, pp=1, sel=1),
|
||||
X86Ops.VPADDD: lambda x: encode(x, 0xFE, pp=1, sel=1), X86Ops.VPADDQ: lambda x: encode(x, 0xD4, pp=1, sel=1),
|
||||
X86Ops.VPSUBB: lambda x: encode(x, 0xF8, pp=1, sel=1), X86Ops.VPSUBW: lambda x: encode(x, 0xF9, pp=1, sel=1),
|
||||
X86Ops.VPSUBD: lambda x: encode(x, 0xFA, pp=1, sel=1), X86Ops.VPSUBQ: lambda x: encode(x, 0xFB, pp=1, sel=1),
|
||||
X86Ops.VPSRAVD: lambda x: encode(x, 0x46, pp=1, sel=2),
|
||||
# float cmp
|
||||
X86Ops.VUCOMISS: lambda x: encode(x, 0x2E, pp=0, sel=1), X86Ops.VUCOMISD: lambda x: encode(x, 0x2E, pp=1, sel=1),
|
||||
# scalar / packed float binary
|
||||
X86Ops.VADDSS: lambda x: encode(x, 0x58, pp=2, sel=1), X86Ops.VADDPS: lambda x: encode(x, 0x58, pp=0, sel=1),
|
||||
X86Ops.VADDSD: lambda x: encode(x, 0x58, pp=3, sel=1), X86Ops.VADDPD: lambda x: encode(x, 0x58, pp=1, sel=1),
|
||||
X86Ops.VSUBSS: lambda x: encode(x, 0x5C, pp=2, sel=1), X86Ops.VSUBPS: lambda x: encode(x, 0x5C, pp=0, sel=1),
|
||||
X86Ops.VSUBSD: lambda x: encode(x, 0x5C, pp=3, sel=1), X86Ops.VSUBPD: lambda x: encode(x, 0x5C, pp=1, sel=1),
|
||||
X86Ops.VMULSS: lambda x: encode(x, 0x59, pp=2, sel=1), X86Ops.VMULPS: lambda x: encode(x, 0x59, pp=0, sel=1),
|
||||
X86Ops.VMULSD: lambda x: encode(x, 0x59, pp=3, sel=1), X86Ops.VMULPD: lambda x: encode(x, 0x59, pp=1, sel=1),
|
||||
X86Ops.VDIVSS: lambda x: encode(x, 0x5E, pp=2, sel=1), X86Ops.VDIVPS: lambda x: encode(x, 0x5E, pp=0, sel=1),
|
||||
X86Ops.VDIVSD: lambda x: encode(x, 0x5E, pp=3, sel=1), X86Ops.VDIVPD: lambda x: encode(x, 0x5E, pp=1, sel=1),
|
||||
X86Ops.VCMPSS: lambda x: encode(x, 0xC2, pp=2, sel=1), X86Ops.VCMPPS: lambda x: encode(x, 0xC2, pp=0, sel=1),
|
||||
X86Ops.VCMPSD: lambda x: encode(x, 0xC2, pp=3, sel=1), X86Ops.VCMPPD: lambda x: encode(x, 0xC2, pp=1, sel=1),
|
||||
X86Ops.VMAXSS: lambda x: encode(x, 0x5F, pp=2, sel=1), X86Ops.VMAXPS: lambda x: encode(x, 0x5F, pp=0, sel=1),
|
||||
X86Ops.VMAXSD: lambda x: encode(x, 0x5F, pp=3, sel=1), X86Ops.VMAXPD: lambda x: encode(x, 0x5F, pp=1, sel=1),
|
||||
X86Ops.VMINSS: lambda x: encode(x, 0x5D, pp=2, sel=1), X86Ops.VMINPS: lambda x: encode(x, 0x5D, pp=0, sel=1),
|
||||
X86Ops.VMINSD: lambda x: encode(x, 0x5D, pp=3, sel=1), X86Ops.VMINPD: lambda x: encode(x, 0x5D, pp=1, sel=1),
|
||||
# ternary
|
||||
X86Ops.CMOVB: lambda x: encode(x, 0x0F42), X86Ops.CMOVL: lambda x: encode(x, 0x0F4C),
|
||||
X86Ops.CMOVE: lambda x: encode(x, 0x0F44), X86Ops.CMOVNE: lambda x: encode(x, 0x0F45),
|
||||
X86Ops.VFMADD213SS: lambda x: encode(x, 0xA9, pp=1, sel=2), X86Ops.VFMADD213SD: lambda x: encode(x, 0xA9, pp=1, sel=2, we=1),
|
||||
X86Ops.VFMADD213PS: lambda x: encode(x, 0xA8, pp=1, sel=2), X86Ops.VFMADD213PD: lambda x: encode(x, 0xA8, pp=1, sel=2, we=1),
|
||||
X86Ops.VBLENDVPS: lambda x: encode(x, 0x4A, pp=1, sel=3), X86Ops.VBLENDVPD: lambda x: encode(x, 0x4B, pp=1, sel=3),
|
||||
X86Ops.VPBLENDVB: lambda x: encode(x, 0x4C, pp=1, sel=3),
|
||||
# shuffles
|
||||
X86Ops.VPBROADCASTB: lambda x: encode(x, 0x78, pp=1, sel=2), X86Ops.VPBROADCASTW: lambda x: encode(x, 0x79, pp=1, sel=2),
|
||||
X86Ops.VPBROADCASTD: lambda x: encode(x, 0x58, pp=1, sel=2), X86Ops.VPBROADCASTQ: lambda x: encode(x, 0x59, pp=1, sel=2),
|
||||
X86Ops.VBROADCASTSS: lambda x: encode(x, 0x18, pp=1, sel=2), X86Ops.VPSRLDQ: lambda x: encode(x, 0x73, reg=3, pp=1, sel=1),
|
||||
X86Ops.VPINSRB: lambda x: encode(x, 0x20, pp=1, sel=3), X86Ops.VPINSRW: lambda x: encode(x, 0xC4, pp=1, sel=1),
|
||||
X86Ops.VPINSRD: lambda x: encode(x, 0x22, pp=1, sel=3), X86Ops.VPINSRQ: lambda x: encode(x, 0x22, pp=1, sel=3, we=1),
|
||||
X86Ops.VSHUFPS: lambda x: encode(x, 0xC6, pp=0, sel=1), X86Ops.VSHUFPD: lambda x: encode(x, 0xC6, pp=1, sel=1),
|
||||
X86Ops.VINSERTPS: lambda x: encode(x, 0x21, pp=1, sel=3),
|
||||
# extract
|
||||
X86Ops.VPEXTRB: lambda x: encode(x, 0x14, pp=1, sel=3), X86Ops.VPEXTRW: lambda x: encode(x, 0x15, pp=1, sel=3),
|
||||
X86Ops.VPEXTRD: lambda x: encode(x, 0x16, pp=1, sel=3), X86Ops.VPEXTRQ: lambda x: encode(x, 0x16, pp=1, sel=3, we=1),
|
||||
# jumps are encoded with a placeholder which gets patched later once the real offset is known
|
||||
X86Ops.JE: lambda x: bytes([0x0F, 0x84]) + int(0).to_bytes(4),
|
||||
X86Ops.JNE: lambda x: bytes([0x0F, 0x85]) + int(0).to_bytes(4),
|
||||
X86Ops.JL: lambda x: bytes([0x0F, 0x8C]) + int(0).to_bytes(4),
|
||||
X86Ops.JB: lambda x: bytes([0x0F, 0x82]) + int(0).to_bytes(4),
|
||||
X86Ops.JGE: lambda x: bytes([0x0F, 0x8D]) + int(0).to_bytes(4),
|
||||
X86Ops.JMP: lambda x: bytes([0xE9]) + int(0).to_bytes(4),
|
||||
X86Ops.RET: lambda x: bytes([0xC3]),
|
||||
}
|
||||
|
||||
class X86Renderer(ISARenderer):
|
||||
device = "CPU"
|
||||
has_local = False
|
||||
has_threads = bool(getenv("THREADS", 1))
|
||||
global_max = (CPU_COUNT.value, 0, 0)
|
||||
extra_matcher = extra_matcher
|
||||
pre_isel_matcher = pre_isel_matcher
|
||||
isel_matcher = isel_matcher
|
||||
pre_regalloc_matcher = pre_regalloc_matcher
|
||||
late_regalloc_matcher = late_regalloc_matcher
|
||||
post_regalloc_matcher = post_regalloc_matcher
|
||||
isa_spec = isa_spec
|
||||
code_for_op = {x: lambda: None for x in (Ops.SQRT, Ops.AND, Ops.OR, Ops.SHL, Ops.SHR, Ops.NEG, Ops.SUB, Ops.FDIV, Ops.CMPLT, Ops.CMPEQ)}
|
||||
def __init__(self, target:Target):
|
||||
super().__init__(target)
|
||||
from tinygrad.runtime.support.compiler_cpu import X86Compiler
|
||||
self.compiler = X86Compiler()
|
||||
def callee_saved(self):
|
||||
ordered = (RSP,) + tuple(r for r in CALLEE_SAVED if r is not RSP)
|
||||
return tuple(def_reg(dtypes.uint64 if r in GPR else dtypes.float64.vec(2), r) for r in ordered)
|
||||
def is_two_address(self, x:UOp) -> bool: return x.arg in X86GroupOp.TwoAddress
|
||||
# nasty hacks to deal with pointers TODO: rm pointers
|
||||
def copy(self, x:UOp, reg:Register):
|
||||
dt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
|
||||
ret = isel_matcher.rewrite(UOp(Ops.COPY, dt, (x,), tag=reg))
|
||||
assert ret is not None
|
||||
return ret.replace(dtype=x.dtype)
|
||||
|
||||
def spill(self, disp:UOp, x:UOp) -> UOp:
|
||||
nx = x.replace(dtype=dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype)
|
||||
ret = isel_matcher.rewrite(def_reg(dtypes.uint64, RSP).index(disp).store(nx))
|
||||
assert ret is not None
|
||||
return ret.replace(src=(s if s is not nx else x for s in ret.src))
|
||||
|
||||
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp:
|
||||
ndt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
|
||||
ret = isel_matcher.rewrite(def_reg(dtypes.uint64, RSP).index(disp).load(dtype=ndt, tag=reg))
|
||||
assert ret is not None
|
||||
return ret.replace(dtype=x.dtype)
|
||||
|
||||
def asm(self, uops:list[UOp], function_name:str) -> str:
|
||||
def _format_op(x:UOp) -> str: return f" {(o[7:-1] if (o:=str(x.arg))[-1] in ('i', 'm') else o[7:]).lower():7s}"
|
||||
def _format_operands(x:UOp) -> str:
|
||||
def _format(src:tuple[UOp, ...]) -> list[str]:
|
||||
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize if not isinstance(s.dtype, PtrDType) else 8, o) if \
|
||||
(o:=str(s.reg)) in reg_strs else o for s in src if s.op is Ops.CONST or s.reg is not None]
|
||||
def _mem_adress(base:UOp, idx:UOp, disp:UOp) -> str:
|
||||
return f"[{base.reg}" + (f" + {idx.reg}*{base.dtype.itemsize}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"
|
||||
|
||||
if len(x.src) > 3 and x.arg in X86GroupOp.WriteMem: return ", ".join([_mem_adress(*x.src[:3])] + _format(x.src[3:]))
|
||||
elif len(x.src) > 2 and x.arg in X86GroupOp.Rm1st: return ", ".join(_format((x,)) + [_mem_adress(*x.src[:3])] + _format(x.src[3:]))
|
||||
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm2nd: return ", ".join(_format((x, x.src[0])) + [_mem_adress(*x.src[1:4])] + _format(x.src[4:]))
|
||||
return ", ".join(_format((x,) + x.src))
|
||||
|
||||
asm = [f".{function_name}:"]
|
||||
for u in uops:
|
||||
if u.op is not Ops.INS: continue
|
||||
if u.arg is X86Ops.DEFINE_REG: continue
|
||||
if u.arg is X86Ops.LABEL: asm.append(f"{str(u.tag)}:")
|
||||
elif u.arg is X86Ops.RET: asm.append(_format_op(u))
|
||||
else: asm.append(_format_op(u) + " " + _format_operands(u))
|
||||
return "\n".join(asm)
|
||||
|
||||
def render(self, uops:list[UOp]) -> str:
|
||||
targets: dict[str, int] = {}
|
||||
jumps: dict[UOp, int] = {}
|
||||
binary = bytearray()
|
||||
for u in uops:
|
||||
if u.op is not Ops.INS: continue
|
||||
if u.arg is X86Ops.DEFINE_REG: continue
|
||||
if u.arg is X86Ops.LABEL:
|
||||
targets[u.tag] = len(binary)
|
||||
continue
|
||||
if u.arg not in encodings or (l:=encodings[u.arg](u)) is None:
|
||||
raise RuntimeError(f"failed to encode {u.arg} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
||||
binary.extend(l)
|
||||
if u.arg in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE, X86Ops.JGE, X86Ops.JMP): jumps[u] = len(binary)
|
||||
# fixup jump targets now that encoding size is known
|
||||
for u in uops:
|
||||
if (t:=jumps.get(u)) is not None: binary[t-4:t] = (targets[u.tag] - t).to_bytes(4, 'little', signed=True)
|
||||
return binary.hex()
|
||||
|
|
@ -8,6 +8,7 @@ from tinygrad.runtime.support.hcq import CLikeArgsState
|
|||
from tinygrad.renderer.cstyle import ClangJITRenderer
|
||||
from tinygrad.renderer.llvmir import CPULLVMRenderer
|
||||
from tinygrad.renderer.nir import LVPRenderer
|
||||
from tinygrad.renderer.isa.x86 import X86Renderer
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
from tinygrad.uop.ops import sint
|
||||
|
||||
|
|
@ -136,5 +137,5 @@ class CPUDevice(HCQCompiled):
|
|||
def __init__(self, device:str=""):
|
||||
self.tasks:queue.Queue = queue.Queue()
|
||||
CPUWorker(self, self.tasks, thread_id=0).start()
|
||||
renderers:list[type[Renderer]] = [ClangJITRenderer, CPULLVMRenderer, LVPRenderer]
|
||||
renderers:list[type[Renderer]] = [ClangJITRenderer, CPULLVMRenderer, LVPRenderer, X86Renderer]
|
||||
super().__init__(device, CPUAllocator(self), renderers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import ctypes, functools, os, pathlib, re, sys, sysconfig
|
||||
from tinygrad.helpers import ceildiv, getenv, unwrap, DEBUG, OSX, WIN
|
||||
from _ctypes import Array as _CArray, _SimpleCData, _Pointer
|
||||
from typing import TYPE_CHECKING, get_type_hints, get_args, get_origin, overload, Annotated, Any, Generic, Iterable, ParamSpec, TypeVar
|
||||
|
||||
def _do_ioctl(__idir, __base, __nr, __struct, __fd, *args, __payload=None, **kwargs):
|
||||
|
|
@ -34,22 +33,22 @@ if TYPE_CHECKING:
|
|||
from _ctypes import _CData
|
||||
class Array(Generic[T, U], _CData):
|
||||
@overload
|
||||
def __getitem__(self: Array[_SimpleCData[V], Any], key: int) -> V: ...
|
||||
def __getitem__(self: Array[ctypes._SimpleCData[V], Any], key: int) -> V: ...
|
||||
@overload
|
||||
def __getitem__(self: Array[T, Any], key: slice) -> list[T]: ...
|
||||
@overload
|
||||
def __getitem__(self: Array[T, Any], key: int) -> T: ...
|
||||
def __getitem__(self, key) -> Any: ...
|
||||
@overload
|
||||
def __setitem__(self: Array[_SimpleCData[V], Any], key: int, val: V): ...
|
||||
def __setitem__(self: Array[ctypes._SimpleCData[V], Any], key: int, val: V): ...
|
||||
@overload
|
||||
def __setitem__(self: Array[T, Any], key: int, val: T): ...
|
||||
@overload
|
||||
def __setitem__(self: Array[T, Any], key: slice, val: Iterable[T]): ...
|
||||
def __setitem__(self, key, val): ...
|
||||
class POINTER(Generic[T], _Pointer): ...
|
||||
class POINTER(Generic[T], ctypes._Pointer): ...
|
||||
class CFUNCTYPE(Generic[T, P], _CFunctionType): ...
|
||||
class Enum(_SimpleCData):
|
||||
class Enum(ctypes._SimpleCData):
|
||||
@classmethod
|
||||
def get(cls, val:int, default="unknown") -> str: ...
|
||||
@classmethod
|
||||
|
|
@ -80,14 +79,9 @@ else:
|
|||
return val
|
||||
def pointer(obj): return ctypes.pointer(obj)
|
||||
|
||||
def i2b(i:int, sz:int) -> bytes: return i.to_bytes(sz, sys.byteorder)
|
||||
def b2i(b:bytes) -> int: return int.from_bytes(b, sys.byteorder)
|
||||
def mv(st) -> memoryview: return memoryview(st).cast('B')
|
||||
|
||||
class Struct(ctypes.Structure):
|
||||
def __init__(self, *args, **kwargs):
|
||||
ctypes.Structure.__init__(self)
|
||||
self._objects_ = {}
|
||||
for f,v in [*zip((rf[0] for rf in self._real_fields_), args), *kwargs.items()]: setattr(self, f, v)
|
||||
|
||||
def record(cls) -> type[Struct]:
|
||||
|
|
@ -98,38 +92,38 @@ def record(cls) -> type[Struct]:
|
|||
def init_records() -> None:
|
||||
for cls, struct, ns in _pending_records:
|
||||
setattr(struct, '_real_fields_', [])
|
||||
for nm, t in get_type_hints(cls, globalns=ns, include_extras=True).items():
|
||||
if t.__origin__ in (bool, bytes, str, int, float): setattr(struct, nm, Field(*(f:=t.__metadata__)))
|
||||
else: setattr(struct, nm, Field(*(f:=(del_an(t.__origin__), *t.__metadata__))))
|
||||
struct._real_fields_.append((nm,) + f) # type: ignore
|
||||
for i, (nm, t) in enumerate(get_type_hints(cls, globalns=ns, include_extras=True).items()):
|
||||
struct._real_fields_.append((nm, *(f:=(del_an(t.__origin__), *t.__metadata__) if isinstance(t.__metadata__[0], int) else t.__metadata__))) # type: ignore
|
||||
setattr(struct, nm, Field(nm, i, *f))
|
||||
_pending_records.clear()
|
||||
|
||||
class Field(property):
|
||||
def __init__(self, typ, off:int, bit_width=None, bit_off=0):
|
||||
if bit_width is not None:
|
||||
sl, set_mask = slice(off,off+(sz:=ceildiv(bit_width+bit_off, 8))), ~((mask:=(1 << bit_width) - 1) << bit_off)
|
||||
class Field:
|
||||
def __init__(self, nm, idx, typ, off, bit_width=None, bit_off=0):
|
||||
self.nm, self.idx, self.typ, self.off, self.bit_width, self.bit_off = nm, idx, typ, off, bit_width, bit_off
|
||||
|
||||
# lazily resolve field descriptors
|
||||
def _resolve(self, cls):
|
||||
if self.bit_width: # handle bitfields ourselves
|
||||
sl, set_mask = slice(self.off, self.off+(sz:=ceildiv(self.bit_width+self.bit_off, 8))), ~((mask:=(1 << self.bit_width) - 1) << self.bit_off)
|
||||
def b2i(obj): return int.from_bytes(memoryview(obj).cast("B")[sl], sys.byteorder)
|
||||
def bset(obj, v): memoryview(obj).cast("B")[sl] = ((b2i(obj) & set_mask) | v << self.bit_off).to_bytes(sz, sys.byteorder)
|
||||
# FIXME: signedness
|
||||
super().__init__(lambda self: (b2i(mv(self)[sl]) >> bit_off) & mask,
|
||||
lambda self,v: mv(self).__setitem__(sl, i2b((b2i(mv(self)[sl]) & set_mask) | (v << bit_off), sz)))
|
||||
else:
|
||||
sl = slice(off, off + ctypes.sizeof(typ))
|
||||
def set_with_objs(f):
|
||||
def wrapper(self, v):
|
||||
if hasattr(v, '_objects') and hasattr(self, '_objects_'): self._objects_[off] = {'_self_': v, **(v._objects or {})}
|
||||
mv(self).__setitem__(sl, bytes(v if isinstance(v, typ) else f(v)))
|
||||
return wrapper
|
||||
if issubclass(typ, _CArray):
|
||||
getter = (lambda self: typ.from_buffer(mv(self)[sl]).value) if typ._type_ is ctypes.c_char else (lambda self: typ.from_buffer(mv(self)[sl]))
|
||||
super().__init__(getter, set_with_objs(lambda v: typ(*v)))
|
||||
else: super().__init__(lambda self: v.value if isinstance(v:=typ.from_buffer(mv(self)[sl]), _SimpleCData) else v, set_with_objs(typ))
|
||||
self.offset = off
|
||||
cf = property(lambda obj: b2i(obj) >> self.bit_off & mask, bset)
|
||||
# pull the CField descriptor from a dummy class, zero length arrays are so ctypes manages references to child objects for us
|
||||
else: cf = type(self.nm, (ctypes.Structure,), {"_layout_": "ms", "_pack_": 1, "_fields_": [(str(i), ctypes.c_byte * 0) for i in range(self.idx)] +
|
||||
[("_", ctypes.c_byte * self.off), ("v", self.typ)]}).v # type: ignore
|
||||
setattr(cls, self.nm, cf)
|
||||
return cf
|
||||
|
||||
def __get__(self, obj, objtype=None): return self._resolve(objtype).__get__(obj, objtype) if objtype else self
|
||||
def __set__(self, obj, value): self._resolve(obj.__class__).__set__(obj, value)
|
||||
|
||||
@functools.cache
|
||||
def init_c_struct_t(sz:int, fields: tuple[tuple, ...]):
|
||||
CStruct = type("CStruct", (Struct,), {'_fields_': [('_mem_', ctypes.c_byte * sz)], '_real_fields_': []})
|
||||
for nm,ty,*args in fields:
|
||||
setattr(CStruct, nm, Field(*(f:=(del_an(ty), *args))))
|
||||
CStruct._real_fields_.append((nm,) + f) # type: ignore
|
||||
for i,(nm,ty,*args) in enumerate(fields):
|
||||
CStruct._real_fields_.append((nm, *(f:=(del_an(ty), *args)))) # type: ignore
|
||||
setattr(CStruct, nm, Field(nm, i, *f))
|
||||
return CStruct
|
||||
def init_c_var(ty, creat_cb): return (creat_cb(v:=del_an(ty)()), v)[1]
|
||||
|
||||
|
|
|
|||
|
|
@ -91,3 +91,8 @@ class CPULLVMCompiler(LLVMCompiler):
|
|||
# +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx
|
||||
cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures())
|
||||
super().__init__(cpu.decode(), feats.decode(), cache_key)
|
||||
|
||||
class X86Compiler(Compiler):
|
||||
def __init__(self): super().__init__(None)
|
||||
def compile(self, src:str) -> bytes: return bytes.fromhex(src)
|
||||
def disassemble(self, lib:bytes): return capstone_flatdump(lib)
|
||||
|
|
@ -473,6 +473,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
|
||||
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
def ins(self, arg, **kwargs): return UOp(Ops.INS, kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), arg, kwargs.pop("tag", self.tag))
|
||||
def contract(self, *rngs:UOp):
|
||||
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
|
||||
return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs))
|
||||
|
|
@ -537,6 +538,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
for s in self.src: yield from s.split_uop(sep)
|
||||
else: yield self
|
||||
|
||||
@property
|
||||
def reg(self:UOp):
|
||||
# TODO: add a way to access the nth element in src, sea of nodes call this a projection
|
||||
if self.op in (Ops.NOOP, Ops.AFTER) and self.src: return self.src[0].reg
|
||||
if isinstance(self.tag, tuple): return self.tag[0]
|
||||
return self.tag
|
||||
|
||||
# *** multi-device helpers ***
|
||||
|
||||
def multi(self, axis:int|None):
|
||||
|
|
@ -1010,10 +1018,13 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
|||
# ***** uop helpers *****
|
||||
|
||||
def print_uops(uops:list[UOp]):
|
||||
def format_tag(u:UOp) -> str: return "" if u.tag is None else str(u.tag)
|
||||
uops_index = {u:i for i,u in enumerate(uops)}
|
||||
tag_width = max((len(format_tag(u)) for u in uops), default=0)
|
||||
for i,u in enumerate(uops):
|
||||
formatted_srcs = [(uops_index[x] if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
||||
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
|
||||
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} "
|
||||
f"{str(formatted_srcs):32s} {format_tag(u):{tag_width}s} {u.arg}")
|
||||
|
||||
# ***** pattern matcher *****
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
|
|||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.SHAPED_WMMA: "#FF5B5B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.RANGE: "#76349c", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
|
|
@ -104,7 +104,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
|
||||
if u.op in {Ops.VECTORIZE, Ops.NOOP} and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
for u in toposort:
|
||||
|
|
@ -288,18 +288,22 @@ metrics:dict[str, Callable[[dict[str, tuple[int, int, int]]], str]] = {
|
|||
|
||||
def unpack_pmc(e) -> dict:
|
||||
agg_cols = ["Name", "Sum"]
|
||||
sample_cols = ["XCC", "INST", "SE", "SA", "WGP", "Value"]
|
||||
rows:list[list] = []
|
||||
stats:dict[str, tuple[int, int, int]] = {} # name -> (sum, max, count)
|
||||
view, ptr = memoryview(e.blob).cast('Q'), 0
|
||||
for s in e.sched:
|
||||
sample_cols = ["XCC", "INST", "SE", "SA"] + [f"WGP:{i}" for i in range(s.wgp)]
|
||||
row:list = [s.name, 0, {"cols":sample_cols, "rows":[]}]
|
||||
max_val, cnt = 0, 0
|
||||
for sample in itertools.product(range(s.xcc), range(s.inst), range(s.se), range(s.sa), range(s.wgp)):
|
||||
row[1] += (val:=int(view[ptr]))
|
||||
max_val, cnt = max(max_val, val), cnt + 1
|
||||
row[2]["rows"].append(sample+(val,))
|
||||
ptr += 1
|
||||
for sample in itertools.product(range(s.xcc), range(s.inst), range(s.se), range(s.sa)):
|
||||
vals:list[int] = []
|
||||
# pack work group processors on the same se
|
||||
for _ in range(s.wgp):
|
||||
row[1] += (val:=int(view[ptr]))
|
||||
max_val, cnt = max(max_val, val), cnt + 1
|
||||
vals.append(val)
|
||||
ptr += 1
|
||||
row[2]["rows"].append(sample+tuple(vals))
|
||||
stats[s.name] = (row[1], max_val, cnt)
|
||||
rows.append(row)
|
||||
for name, fn in metrics.items():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue