mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
x86 cleanups (fable) [pr] (#16591)
* x86 cleanups (fable)
* support shrink
* remove ptr dtype
* move that
* is_lane helper
* Revert "is_lane helper"
This reverts commit ea4571254d.
This commit is contained in:
parent
97c2e7a3d9
commit
b05bea81ce
6 changed files with 170 additions and 166 deletions
|
|
@ -14,49 +14,51 @@ class TestEncodingsX86(unittest.TestCase):
|
|||
|
||||
# displacement of 0 isn't emitted
|
||||
def test_base_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RDI)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RDI)
|
||||
# mov edi, dword ptr [rdi]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 3F"))
|
||||
|
||||
# rsp/r12 require a sib byte when used as base memory address
|
||||
def test_rsp_base_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RSP)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RSP)
|
||||
# mov esp, dword ptr [rsp]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 24 24"))
|
||||
|
||||
# rbp/r13 require a displacement when used as base memory address
|
||||
def test_rbp_base_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RBP)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RBP)
|
||||
# mov ebp, dword ptr [rbp + 0]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 6D 00"))
|
||||
|
||||
# test [base + index*scale]
|
||||
def test_base_index_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0)), RAX)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
|
||||
# mov eax, dword ptr [rax + rdx*4]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 04 90"))
|
||||
|
||||
# rsp as index means no index
|
||||
def test_rsp_index_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0)), RAX)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
|
||||
# mov eax, dword ptr [rax]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 00"))
|
||||
|
||||
# however r12 is a valid index
|
||||
def test_r12_index_address(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0)), RAX)
|
||||
load = ins(X86Ops.MOV, dtypes.int32,
|
||||
(def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
|
||||
# mov eax, dword ptr [rax + r12*4]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("42 8B 04 A0"))
|
||||
|
||||
# test [base + index*scale + 8bit disp]
|
||||
def test_complex_address_8bit_disp(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)), RDI)
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4)), RDI)
|
||||
# mov edi, dword ptr [rdi + rsi*4 + 0xa]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 7C B7 0A"))
|
||||
|
||||
# test [base + index*scale + 32bit disp]
|
||||
def test_complex_address_32bit_disp(self):
|
||||
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000)), RDI)
|
||||
load = ins(X86Ops.MOV, dtypes.int32,
|
||||
(def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000), imm(dtypes.uint8, 4)), RDI)
|
||||
# mov edi, dword ptr [rdi + rsi*4 + 0x2710]
|
||||
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B BC B7 10 27 00 00"))
|
||||
|
||||
|
|
@ -114,28 +116,28 @@ class TestEncodingsX86(unittest.TestCase):
|
|||
|
||||
# when writting to mem the uop takes the store form where dtype is void and there's no definition
|
||||
def test_write_mem(self):
|
||||
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
|
||||
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4))
|
||||
xmm0 = def_reg(dtypes.float32, XMM[0])
|
||||
extr = ins(X86Ops.VPEXTRD, dtypes.void, (base, index, disp, xmm0, imm(dtypes.uint8, 0)))
|
||||
extr = ins(X86Ops.VPEXTRD, dtypes.void, address + (xmm0, imm(dtypes.uint8, 0)))
|
||||
# vpextrd dword ptr [rdi + rsi*4 + 0xa], xmm0, 0
|
||||
self.assertEqual(bytes.fromhex(self.encode(extr)), bytes.fromhex("C4 E3 79 16 44 B7 0A 00"))
|
||||
|
||||
# test two address instruction with fused load works
|
||||
def test_two_address_load(self):
|
||||
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
|
||||
cmove = ins(X86Ops.CMOVE, dtypes.int32, (base, index, disp), RAX)
|
||||
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4))
|
||||
cmove = ins(X86Ops.CMOVE, dtypes.int32, address, RAX)
|
||||
# cmove eax, dword ptr [rdi + rsi*4 + 0xa]
|
||||
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 44 B7 0A"))
|
||||
|
||||
# test instruction where displacement and imm have the same value
|
||||
def test_disp_imm_same_value(self):
|
||||
base, index, disp = def_reg(dtypes.int8.ptr(), RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10)
|
||||
mov = ins(X86Ops.MOVi, dtypes.void, (base, index, disp, disp))
|
||||
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 1))
|
||||
mov = ins(X86Ops.MOVi, dtypes.void, address + (imm(dtypes.int8, 10),))
|
||||
# mov byte ptr [rdi + rsi + 0xa], 0xa
|
||||
self.assertEqual(bytes.fromhex(self.encode(mov)), bytes.fromhex("40 C6 44 37 0A 0A"))
|
||||
|
||||
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10)
|
||||
imul = ins(X86Ops.IMULi, dtypes.int32, (base, index, disp) + (imm(dtypes.int32, 10),), RDI)
|
||||
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10), imm(dtypes.uint8, 4))
|
||||
imul = ins(X86Ops.IMULi, dtypes.int32, address + (imm(dtypes.int32, 10),), RDI)
|
||||
# imul edi, dword ptr [rdi + rsi*4 + 0xa], 0xa
|
||||
self.assertEqual(bytes.fromhex(self.encode(imul)), bytes.fromhex("69 BC B7 0A 00 00 00 0A 00 00 00"))
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
|
|||
from tinygrad.renderer.isa.x86 import X86Renderer, X86Ops
|
||||
from tinygrad.renderer.isa import IselContext
|
||||
|
||||
# INDEX on a register value with a constant index extracts a single element (the old GEP)
|
||||
def lane(y:UOp, i:int) -> UOp: return y.index(UOp.const(dtypes.int, i), dtype=y.dtype.scalar())
|
||||
|
||||
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only x86")
|
||||
class TestIselX86(unittest.TestCase):
|
||||
def isel_rewrite(self, x:UOp):
|
||||
|
|
@ -57,9 +60,9 @@ class TestIselX86(unittest.TestCase):
|
|||
# need to move src from gpr to xmm before broadcasting
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and n.src[0].arg is X86Ops.VMOVD)
|
||||
# if we can fuse a load we can skip the move and access memory directly
|
||||
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
|
||||
n = self.isel_rewrite(load.broadcast(4))
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 4)
|
||||
|
||||
def test_vbroadcastss(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32)
|
||||
|
|
@ -73,17 +76,17 @@ class TestIselX86(unittest.TestCase):
|
|||
d = UOp.variable("d", 0, 0, dtypes.float32)
|
||||
|
||||
valid = [UOp.vectorize(c, c, d, d),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), c, c),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(1), a.gep(2), a.gep(3), a.gep(0)),
|
||||
UOp.vectorize(a.gep(3), a.gep(2), a.gep(1), a.gep(0), a.gep(7), a.gep(6), a.gep(5), a.gep(4)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), b.gep(1), b.gep(1), a.gep(4), a.gep(4), b.gep(5), b.gep(5))]
|
||||
UOp.vectorize(lane(a, 0), lane(a, 1), c, c),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 1), lane(a, 2), lane(a, 3), lane(a, 0)),
|
||||
UOp.vectorize(lane(a, 3), lane(a, 2), lane(a, 1), lane(a, 0), lane(a, 7), lane(a, 6), lane(a, 5), lane(a, 4)),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 0), lane(b, 1), lane(b, 1), lane(a, 4), lane(a, 4), lane(b, 5), lane(b, 5))]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
|
||||
|
||||
invalid = [UOp.vectorize(a.gep(0), a.gep(1), b.gep(4), b.gep(5)),
|
||||
UOp.vectorize(a.gep(0), a.gep(5), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), a.gep(0), a.gep(0), a.gep(4), a.gep(4), a.gep(4), a.gep(5)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), b.gep(0), b.gep(0), a.gep(4), a.gep(4), b.gep(4), a.gep(4))]
|
||||
invalid = [UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 4), lane(b, 5)),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 5), lane(b, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 0), lane(a, 0), lane(a, 0), lane(a, 4), lane(a, 4), lane(a, 4), lane(a, 5)),
|
||||
UOp.vectorize(lane(a, 0), lane(a, 0), lane(b, 0), lane(b, 0), lane(a, 4), lane(a, 4), lane(b, 4), lane(a, 4))]
|
||||
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
|
||||
|
||||
def test_vshufpd(self):
|
||||
|
|
@ -93,16 +96,16 @@ class TestIselX86(unittest.TestCase):
|
|||
d = UOp.variable("d", 0, 0, dtypes.float64)
|
||||
|
||||
valid = [UOp.vectorize(c, d),
|
||||
UOp.vectorize(a.gep(0), c),
|
||||
UOp.vectorize(a.gep(1), b.gep(1)),
|
||||
UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(1), a.gep(1), a.gep(3), a.gep(3))]
|
||||
UOp.vectorize(lane(a, 0), c),
|
||||
UOp.vectorize(lane(a, 1), lane(b, 1)),
|
||||
UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 1), lane(a, 1), lane(a, 3), lane(a, 3))]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
|
||||
|
||||
invalid = [UOp.vectorize(c, c, c, c),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(2), b.gep(3), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(0), b.gep(1), a.gep(0), b.gep(1))]
|
||||
UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 2), lane(b, 3), lane(a, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 0), lane(b, 1))]
|
||||
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
|
||||
|
||||
def test_vinsertps(self):
|
||||
|
|
@ -111,31 +114,31 @@ class TestIselX86(unittest.TestCase):
|
|||
c = UOp.variable("c", 0, 0, dtypes.float32.vec(4))
|
||||
d = UOp.variable("e", 0, 0, dtypes.float32)
|
||||
# moving 0th element to position 0 does nothing so only 1 vinsertps is generated
|
||||
n = self.isel_rewrite(UOp.vectorize(a.gep(0), d))
|
||||
n = self.isel_rewrite(UOp.vectorize(lane(a, 0), d))
|
||||
self.assertIs(n.arg, X86Ops.VINSERTPS)
|
||||
self.assertIsNot(n.src[0].arg, X86Ops.VINSERTPS)
|
||||
|
||||
valid = [UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(3), b.gep(2), c.gep(1), d)]
|
||||
valid = [UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 2), lane(b, 3)),
|
||||
UOp.vectorize(lane(a, 3), lane(b, 2), lane(c, 1), d)]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VINSERTPS)
|
||||
|
||||
# complex address is [base + index*scale + displacement]
|
||||
def test_complex_address(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
load = UOp.param(0, dtypes.int32.ptr()).index(a + 1, ptr=True).load()
|
||||
load = UOp.param(0, dtypes.int32, (16,)).index(a + 1).load()
|
||||
n = self.isel_rewrite(load)
|
||||
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
|
||||
self.assertTrue(n.src[2].op is Ops.CONST and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4)
|
||||
|
||||
def test_fold_load(self):
|
||||
load1 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load2 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 1), ptr=True).load()
|
||||
load1 = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
|
||||
load2 = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 1)).load()
|
||||
n = self.isel_rewrite(load1 + load2)
|
||||
self.assertTrue(len(n.src) == 4)
|
||||
self.assertTrue(len(n.src) == 5)
|
||||
|
||||
# don't fold when used multiple times
|
||||
def test_dont_fold_load(self):
|
||||
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
|
||||
# used by multiple users
|
||||
n = self.isel_rewrite(load + 1 + load)
|
||||
self.assertTrue(len(n.src) == 2)
|
||||
|
|
|
|||
|
|
@ -175,6 +175,8 @@ def do_linearize(ctx:Renderer, prg:UOp, sink:UOp) -> UOp:
|
|||
# isa renderers need to allocate registers
|
||||
if isinstance(ctx, ISARenderer):
|
||||
if ctx.pre_regalloc_matcher is not None: lst = line_rewrite(lst, ctx.pre_regalloc_matcher, PreRegAllocContext())
|
||||
# register definitions (INS without srcs) move to the top so regalloc sees their live ranges span the whole program (callee saved regs)
|
||||
lst = sorted(lst, key=lambda u: u.op is not Ops.INS or bool(u.src))
|
||||
regalloc_ctx = LinearScanRegallocContext(lst, ctx)
|
||||
lst = line_rewrite(lst, pm_regalloc_rewrite, regalloc_ctx)
|
||||
lst = line_rewrite(lst, ctx.post_regalloc_matcher, regalloc_ctx)
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@ def linearize(sink:UOp) -> list[UOp]:
|
|||
# the order and placement of these defines is important
|
||||
case Ops.PARAM: priority, extra = -20, u.arg.slot
|
||||
case Ops.BUFFER: priority = -18
|
||||
case Ops.DEFINE_REG: priority = -18
|
||||
case Ops.DEFINE_LOCAL: priority = -17
|
||||
case Ops.LOAD: priority = -1 # place loads early
|
||||
case Ops.STORE: priority = 1 # place stores late
|
||||
case Ops.RANGE: priority = 5 # placing RANGE is good
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import itertools
|
|||
from tinygrad.helpers import dedup
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
||||
from tinygrad.renderer.isa import ISARenderer, Register
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP, Ops.STACK}
|
||||
|
||||
|
|
@ -49,8 +49,9 @@ class LinearScanRegallocContext:
|
|||
# assign register to spilled virtual and record load to be emitted before current uop, also assign it a stack slot
|
||||
def fill(v:Register, i:int, cons:tuple[Register, ...]|None=None) -> Register:
|
||||
if v not in self.spills:
|
||||
# the value of a BUFFER is its 64bit address
|
||||
dt = self.vdef(v).dtype
|
||||
sz = dt.scalar().itemsize * dt.count if not isinstance(dt, PtrDType) else 8
|
||||
sz = 8 if self.vdef(v).op is Ops.BUFFER else dt.scalar().itemsize * dt.count
|
||||
offset = self.stack_size + (sz - self.stack_size % sz) % sz
|
||||
self.spills[v] = UOp.const(dtypes.int32, offset)
|
||||
self.stack_size = offset + sz
|
||||
|
|
@ -82,10 +83,10 @@ class LinearScanRegallocContext:
|
|||
live[v] = alloc(cons, i+1 if u.op is not Ops.RANGE else i)
|
||||
self.reals.setdefault(i, {})[v] = live[v]
|
||||
|
||||
# allocate stack array
|
||||
if u.op is Ops.DEFINE_LOCAL:
|
||||
# allocate stack array, BUFFER size is in src[0]
|
||||
if u.op is Ops.BUFFER:
|
||||
self.locals[u] = UOp.const(dtypes.int32, self.stack_size)
|
||||
self.stack_size += u.dtype.nbytes()
|
||||
self.stack_size += u.src[0].arg * u.dtype.itemsize
|
||||
|
||||
# loop prologue, avoid loading inside the loop
|
||||
if u.op is Ops.RANGE:
|
||||
|
|
@ -116,7 +117,7 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
|
|||
if i in ctx.reals and (v:=ctx.uops[i].src[j].reg) in ctx.spills: nsrc.append(ctx.ren.fill(ctx.spills[v], ctx.vdef(v), ctx.reals[i][v]))
|
||||
else: nsrc.append(s)
|
||||
ndefs = tuple(ctx.reals[i][v] for v in x.tag) if isinstance(x.tag, tuple) else x.tag
|
||||
if x.op is Ops.DEFINE_LOCAL: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], dtype=x.dtype, tag=ndefs))
|
||||
if x.op is Ops.BUFFER: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], tag=ndefs))
|
||||
else: nx = x.replace(src=tuple(nsrc), tag=ndefs)
|
||||
|
||||
before = [ctx.ren.fill(ctx.spills[v], ctx.vdef(v), r) for v,r in ctx.insert_before.get(i, [])]
|
||||
|
|
@ -132,6 +133,5 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
|
|||
return nx, before + [nx] + after
|
||||
|
||||
pm_regalloc_rewrite = PatternMatcher([
|
||||
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
|
||||
regalloc_rewrite),
|
||||
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.BUFFER, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"), regalloc_rewrite),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
# allow semicolons to put multiple ops on one line
|
||||
import sys, struct, functools
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, DType, truncate, AddrSpace
|
||||
from tinygrad.dtype import dtypes, DType, truncate, AddrSpace
|
||||
from tinygrad.uop import FastEnum, auto, Ops, GroupOp
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, ParamArg
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext, Register, PreRegAllocContext
|
||||
from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
|
||||
|
||||
|
|
@ -12,8 +12,8 @@ from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
|
|||
|
||||
class X86Ops(FastEnum):
|
||||
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
|
||||
# these aren't real instructions
|
||||
FRAME_INDEX = auto(); LABEL = auto()
|
||||
# these aren't real instructions, DEFINE is a register placeholder that defines a register without emitting an instruction
|
||||
FRAME_INDEX = auto(); LABEL = auto(); DEFINE = auto()
|
||||
# index
|
||||
LEA = auto()
|
||||
# register / memory / immediate moves
|
||||
|
|
@ -171,50 +171,32 @@ extra_matcher = PatternMatcher([
|
|||
(UPat(Ops.CMOD, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x - y * x.alu(Ops.CDIV, y)),
|
||||
])
|
||||
|
||||
# ***** X86 new style -> x86 internal style (pointers, vec dtypes, GEP) *****
|
||||
|
||||
pm_x86_style = PatternMatcher([
|
||||
# buffers are pointers, scalar PARAMs (variables) keep their shape src
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: x.replace(dtype=x.dtype.ptr(x.src[0].arg), src=()) \
|
||||
if x.arg.addrspace is AddrSpace.GLOBAL and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.BUFFER, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG if x.arg.addrspace == AddrSpace.REG else Ops.DEFINE_LOCAL,
|
||||
dtype=x.dtype.ptr(x.src[0].arg, x.arg.addrspace), src=(), arg=x.arg.slot)),
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(dtype=x.src[0].dtype) if x.dtype != x.src[0].dtype else None),
|
||||
# SHRINK is a vectorized INDEX
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx"), UPat.cvar("c"))), lambda buf,idx,c: buf.index(idx, ptr=True) \
|
||||
.cast(buf.ptrdtype.base.vec(c.arg).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) if isinstance(buf.dtype, PtrDType) else None),
|
||||
# cast of a pointer is a noop in new style (any reinterpreting cast was absorbed into SHRINK)
|
||||
(UPat(Ops.CAST, src=(UPat.var("y"),), name="x"), lambda x,y:
|
||||
y if isinstance(y.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
|
||||
# INDEX on a pointer has pointer dtype, INDEX on a register value is a GEP
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat()), name="x"), lambda buf,x:
|
||||
x.replace(dtype=buf.dtype) if isinstance(buf.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("y"), UPat.cvar("c")), name="x"), lambda y,c,x:
|
||||
y.gep(c.arg) if not isinstance(y.dtype, PtrDType) and y.op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else None),
|
||||
# restore vec dtypes from structure
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.CAST, name="c"),), allow_any_len=True, name="x"), lambda x,c:
|
||||
x.replace(dtype=x.dtype.scalar().vec(c.ptrdtype.base.count)) if isinstance(c.dtype, PtrDType) and c.ptrdtype.base.count > x.dtype.count else None),
|
||||
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
|
||||
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
|
||||
if not isinstance(x.dtype, PtrDType) and not any(isinstance(s.dtype, PtrDType) for s in x.src) \
|
||||
and (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
|
||||
])
|
||||
|
||||
# ***** X86 pre instruction selection *****
|
||||
|
||||
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count, AddrSpace.LOCAL), arg=next(ctx))
|
||||
local_idx = local.index(UOp.const(dtypes.int32, 0), ptr=True)
|
||||
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
|
||||
return ptr.cast(cast.dtype).load(dtype=x.dtype)
|
||||
def scratch_buffer(elem_dt:DType, count:int, slot:int) -> UOp:
|
||||
return UOp(Ops.BUFFER, elem_dt, src=(UOp.const(dtypes.int, count),), arg=ParamArg(slot, addrspace=AddrSpace.LOCAL))
|
||||
|
||||
def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count, AddrSpace.LOCAL), arg=-1)
|
||||
ptr = gate.where(base.index(idx, ptr=True), local.index(UOp.const(dtypes.int32, 0), ptr=True))
|
||||
return ptr.cast(cast.dtype).store(val)
|
||||
def gated_load(ctx, addr:UOp, alt:UOp, gate:UOp, x:UOp):
|
||||
local = scratch_buffer(addr.src[0].dtype.scalar(), x.dtype.count, next(ctx))
|
||||
local_idx = local.index(UOp.const(dtypes.int32, 0), dtype=dtypes.uint64)
|
||||
# the selected address is a 64bit value, the AFTER orders the load after the scratch store and carries the element dtype for the encoder
|
||||
sel = gate.where(addr.replace(dtype=dtypes.uint64), local_idx)
|
||||
ptr = UOp(Ops.AFTER, addr.dtype, (sel, (local_idx if x.dtype.count == 1 else local).store(alt)))
|
||||
return ptr.load(dtype=x.dtype)
|
||||
|
||||
# these must be done in a separate matcher because they violate the spec
|
||||
pre_isel_matcher = pm_x86_style + PatternMatcher([
|
||||
def gated_store(addr:UOp, gate:UOp, val:UOp):
|
||||
local = scratch_buffer(addr.src[0].dtype.scalar(), val.dtype.count, -1)
|
||||
sel = gate.where(addr.replace(dtype=dtypes.uint64), local.index(UOp.const(dtypes.int32, 0), dtype=dtypes.uint64))
|
||||
return UOp(Ops.AFTER, addr.dtype, (sel,)).store(val)
|
||||
|
||||
# legalize the new style graph for isel. NOTE: this runs after the spec is verified, some of these rewrites violate it
|
||||
pre_isel_matcher = PatternMatcher([
|
||||
# x86 registers are typed by their width, materialize the structural width of the graph into vec dtypes (this is still valid new style)
|
||||
(UPat(Ops.SHRINK, src=(UPat(), UPat(), UPat.cvar("c"))).load(allow_any_len=True, name="x"), lambda x,c:
|
||||
x.replace(dtype=x.dtype.scalar().vec(c.arg)) if c.arg > x.dtype.count else None),
|
||||
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
|
||||
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
|
||||
if (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
|
||||
# zero extending scalar 32bit int is a noop
|
||||
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
|
||||
# cast between signed and unsigned int is a noop
|
||||
|
|
@ -229,11 +211,12 @@ pre_isel_matcher = pm_x86_style + PatternMatcher([
|
|||
# noop of a noop is removed
|
||||
(UPat(Ops.NOOP, src=(UPat(Ops.NOOP),), name="x"), lambda x: x.replace(src=x.src[0].src)),
|
||||
# moving elements of a single register to another without shuffling is a noop
|
||||
(UPat(Ops.STACK, src=(UPat.var("y"),), allow_any_len=True, name="x"),
|
||||
lambda y,x: UOp(Ops.NOOP, x.dtype, y.src) if all(s.op is Ops.GEP and s.src == y.src and s.arg[0] == i for i,s in enumerate(x.src)) else None),
|
||||
# gated load/store become a conditional move on the index, the load/store are unconditional
|
||||
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
|
||||
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").store(UPat.var("val"), UPat.var("gate")), gated_store),
|
||||
(UPat(Ops.STACK, src=(UPat.var("y").index(UPat()),), allow_any_len=True, name="x"),
|
||||
lambda y,x: UOp(Ops.NOOP, x.dtype, (y,)) if all(s.op is Ops.INDEX and len(s.src) == 2 and s.src[0] is y \
|
||||
and s.src[1].op is Ops.CONST and s.src[1].arg == i for i,s in enumerate(x.src)) else None),
|
||||
# gated load/store become a conditional move on the address, the load/store are unconditional
|
||||
(UPat((Ops.INDEX, Ops.SHRINK), name="addr").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
|
||||
(UPat((Ops.INDEX, Ops.SHRINK), name="addr").store(UPat.var("val"), UPat.var("gate")), gated_store),
|
||||
# TODO: remove this once we allow all flag producing ops in cmove
|
||||
# if gate in scalar int cmove is not a comparison need to add one to set the flag
|
||||
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
|
||||
|
|
@ -264,10 +247,10 @@ reg_strs = {"rax": {4:"eax", 2:"ax", 1:"al"}, "rcx": {4:"ecx", 2:"cx", 1:"cl"},
|
|||
# ***** X86 instruction selection *****
|
||||
# if s is used multiple times we don't fold
|
||||
def is_foldable(ctx:IselContext, x:UOp, s:UOp) -> bool: return len(ctx.uses[s]) == x.src.count(s) == 1
|
||||
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
|
||||
def lane(x:UOp, i:int) -> int: return s.arg[0] if (s:=x.src[i]).op is Ops.GEP else 0
|
||||
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.INDEX else s
|
||||
def lane(x:UOp, i:int) -> int: return s.src[1].arg if (s:=x.src[i]).op is Ops.INDEX else 0
|
||||
def to_int(dt:DType): return {dtypes.float16: dtypes.int16, dtypes.float32: dtypes.int32, dtypes.float64: dtypes.int64}[dt]
|
||||
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.DEFINE_REG, dt, tag=None if reg is None else (reg,))
|
||||
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, dt, arg=X86Ops.DEFINE, tag=None if reg is None else (reg,))
|
||||
def imm(dt:DType, v:int) -> UOp: return UOp.const(dt, truncate[dt](v)).rtag()
|
||||
def to_imm(c:UOp) -> UOp|None:
|
||||
if c.op is not Ops.CONST: return None
|
||||
|
|
@ -348,41 +331,52 @@ def idiv(ctx:IselContext, x:UOp) -> UOp:
|
|||
# this move "cleanses" the register constraints (rax/rdx) of idiv as that only applies on definition and not on the uses of idiv
|
||||
return x.ins(X86Ops.MOV, src=(idiv,))
|
||||
|
||||
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp]:
|
||||
# a memory address operand is (base, index, displacement, size). size is the element size, it scales the index and is the memory operand width.
|
||||
# it is materialized as an immediate so the address stays correct if the base register is ever spilled and refilled
|
||||
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp, UOp]:
|
||||
def _disp(v:int) -> UOp: return imm(dtypes.int32 if abs(v) > dtypes.int8.max else dtypes.int8, v)
|
||||
def _cast(v:UOp) -> UOp: return v.cast(dtypes.int64) if v.vmin < 0 else v
|
||||
if x.op is not Ops.INDEX: return (x, UOp(Ops.NOOP), _disp(0))
|
||||
base, idx = x.src
|
||||
disp_scale = base.dtype.itemsize if isinstance(base.dtype, PtrDType) else 1
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * disp_scale))
|
||||
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * disp_scale))
|
||||
return (base, _cast(idx), _disp(0))
|
||||
if x.op not in {Ops.INDEX, Ops.SHRINK}: return (x, UOp(Ops.NOOP), _disp(0), imm(dtypes.uint8, x.dtype.itemsize))
|
||||
base, idx = x.src[0], x.src[1]
|
||||
# buffers are indexed by element, everything else (the stack pointer) by byte
|
||||
scale = base.dtype.itemsize if base.op in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else 1
|
||||
sz = imm(dtypes.uint8, base.dtype.itemsize)
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * scale), sz)
|
||||
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * scale), sz)
|
||||
return (base, _cast(idx), _disp(0), sz)
|
||||
|
||||
def abi(ctx:IselContext, x:UOp) -> UOp|None:
|
||||
if isinstance(x.tag, tuple): return None
|
||||
i = ctx.func_args.index(x)
|
||||
def _stack_arg(disp:int): return (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=disp))
|
||||
if sys.platform == "win32": src = (x.replace(tag=((RCX, RDX, GPR[8], GPR[9])[i],)),) if i < 4 else _stack_arg((i-3)*8+32)
|
||||
else: src = (x.replace(tag=((RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i],)),) if i < 6 else _stack_arg((i-5)*8)
|
||||
# buffer params hold addresses, their value moves as a 64bit int
|
||||
dt = dtypes.uint64 if x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.GLOBAL else x.dtype
|
||||
# the shape srcs of a PARAM are not values, tag them so they aren't materialized into registers
|
||||
def _reg_arg(r:Register) -> tuple[UOp, ...]: return (x.replace(dtype=dt, src=tuple(s.rtag() for s in x.src), tag=(r,)),)
|
||||
def _stack_arg(disp:int):
|
||||
return (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=disp), imm(dtypes.uint8, 8))
|
||||
if sys.platform == "win32": src = _reg_arg((RCX, RDX, GPR[8], GPR[9])[i]) if i < 4 else _stack_arg((i-3)*8+32)
|
||||
else: src = _reg_arg((RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i]) if i < 6 else _stack_arg((i-5)*8)
|
||||
# this move "cleanses" the abi register constraint
|
||||
return x.ins(X86Ops.MOV, src=src)
|
||||
return x.ins(X86Ops.MOV, dtype=dt, src=src)
|
||||
|
||||
def alloc_vregs(ctx:IselContext, x:UOp) -> UOp|None:
|
||||
# real registers
|
||||
if x.op is Ops.DEFINE_REG and x.tag is not None: return None
|
||||
# register placeholders with real registers
|
||||
if x.arg is X86Ops.DEFINE and x.tag is not None: return None
|
||||
# this is an immediate
|
||||
if x.arg is X86Ops.FRAME_INDEX: return None
|
||||
# no register definition
|
||||
if x.dtype is dtypes.void: return None
|
||||
# already allocated vregs
|
||||
if isinstance(x.tag, tuple) and x.tag[0]._cons: return None
|
||||
# allocate vreg definitions
|
||||
# allocate vreg definitions, the value of a BUFFER is its address so it lives in a gpr
|
||||
defs = []
|
||||
if isinstance(x.tag, tuple): defs = [ctx.vreg(x.tag)]
|
||||
elif x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType): defs = [ctx.vreg(WGPR)]
|
||||
elif x.op is Ops.BUFFER or x.dtype in dtypes.ints+(dtypes.bool,): defs = [ctx.vreg(WGPR)]
|
||||
elif x.dtype in dtypes.floats or x.dtype.count > 1: defs = [ctx.vreg(XMM)]
|
||||
# TODO: add this once the scheduler can track register pressure
|
||||
# if x.arg in X86GroupOp.WriteFlags: defs.append(ctx.vreg(RFLAGS))
|
||||
# the size src of a BUFFER is not a value, tag it so it isn't materialized into a register
|
||||
if x.op is Ops.BUFFER: return x.replace(src=tuple(s.rtag() for s in x.src), tag=tuple(defs))
|
||||
return x.replace(tag=tuple(defs))
|
||||
|
||||
dts = dtypes.ints + (dtypes.bool, dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
|
|
@ -393,11 +387,12 @@ dt_128bit = tuple(dt.vec(l) for dt in dts for l in [16,8,4,2,1] if l*dt.itemsize
|
|||
|
||||
isel_matcher = PatternMatcher([
|
||||
# **** Op -> Op ****
|
||||
# cast to pointer is a noop
|
||||
(UPat.var("y").cast(name="x"), lambda y,x: y if isinstance(x.dtype, PtrDType) or y.dtype == dtypes.void else None),
|
||||
# float gep(0) is a noop as it just moves the 0th element from one xmm register to another
|
||||
# cast of void is a noop
|
||||
(UPat.var("y").cast(name="x"), lambda y,x: y if y.dtype == dtypes.void else None),
|
||||
# extracting the 0th float element is a noop as it just moves the 0th element from one xmm register to another
|
||||
# this is done here to not interfere with shuffles
|
||||
(UPat(dtype=dtypes.floats).gep(0, name="x"), lambda x: x.replace(op=Ops.NOOP, arg=None)),
|
||||
(UPat(dtype=dtypes.floats).index(UPat(Ops.CONST, arg=0), name="x"),
|
||||
lambda x: x.replace(op=Ops.NOOP, src=x.src[:1]) if x.src[0].dtype.count > 1 else None),
|
||||
# range is lowered to acc, cmp, jmp after regalloc
|
||||
(UPat(Ops.RANGE, src=(UPat.cvar("c"),), allow_any_len=True, name="x"), lambda c,x: x.replace(src=(imm(c.dtype, c.arg),) + x.src[1:])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(tag=(ctx.vreg(WGPR),)) if not isinstance(x.tag, tuple) else None),
|
||||
|
|
@ -409,9 +404,6 @@ isel_matcher = PatternMatcher([
|
|||
if not x.src or x.src[0].arg is not X86Ops.RET else None),
|
||||
# function abi constraints
|
||||
(UPat((Ops.PARAM, Ops.SPECIAL), name="x"), abi),
|
||||
# these are treated the same for now
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x:
|
||||
x.replace(op=Ops.DEFINE_LOCAL, dtype=x.dtype.base.ptr(x.dtype.size, AddrSpace.LOCAL)) if isinstance(x.arg, int) else None),
|
||||
# constants that can't be immediates, move them to registers
|
||||
(UPat.cvar("x", dtypes.int64s), lambda x: x.ins(X86Ops.MOVABS, src=(imm(x.dtype, x.arg),)) if not x.tag else None),
|
||||
(UPat.cvar("x", dtypes.ints+(dtypes.bool,)), lambda x: x.ins(X86Ops.MOVi, src=(imm(x.dtype, x.arg),)) if not x.tag else None),
|
||||
|
|
@ -479,12 +471,17 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(Ops.STACK, dtypes.float32, name="x"), vinsertps),
|
||||
(UPat.var("y", dtypes.ints+(dtypes.bool,)).broadcast(name="x"), vpbroadcast),
|
||||
(UPat(Ops.STACK, dtypes.ints+(dtypes.bool,), name="x"), vpins),
|
||||
# gep
|
||||
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.floats).gep(name="x"), lambda y,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, x.arg[0] * x.dtype.itemsize)))),
|
||||
# INDEX on a vector register value extracts a single element
|
||||
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
|
||||
(UPat.var("y", dtypes.int16s).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
|
||||
(UPat.var("y", dtypes.int32s).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
|
||||
(UPat.var("y", dtypes.int64s).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
|
||||
(UPat.var("y", dtypes.floats).index(UPat.cvar("c"), name="x"),
|
||||
lambda y,c,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, c.arg * x.dtype.itemsize))) if y.dtype.count > 1 else None),
|
||||
# fused multiply add
|
||||
((UPat(Ops.MUL, dtypes.float32, name="a") + UPat.var("b")).named("c"), lambda ctx,a,b,c:
|
||||
a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
|
||||
|
|
@ -578,8 +575,9 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.ins(X86Ops.VMOVQ)),
|
||||
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
|
||||
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
|
||||
# index
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
|
||||
# index on a buffer (or the stack pointer) computes an address, addresses are 64bit values
|
||||
(UPat((Ops.INDEX, Ops.SHRINK), name="x"),
|
||||
lambda x: x.ins(X86Ops.LEA, dtype=dtypes.uint64, src=fold_address(x)) if x.src[0].dtype.count == 1 else None),
|
||||
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
|
||||
# copy, load, store
|
||||
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
|
||||
|
|
@ -608,7 +606,7 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(Ops.INS, src=(UPat(), UPat(), UPat(Ops.LOAD, src=(UPat(name="a"),), name="y")), allow_any_len=True, name="x"), lambda ctx,y,a,x:
|
||||
x.replace(src=x.src[:2] + fold_address(a) + x.src[3:]) if x.arg in X86GroupOp.ReadMem3rd and is_foldable(ctx, x, y) else None),
|
||||
# allocate virtual registers
|
||||
(UPat((Ops.INS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), alloc_vregs),
|
||||
(UPat((Ops.INS, Ops.BUFFER), name="x"), alloc_vregs),
|
||||
])
|
||||
|
||||
# ***** pre register allocation *****
|
||||
|
|
@ -656,14 +654,16 @@ post_regalloc_matcher = PatternMatcher([
|
|||
# ***** X86 instruction encoding *****
|
||||
|
||||
def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) -> bytes|None:
|
||||
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
|
||||
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, sz_uop:UOp|None=None,
|
||||
vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
|
||||
nonlocal reg, opc
|
||||
# get the encoding values of the different fields
|
||||
reg = cast(int, cast(Register, reg_uop.reg).index if reg_uop is not None else reg)
|
||||
rm = cast(Register, rm_uop.reg).index
|
||||
idx = cast(Register, idx_uop.reg).index if idx_uop is not None and idx_uop.reg is not None else 4
|
||||
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else rm_uop.dtype.itemsize
|
||||
reg_sz = (reg_uop.dtype.itemsize if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
|
||||
# for a memory operand the rm size is the element size from the address, otherwise it's the size of the value in the register
|
||||
rm_sz = sz_uop.arg if sz_uop is not None else rm_uop.dtype.itemsize
|
||||
reg_sz = reg_uop.dtype.itemsize if reg_uop is not None else 0
|
||||
sz = reg_sz or rm_sz
|
||||
|
||||
# encode instruction
|
||||
|
|
@ -723,19 +723,19 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) ->
|
|||
# when a uop writes to memory it takes the form of a store, dtype is void, no definition
|
||||
address:tuple[UOp|None, ...]
|
||||
if x.arg in X86GroupOp.WriteMem:
|
||||
if len(x.src) > 3: address, rest = x.src[:3], x.src[3:]
|
||||
else: address, rest = (x, None, None), x.src
|
||||
if len(x.src) > 4: address, rest = x.src[:4], x.src[4:]
|
||||
else: address, rest = (x, None, None, None), x.src
|
||||
return _encode(rest[0], *address, *(None, *rest[1:])) if reg is None else _encode(None, *address, *(None, *rest[:1]))
|
||||
|
||||
if x.arg in X86GroupOp.Rm1st:
|
||||
if len(x.src) > 2: address, rest = x.src[:3], x.src[3:]
|
||||
else: address, rest = (x.src[0], None, None), x.src[1:]
|
||||
if len(x.src) > 3: address, rest = x.src[:4], x.src[4:]
|
||||
else: address, rest = (x.src[0], None, None, None), x.src[1:]
|
||||
imm_uop = rest[:1] if rest and rest[0].op is Ops.CONST else (None,)
|
||||
return _encode(x, *address, *(None, *imm_uop)) if reg is None else _encode(None, *address, *(x if sel else None, *imm_uop))
|
||||
|
||||
if x.arg in X86GroupOp.Rm2nd:
|
||||
if len(x.src) > 3: address, rest = x.src[1:4], x.src[:1] + x.src[4:]
|
||||
else: address, rest = (x.src[1], None, None), x.src[:1] + x.src[2:]
|
||||
if len(x.src) > 4: address, rest = x.src[1:5], x.src[:1] + x.src[5:]
|
||||
else: address, rest = (x.src[1], None, None, None), x.src[:1] + x.src[2:]
|
||||
# cmp/vucomiss reg, rm don't define a new register
|
||||
return _encode(x, *address, *rest) if x.dtype is not dtypes.void else _encode(rest[0], *address)
|
||||
|
||||
|
|
@ -770,8 +770,9 @@ encodings = {
|
|||
X86Ops.VCVTDQ2PS: lambda x: encode(x, 0x5B, pp=0, sel=1), X86Ops.VCVTDQ2PD: lambda x: encode(x, 0xE6, pp=2, sel=1),
|
||||
X86Ops.VCVTPS2PD: lambda x: encode(x, 0x5A, pp=0, sel=1), X86Ops.VCVTPD2PS: lambda x: encode(x, 0x5A, pp=1, sel=1),
|
||||
X86Ops.VCVTTPS2DQ: lambda x: encode(x, 0x5B, pp=2, sel=1), X86Ops.VCVTTPD2DQ: lambda x: encode(x, 0xE6, pp=1, sel=1),
|
||||
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.itemsize == 8),
|
||||
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.itemsize == 8),
|
||||
# the int src is the 2nd src (the rm field), if it was folded into a memory operand its width is the element size of the address
|
||||
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=(x.src[4].arg if len(x.src) > 4 else x.src[1].dtype.itemsize) == 8),
|
||||
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=(x.src[4].arg if len(x.src) > 4 else x.src[1].dtype.itemsize) == 8),
|
||||
X86Ops.VCVTTSS2SI: lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype.itemsize == 8),
|
||||
X86Ops.VCVTTSD2SI: lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype.itemsize == 8),
|
||||
# int division
|
||||
|
|
@ -871,43 +872,41 @@ class X86Renderer(ISARenderer):
|
|||
self.compiler = X86Compiler()
|
||||
def is_two_address(self, x:UOp) -> bool: return x.arg in X86GroupOp.TwoAddress
|
||||
def stack_pointer(self) -> UOp: return def_reg(dtypes.uint64, RSP)
|
||||
# nasty hacks to deal with pointers TODO: rm pointers
|
||||
# the value of a BUFFER is its address, it moves through registers and the stack as a 64bit int
|
||||
def copy(self, x:UOp, reg:Register):
|
||||
dt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
|
||||
ret = isel_matcher.rewrite(UOp(Ops.COPY, dt, (x,), tag=reg))
|
||||
ret = isel_matcher.rewrite(UOp(Ops.COPY, dtypes.uint64 if x.op is Ops.BUFFER else x.dtype, (x,), tag=reg))
|
||||
assert ret is not None
|
||||
return ret.replace(dtype=x.dtype)
|
||||
return ret
|
||||
|
||||
def spill(self, disp:UOp, x:UOp) -> UOp:
|
||||
nx = x.replace(dtype=dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype)
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(nx))
|
||||
if x.op is Ops.BUFFER: x = x.replace(dtype=dtypes.uint64)
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(x))
|
||||
assert ret is not None
|
||||
return ret.replace(src=(s if s is not nx else x for s in ret.src))
|
||||
return ret
|
||||
|
||||
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp:
|
||||
ndt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=ndt, tag=reg))
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=dtypes.uint64 if x.op is Ops.BUFFER else x.dtype, tag=reg))
|
||||
assert ret is not None
|
||||
return ret.replace(dtype=x.dtype)
|
||||
return ret
|
||||
|
||||
def asm_str(self, uops:list[UOp], function_name:str) -> str:
|
||||
def _format_op(x:UOp) -> str: return f" {(o[7:-1] if (o:=str(x.arg))[-1] in ('i', 'm') else o[7:]).lower():7s}"
|
||||
def _format_operands(x:UOp) -> str:
|
||||
def _format(src:tuple[UOp, ...]) -> list[str]:
|
||||
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize if not isinstance(s.dtype, PtrDType) else 8, o) if \
|
||||
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize, o) if \
|
||||
(o:=str(s.reg)) in reg_strs else o for s in src if s.reg is not None]
|
||||
def _mem_adress(base:UOp, idx:UOp, disp:UOp) -> list[str]:
|
||||
return [f"[{base.reg}" + (f" + {idx.reg}*{base.dtype.itemsize}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"]
|
||||
def _mem_adress(base:UOp, idx:UOp, disp:UOp, sz:UOp) -> list[str]:
|
||||
return [f"[{base.reg}" + (f" + {idx.reg}*{sz.arg}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"]
|
||||
|
||||
if len(x.src) > 3 and x.arg in X86GroupOp.WriteMem: ret = _mem_adress(*x.src[:3]) + _format(x.src[3:])
|
||||
elif len(x.src) > 2 and x.arg in X86GroupOp.Rm1st: ret = _format((x,)) + _mem_adress(*x.src[:3]) + _format(x.src[3:])
|
||||
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm2nd: ret = _format((x, x.src[0])) + _mem_adress(*x.src[1:4]) + _format(x.src[4:])
|
||||
if len(x.src) > 4 and x.arg in X86GroupOp.WriteMem: ret = _mem_adress(*x.src[:4]) + _format(x.src[4:])
|
||||
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm1st: ret = _format((x,)) + _mem_adress(*x.src[:4]) + _format(x.src[4:])
|
||||
elif len(x.src) > 4 and x.arg in X86GroupOp.Rm2nd: ret = _format((x, x.src[0])) + _mem_adress(*x.src[1:5]) + _format(x.src[5:])
|
||||
else: ret = _format((x,) + x.src)
|
||||
return ", ".join(ret)
|
||||
|
||||
asm = [f".{function_name}:"]
|
||||
for u in uops:
|
||||
if u.op is not Ops.INS: continue
|
||||
if u.op is not Ops.INS or u.arg is X86Ops.DEFINE: continue
|
||||
if u.arg is X86Ops.LABEL: asm.append(f"{str(u.tag)}:")
|
||||
elif u.arg is X86Ops.RET: asm.append(_format_op(u))
|
||||
else: asm.append(_format_op(u) + " " + _format_operands(u))
|
||||
|
|
@ -918,7 +917,7 @@ class X86Renderer(ISARenderer):
|
|||
jumps: dict[UOp, int] = {}
|
||||
binary = bytearray()
|
||||
for u in uops:
|
||||
if u.op is not Ops.INS: continue
|
||||
if u.op is not Ops.INS or u.arg is X86Ops.DEFINE: continue
|
||||
if u.arg is X86Ops.LABEL:
|
||||
targets[u.tag] = len(binary)
|
||||
continue
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue