x86 cleanups (fable) [pr] (#16591)

* x86 cleanups (fable)

* support shrink

* remove ptr dtype

* move that

* is_lane helper

* Revert "is_lane helper"

This reverts commit ea4571254d.
This commit is contained in:
George Hotz 2026-06-19 09:04:51 -07:00 committed by GitHub
commit b05bea81ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 170 additions and 166 deletions

View file

@ -14,49 +14,51 @@ class TestEncodingsX86(unittest.TestCase):
# displacement of 0 isn't emitted
def test_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RDI)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RDI)
# mov edi, dword ptr [rdi]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 3F"))
# rsp/r12 require a sib byte when used as base memory address
def test_rsp_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RSP)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RSP)
# mov esp, dword ptr [rsp]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 24 24"))
# rbp/r13 require a displacement when used as base memory address
def test_rbp_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RBP)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RBP)
# mov ebp, dword ptr [rbp + 0]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 6D 00"))
# test [base + index*scale]
def test_base_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0)), RAX)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
# mov eax, dword ptr [rax + rdx*4]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 04 90"))
# rsp as index means no index
def test_rsp_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0)), RAX)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
# mov eax, dword ptr [rax]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 00"))
# however r12 is a valid index
def test_r12_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0)), RAX)
load = ins(X86Ops.MOV, dtypes.int32,
(def_reg(dtypes.uint64, RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0), imm(dtypes.uint8, 4)), RAX)
# mov eax, dword ptr [rax + r12*4]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("42 8B 04 A0"))
# test [base + index*scale + 8bit disp]
def test_complex_address_8bit_disp(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)), RDI)
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4)), RDI)
# mov edi, dword ptr [rdi + rsi*4 + 0xa]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 7C B7 0A"))
# test [base + index*scale + 32bit disp]
def test_complex_address_32bit_disp(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000)), RDI)
load = ins(X86Ops.MOV, dtypes.int32,
(def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000), imm(dtypes.uint8, 4)), RDI)
# mov edi, dword ptr [rdi + rsi*4 + 0x2710]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B BC B7 10 27 00 00"))
@ -114,28 +116,28 @@ class TestEncodingsX86(unittest.TestCase):
# when writting to mem the uop takes the store form where dtype is void and there's no definition
def test_write_mem(self):
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4))
xmm0 = def_reg(dtypes.float32, XMM[0])
extr = ins(X86Ops.VPEXTRD, dtypes.void, (base, index, disp, xmm0, imm(dtypes.uint8, 0)))
extr = ins(X86Ops.VPEXTRD, dtypes.void, address + (xmm0, imm(dtypes.uint8, 0)))
# vpextrd dword ptr [rdi + rsi*4 + 0xa], xmm0, 0
self.assertEqual(bytes.fromhex(self.encode(extr)), bytes.fromhex("C4 E3 79 16 44 B7 0A 00"))
# test two address instruction with fused load works
def test_two_address_load(self):
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
cmove = ins(X86Ops.CMOVE, dtypes.int32, (base, index, disp), RAX)
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 4))
cmove = ins(X86Ops.CMOVE, dtypes.int32, address, RAX)
# cmove eax, dword ptr [rdi + rsi*4 + 0xa]
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 44 B7 0A"))
# test instruction where displacement and imm have the same value
def test_disp_imm_same_value(self):
base, index, disp = def_reg(dtypes.int8.ptr(), RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10)
mov = ins(X86Ops.MOVi, dtypes.void, (base, index, disp, disp))
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10), imm(dtypes.uint8, 1))
mov = ins(X86Ops.MOVi, dtypes.void, address + (imm(dtypes.int8, 10),))
# mov byte ptr [rdi + rsi + 0xa], 0xa
self.assertEqual(bytes.fromhex(self.encode(mov)), bytes.fromhex("40 C6 44 37 0A 0A"))
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10)
imul = ins(X86Ops.IMULi, dtypes.int32, (base, index, disp) + (imm(dtypes.int32, 10),), RDI)
address = (def_reg(dtypes.uint64, RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10), imm(dtypes.uint8, 4))
imul = ins(X86Ops.IMULi, dtypes.int32, address + (imm(dtypes.int32, 10),), RDI)
# imul edi, dword ptr [rdi + rsi*4 + 0xa], 0xa
self.assertEqual(bytes.fromhex(self.encode(imul)), bytes.fromhex("69 BC B7 0A 00 00 00 0A 00 00 00"))

View file

@ -6,6 +6,9 @@ from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
from tinygrad.renderer.isa.x86 import X86Renderer, X86Ops
from tinygrad.renderer.isa import IselContext
# INDEX on a register value with a constant index extracts a single element (the old GEP)
def lane(y:UOp, i:int) -> UOp: return y.index(UOp.const(dtypes.int, i), dtype=y.dtype.scalar())
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only x86")
class TestIselX86(unittest.TestCase):
def isel_rewrite(self, x:UOp):
@ -57,9 +60,9 @@ class TestIselX86(unittest.TestCase):
# need to move src from gpr to xmm before broadcasting
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and n.src[0].arg is X86Ops.VMOVD)
# if we can fuse a load we can skip the move and access memory directly
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
n = self.isel_rewrite(load.broadcast(4))
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 4)
def test_vbroadcastss(self):
a = UOp.variable("a", 0, 0, dtypes.float32)
@ -73,17 +76,17 @@ class TestIselX86(unittest.TestCase):
d = UOp.variable("d", 0, 0, dtypes.float32)
valid = [UOp.vectorize(c, c, d, d),
UOp.vectorize(a.gep(0), a.gep(1), c, c),
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(1), a.gep(2), a.gep(3), a.gep(0)),
UOp.vectorize(a.gep(3), a.gep(2), a.gep(1), a.gep(0), a.gep(7), a.gep(6), a.gep(5), a.gep(4)),
UOp.vectorize(a.gep(0), a.gep(0), b.gep(1), b.gep(1), a.gep(4), a.gep(4), b.gep(5), b.gep(5))]
UOp.vectorize(lane(a, 0), lane(a, 1), c, c),
UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 2), lane(b, 3)),
UOp.vectorize(lane(a, 1), lane(a, 2), lane(a, 3), lane(a, 0)),
UOp.vectorize(lane(a, 3), lane(a, 2), lane(a, 1), lane(a, 0), lane(a, 7), lane(a, 6), lane(a, 5), lane(a, 4)),
UOp.vectorize(lane(a, 0), lane(a, 0), lane(b, 1), lane(b, 1), lane(a, 4), lane(a, 4), lane(b, 5), lane(b, 5))]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
invalid = [UOp.vectorize(a.gep(0), a.gep(1), b.gep(4), b.gep(5)),
UOp.vectorize(a.gep(0), a.gep(5), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(0), a.gep(0), a.gep(0), a.gep(0), a.gep(4), a.gep(4), a.gep(4), a.gep(5)),
UOp.vectorize(a.gep(0), a.gep(0), b.gep(0), b.gep(0), a.gep(4), a.gep(4), b.gep(4), a.gep(4))]
invalid = [UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 4), lane(b, 5)),
UOp.vectorize(lane(a, 0), lane(a, 5), lane(b, 2), lane(b, 3)),
UOp.vectorize(lane(a, 0), lane(a, 0), lane(a, 0), lane(a, 0), lane(a, 4), lane(a, 4), lane(a, 4), lane(a, 5)),
UOp.vectorize(lane(a, 0), lane(a, 0), lane(b, 0), lane(b, 0), lane(a, 4), lane(a, 4), lane(b, 4), lane(a, 4))]
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
def test_vshufpd(self):
@ -93,16 +96,16 @@ class TestIselX86(unittest.TestCase):
d = UOp.variable("d", 0, 0, dtypes.float64)
valid = [UOp.vectorize(c, d),
UOp.vectorize(a.gep(0), c),
UOp.vectorize(a.gep(1), b.gep(1)),
UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(1), a.gep(1), a.gep(3), a.gep(3))]
UOp.vectorize(lane(a, 0), c),
UOp.vectorize(lane(a, 1), lane(b, 1)),
UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 2), lane(b, 3)),
UOp.vectorize(lane(a, 1), lane(a, 1), lane(a, 3), lane(a, 3))]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
invalid = [UOp.vectorize(c, c, c, c),
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(2), b.gep(3), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(0), b.gep(1), a.gep(0), b.gep(1))]
UOp.vectorize(lane(a, 0), lane(a, 1), lane(b, 2), lane(b, 3)),
UOp.vectorize(lane(a, 2), lane(b, 3), lane(a, 2), lane(b, 3)),
UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 0), lane(b, 1))]
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
def test_vinsertps(self):
@ -111,31 +114,31 @@ class TestIselX86(unittest.TestCase):
c = UOp.variable("c", 0, 0, dtypes.float32.vec(4))
d = UOp.variable("e", 0, 0, dtypes.float32)
# moving 0th element to position 0 does nothing so only 1 vinsertps is generated
n = self.isel_rewrite(UOp.vectorize(a.gep(0), d))
n = self.isel_rewrite(UOp.vectorize(lane(a, 0), d))
self.assertIs(n.arg, X86Ops.VINSERTPS)
self.assertIsNot(n.src[0].arg, X86Ops.VINSERTPS)
valid = [UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(3), b.gep(2), c.gep(1), d)]
valid = [UOp.vectorize(lane(a, 0), lane(b, 1), lane(a, 2), lane(b, 3)),
UOp.vectorize(lane(a, 3), lane(b, 2), lane(c, 1), d)]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VINSERTPS)
# complex address is [base + index*scale + displacement]
def test_complex_address(self):
a = UOp.variable("a", 0, 0, dtypes.int32)
load = UOp.param(0, dtypes.int32.ptr()).index(a + 1, ptr=True).load()
load = UOp.param(0, dtypes.int32, (16,)).index(a + 1).load()
n = self.isel_rewrite(load)
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
self.assertTrue(n.src[2].op is Ops.CONST and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4)
def test_fold_load(self):
load1 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load2 = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 1), ptr=True).load()
load1 = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
load2 = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 1)).load()
n = self.isel_rewrite(load1 + load2)
self.assertTrue(len(n.src) == 4)
self.assertTrue(len(n.src) == 5)
# don't fold when used multiple times
def test_dont_fold_load(self):
load = UOp.param(0, dtypes.int32.ptr()).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load = UOp.param(0, dtypes.int32, (16,)).index(UOp.const(dtypes.int32, 0)).load()
# used by multiple users
n = self.isel_rewrite(load + 1 + load)
self.assertTrue(len(n.src) == 2)

View file

@ -175,6 +175,8 @@ def do_linearize(ctx:Renderer, prg:UOp, sink:UOp) -> UOp:
# isa renderers need to allocate registers
if isinstance(ctx, ISARenderer):
if ctx.pre_regalloc_matcher is not None: lst = line_rewrite(lst, ctx.pre_regalloc_matcher, PreRegAllocContext())
# register definitions (INS without srcs) move to the top so regalloc sees their live ranges span the whole program (callee saved regs)
lst = sorted(lst, key=lambda u: u.op is not Ops.INS or bool(u.src))
regalloc_ctx = LinearScanRegallocContext(lst, ctx)
lst = line_rewrite(lst, pm_regalloc_rewrite, regalloc_ctx)
lst = line_rewrite(lst, ctx.post_regalloc_matcher, regalloc_ctx)

View file

@ -24,8 +24,6 @@ def linearize(sink:UOp) -> list[UOp]:
# the order and placement of these defines is important
case Ops.PARAM: priority, extra = -20, u.arg.slot
case Ops.BUFFER: priority = -18
case Ops.DEFINE_REG: priority = -18
case Ops.DEFINE_LOCAL: priority = -17
case Ops.LOAD: priority = -1 # place loads early
case Ops.STORE: priority = 1 # place stores late
case Ops.RANGE: priority = 5 # placing RANGE is good

View file

@ -2,7 +2,7 @@ import itertools
from tinygrad.helpers import dedup
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
from tinygrad.renderer.isa import ISARenderer, Register
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.dtype import dtypes
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP, Ops.STACK}
@ -49,8 +49,9 @@ class LinearScanRegallocContext:
# assign register to spilled virtual and record load to be emitted before current uop, also assign it a stack slot
def fill(v:Register, i:int, cons:tuple[Register, ...]|None=None) -> Register:
if v not in self.spills:
# the value of a BUFFER is its 64bit address
dt = self.vdef(v).dtype
sz = dt.scalar().itemsize * dt.count if not isinstance(dt, PtrDType) else 8
sz = 8 if self.vdef(v).op is Ops.BUFFER else dt.scalar().itemsize * dt.count
offset = self.stack_size + (sz - self.stack_size % sz) % sz
self.spills[v] = UOp.const(dtypes.int32, offset)
self.stack_size = offset + sz
@ -82,10 +83,10 @@ class LinearScanRegallocContext:
live[v] = alloc(cons, i+1 if u.op is not Ops.RANGE else i)
self.reals.setdefault(i, {})[v] = live[v]
# allocate stack array
if u.op is Ops.DEFINE_LOCAL:
# allocate stack array, BUFFER size is in src[0]
if u.op is Ops.BUFFER:
self.locals[u] = UOp.const(dtypes.int32, self.stack_size)
self.stack_size += u.dtype.nbytes()
self.stack_size += u.src[0].arg * u.dtype.itemsize
# loop prologue, avoid loading inside the loop
if u.op is Ops.RANGE:
@ -116,7 +117,7 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
if i in ctx.reals and (v:=ctx.uops[i].src[j].reg) in ctx.spills: nsrc.append(ctx.ren.fill(ctx.spills[v], ctx.vdef(v), ctx.reals[i][v]))
else: nsrc.append(s)
ndefs = tuple(ctx.reals[i][v] for v in x.tag) if isinstance(x.tag, tuple) else x.tag
if x.op is Ops.DEFINE_LOCAL: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], dtype=x.dtype, tag=ndefs))
if x.op is Ops.BUFFER: nx = ctx.ren.isel_matcher.rewrite(ctx.ren.stack_pointer().index(ctx.locals[x], tag=ndefs))
else: nx = x.replace(src=tuple(nsrc), tag=ndefs)
before = [ctx.ren.fill(ctx.spills[v], ctx.vdef(v), r) for v,r in ctx.insert_before.get(i, [])]
@ -132,6 +133,5 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
return nx, before + [nx] + after
pm_regalloc_rewrite = PatternMatcher([
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
regalloc_rewrite),
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.BUFFER, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"), regalloc_rewrite),
])

View file

@ -2,9 +2,9 @@
# allow semicolons to put multiple ops on one line
import sys, struct, functools
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, DType, truncate, AddrSpace
from tinygrad.dtype import dtypes, DType, truncate, AddrSpace
from tinygrad.uop import FastEnum, auto, Ops, GroupOp
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, ParamArg
from tinygrad.renderer.isa import ISARenderer, IselContext, Register, PreRegAllocContext
from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
@ -12,8 +12,8 @@ from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
class X86Ops(FastEnum):
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
# these aren't real instructions
FRAME_INDEX = auto(); LABEL = auto()
# these aren't real instructions, DEFINE is a register placeholder that defines a register without emitting an instruction
FRAME_INDEX = auto(); LABEL = auto(); DEFINE = auto()
# index
LEA = auto()
# register / memory / immediate moves
@ -171,50 +171,32 @@ extra_matcher = PatternMatcher([
(UPat(Ops.CMOD, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x - y * x.alu(Ops.CDIV, y)),
])
# ***** X86 new style -> x86 internal style (pointers, vec dtypes, GEP) *****
pm_x86_style = PatternMatcher([
# buffers are pointers, scalar PARAMs (variables) keep their shape src
(UPat(Ops.PARAM, name="x"), lambda x: x.replace(dtype=x.dtype.ptr(x.src[0].arg), src=()) \
if x.arg.addrspace is AddrSpace.GLOBAL and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.BUFFER, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG if x.arg.addrspace == AddrSpace.REG else Ops.DEFINE_LOCAL,
dtype=x.dtype.ptr(x.src[0].arg, x.arg.addrspace), src=(), arg=x.arg.slot)),
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(dtype=x.src[0].dtype) if x.dtype != x.src[0].dtype else None),
# SHRINK is a vectorized INDEX
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx"), UPat.cvar("c"))), lambda buf,idx,c: buf.index(idx, ptr=True) \
.cast(buf.ptrdtype.base.vec(c.arg).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) if isinstance(buf.dtype, PtrDType) else None),
# cast of a pointer is a noop in new style (any reinterpreting cast was absorbed into SHRINK)
(UPat(Ops.CAST, src=(UPat.var("y"),), name="x"), lambda x,y:
y if isinstance(y.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
# INDEX on a pointer has pointer dtype, INDEX on a register value is a GEP
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat()), name="x"), lambda buf,x:
x.replace(dtype=buf.dtype) if isinstance(buf.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.INDEX, src=(UPat.var("y"), UPat.cvar("c")), name="x"), lambda y,c,x:
y.gep(c.arg) if not isinstance(y.dtype, PtrDType) and y.op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else None),
# restore vec dtypes from structure
(UPat(Ops.LOAD, src=(UPat(Ops.CAST, name="c"),), allow_any_len=True, name="x"), lambda x,c:
x.replace(dtype=x.dtype.scalar().vec(c.ptrdtype.base.count)) if isinstance(c.dtype, PtrDType) and c.ptrdtype.base.count > x.dtype.count else None),
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
if not isinstance(x.dtype, PtrDType) and not any(isinstance(s.dtype, PtrDType) for s in x.src) \
and (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
])
# ***** X86 pre instruction selection *****
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count, AddrSpace.LOCAL), arg=next(ctx))
local_idx = local.index(UOp.const(dtypes.int32, 0), ptr=True)
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
return ptr.cast(cast.dtype).load(dtype=x.dtype)
def scratch_buffer(elem_dt:DType, count:int, slot:int) -> UOp:
return UOp(Ops.BUFFER, elem_dt, src=(UOp.const(dtypes.int, count),), arg=ParamArg(slot, addrspace=AddrSpace.LOCAL))
def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count, AddrSpace.LOCAL), arg=-1)
ptr = gate.where(base.index(idx, ptr=True), local.index(UOp.const(dtypes.int32, 0), ptr=True))
return ptr.cast(cast.dtype).store(val)
def gated_load(ctx, addr:UOp, alt:UOp, gate:UOp, x:UOp):
local = scratch_buffer(addr.src[0].dtype.scalar(), x.dtype.count, next(ctx))
local_idx = local.index(UOp.const(dtypes.int32, 0), dtype=dtypes.uint64)
# the selected address is a 64bit value, the AFTER orders the load after the scratch store and carries the element dtype for the encoder
sel = gate.where(addr.replace(dtype=dtypes.uint64), local_idx)
ptr = UOp(Ops.AFTER, addr.dtype, (sel, (local_idx if x.dtype.count == 1 else local).store(alt)))
return ptr.load(dtype=x.dtype)
# these must be done in a separate matcher because they violate the spec
pre_isel_matcher = pm_x86_style + PatternMatcher([
def gated_store(addr:UOp, gate:UOp, val:UOp):
local = scratch_buffer(addr.src[0].dtype.scalar(), val.dtype.count, -1)
sel = gate.where(addr.replace(dtype=dtypes.uint64), local.index(UOp.const(dtypes.int32, 0), dtype=dtypes.uint64))
return UOp(Ops.AFTER, addr.dtype, (sel,)).store(val)
# legalize the new style graph for isel. NOTE: this runs after the spec is verified, some of these rewrites violate it
pre_isel_matcher = PatternMatcher([
# x86 registers are typed by their width, materialize the structural width of the graph into vec dtypes (this is still valid new style)
(UPat(Ops.SHRINK, src=(UPat(), UPat(), UPat.cvar("c"))).load(allow_any_len=True, name="x"), lambda x,c:
x.replace(dtype=x.dtype.scalar().vec(c.arg)) if c.arg > x.dtype.count else None),
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
if (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
# zero extending scalar 32bit int is a noop
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
# cast between signed and unsigned int is a noop
@ -229,11 +211,12 @@ pre_isel_matcher = pm_x86_style + PatternMatcher([
# noop of a noop is removed
(UPat(Ops.NOOP, src=(UPat(Ops.NOOP),), name="x"), lambda x: x.replace(src=x.src[0].src)),
# moving elements of a single register to another without shuffling is a noop
(UPat(Ops.STACK, src=(UPat.var("y"),), allow_any_len=True, name="x"),
lambda y,x: UOp(Ops.NOOP, x.dtype, y.src) if all(s.op is Ops.GEP and s.src == y.src and s.arg[0] == i for i,s in enumerate(x.src)) else None),
# gated load/store become a conditional move on the index, the load/store are unconditional
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").store(UPat.var("val"), UPat.var("gate")), gated_store),
(UPat(Ops.STACK, src=(UPat.var("y").index(UPat()),), allow_any_len=True, name="x"),
lambda y,x: UOp(Ops.NOOP, x.dtype, (y,)) if all(s.op is Ops.INDEX and len(s.src) == 2 and s.src[0] is y \
and s.src[1].op is Ops.CONST and s.src[1].arg == i for i,s in enumerate(x.src)) else None),
# gated load/store become a conditional move on the address, the load/store are unconditional
(UPat((Ops.INDEX, Ops.SHRINK), name="addr").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
(UPat((Ops.INDEX, Ops.SHRINK), name="addr").store(UPat.var("val"), UPat.var("gate")), gated_store),
# TODO: remove this once we allow all flag producing ops in cmove
# if gate in scalar int cmove is not a comparison need to add one to set the flag
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
@ -264,10 +247,10 @@ reg_strs = {"rax": {4:"eax", 2:"ax", 1:"al"}, "rcx": {4:"ecx", 2:"cx", 1:"cl"},
# ***** X86 instruction selection *****
# if s is used multiple times we don't fold
def is_foldable(ctx:IselContext, x:UOp, s:UOp) -> bool: return len(ctx.uses[s]) == x.src.count(s) == 1
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
def lane(x:UOp, i:int) -> int: return s.arg[0] if (s:=x.src[i]).op is Ops.GEP else 0
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.INDEX else s
def lane(x:UOp, i:int) -> int: return s.src[1].arg if (s:=x.src[i]).op is Ops.INDEX else 0
def to_int(dt:DType): return {dtypes.float16: dtypes.int16, dtypes.float32: dtypes.int32, dtypes.float64: dtypes.int64}[dt]
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.DEFINE_REG, dt, tag=None if reg is None else (reg,))
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, dt, arg=X86Ops.DEFINE, tag=None if reg is None else (reg,))
def imm(dt:DType, v:int) -> UOp: return UOp.const(dt, truncate[dt](v)).rtag()
def to_imm(c:UOp) -> UOp|None:
if c.op is not Ops.CONST: return None
@ -348,41 +331,52 @@ def idiv(ctx:IselContext, x:UOp) -> UOp:
# this move "cleanses" the register constraints (rax/rdx) of idiv as that only applies on definition and not on the uses of idiv
return x.ins(X86Ops.MOV, src=(idiv,))
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp]:
# a memory address operand is (base, index, displacement, size). size is the element size, it scales the index and is the memory operand width.
# it is materialized as an immediate so the address stays correct if the base register is ever spilled and refilled
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp, UOp]:
def _disp(v:int) -> UOp: return imm(dtypes.int32 if abs(v) > dtypes.int8.max else dtypes.int8, v)
def _cast(v:UOp) -> UOp: return v.cast(dtypes.int64) if v.vmin < 0 else v
if x.op is not Ops.INDEX: return (x, UOp(Ops.NOOP), _disp(0))
base, idx = x.src
disp_scale = base.dtype.itemsize if isinstance(base.dtype, PtrDType) else 1
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * disp_scale))
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * disp_scale))
return (base, _cast(idx), _disp(0))
if x.op not in {Ops.INDEX, Ops.SHRINK}: return (x, UOp(Ops.NOOP), _disp(0), imm(dtypes.uint8, x.dtype.itemsize))
base, idx = x.src[0], x.src[1]
# buffers are indexed by element, everything else (the stack pointer) by byte
scale = base.dtype.itemsize if base.op in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else 1
sz = imm(dtypes.uint8, base.dtype.itemsize)
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * scale), sz)
if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * scale), sz)
return (base, _cast(idx), _disp(0), sz)
def abi(ctx:IselContext, x:UOp) -> UOp|None:
if isinstance(x.tag, tuple): return None
i = ctx.func_args.index(x)
def _stack_arg(disp:int): return (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=disp))
if sys.platform == "win32": src = (x.replace(tag=((RCX, RDX, GPR[8], GPR[9])[i],)),) if i < 4 else _stack_arg((i-3)*8+32)
else: src = (x.replace(tag=((RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i],)),) if i < 6 else _stack_arg((i-5)*8)
# buffer params hold addresses, their value moves as a 64bit int
dt = dtypes.uint64 if x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.GLOBAL else x.dtype
# the shape srcs of a PARAM are not values, tag them so they aren't materialized into registers
def _reg_arg(r:Register) -> tuple[UOp, ...]: return (x.replace(dtype=dt, src=tuple(s.rtag() for s in x.src), tag=(r,)),)
def _stack_arg(disp:int):
return (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(Ops.INS, arg=X86Ops.FRAME_INDEX, dtype=dtypes.int32, tag=disp), imm(dtypes.uint8, 8))
if sys.platform == "win32": src = _reg_arg((RCX, RDX, GPR[8], GPR[9])[i]) if i < 4 else _stack_arg((i-3)*8+32)
else: src = _reg_arg((RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i]) if i < 6 else _stack_arg((i-5)*8)
# this move "cleanses" the abi register constraint
return x.ins(X86Ops.MOV, src=src)
return x.ins(X86Ops.MOV, dtype=dt, src=src)
def alloc_vregs(ctx:IselContext, x:UOp) -> UOp|None:
# real registers
if x.op is Ops.DEFINE_REG and x.tag is not None: return None
# register placeholders with real registers
if x.arg is X86Ops.DEFINE and x.tag is not None: return None
# this is an immediate
if x.arg is X86Ops.FRAME_INDEX: return None
# no register definition
if x.dtype is dtypes.void: return None
# already allocated vregs
if isinstance(x.tag, tuple) and x.tag[0]._cons: return None
# allocate vreg definitions
# allocate vreg definitions, the value of a BUFFER is its address so it lives in a gpr
defs = []
if isinstance(x.tag, tuple): defs = [ctx.vreg(x.tag)]
elif x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType): defs = [ctx.vreg(WGPR)]
elif x.op is Ops.BUFFER or x.dtype in dtypes.ints+(dtypes.bool,): defs = [ctx.vreg(WGPR)]
elif x.dtype in dtypes.floats or x.dtype.count > 1: defs = [ctx.vreg(XMM)]
# TODO: add this once the scheduler can track register pressure
# if x.arg in X86GroupOp.WriteFlags: defs.append(ctx.vreg(RFLAGS))
# the size src of a BUFFER is not a value, tag it so it isn't materialized into a register
if x.op is Ops.BUFFER: return x.replace(src=tuple(s.rtag() for s in x.src), tag=tuple(defs))
return x.replace(tag=tuple(defs))
dts = dtypes.ints + (dtypes.bool, dtypes.float16, dtypes.float32, dtypes.float64)
@ -393,11 +387,12 @@ dt_128bit = tuple(dt.vec(l) for dt in dts for l in [16,8,4,2,1] if l*dt.itemsize
isel_matcher = PatternMatcher([
# **** Op -> Op ****
# cast to pointer is a noop
(UPat.var("y").cast(name="x"), lambda y,x: y if isinstance(x.dtype, PtrDType) or y.dtype == dtypes.void else None),
# float gep(0) is a noop as it just moves the 0th element from one xmm register to another
# cast of void is a noop
(UPat.var("y").cast(name="x"), lambda y,x: y if y.dtype == dtypes.void else None),
# extracting the 0th float element is a noop as it just moves the 0th element from one xmm register to another
# this is done here to not interfere with shuffles
(UPat(dtype=dtypes.floats).gep(0, name="x"), lambda x: x.replace(op=Ops.NOOP, arg=None)),
(UPat(dtype=dtypes.floats).index(UPat(Ops.CONST, arg=0), name="x"),
lambda x: x.replace(op=Ops.NOOP, src=x.src[:1]) if x.src[0].dtype.count > 1 else None),
# range is lowered to acc, cmp, jmp after regalloc
(UPat(Ops.RANGE, src=(UPat.cvar("c"),), allow_any_len=True, name="x"), lambda c,x: x.replace(src=(imm(c.dtype, c.arg),) + x.src[1:])),
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(tag=(ctx.vreg(WGPR),)) if not isinstance(x.tag, tuple) else None),
@ -409,9 +404,6 @@ isel_matcher = PatternMatcher([
if not x.src or x.src[0].arg is not X86Ops.RET else None),
# function abi constraints
(UPat((Ops.PARAM, Ops.SPECIAL), name="x"), abi),
# these are treated the same for now
(UPat(Ops.DEFINE_REG, name="x"), lambda x:
x.replace(op=Ops.DEFINE_LOCAL, dtype=x.dtype.base.ptr(x.dtype.size, AddrSpace.LOCAL)) if isinstance(x.arg, int) else None),
# constants that can't be immediates, move them to registers
(UPat.cvar("x", dtypes.int64s), lambda x: x.ins(X86Ops.MOVABS, src=(imm(x.dtype, x.arg),)) if not x.tag else None),
(UPat.cvar("x", dtypes.ints+(dtypes.bool,)), lambda x: x.ins(X86Ops.MOVi, src=(imm(x.dtype, x.arg),)) if not x.tag else None),
@ -479,12 +471,17 @@ isel_matcher = PatternMatcher([
(UPat(Ops.STACK, dtypes.float32, name="x"), vinsertps),
(UPat.var("y", dtypes.ints+(dtypes.bool,)).broadcast(name="x"), vpbroadcast),
(UPat(Ops.STACK, dtypes.ints+(dtypes.bool,), name="x"), vpins),
# gep
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.floats).gep(name="x"), lambda y,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, x.arg[0] * x.dtype.itemsize)))),
# INDEX on a vector register value extracts a single element
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
(UPat.var("y", dtypes.int16s).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
(UPat.var("y", dtypes.int32s).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
(UPat.var("y", dtypes.int64s).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, c.arg))) if y.dtype.count > 1 else None),
(UPat.var("y", dtypes.floats).index(UPat.cvar("c"), name="x"),
lambda y,c,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, c.arg * x.dtype.itemsize))) if y.dtype.count > 1 else None),
# fused multiply add
((UPat(Ops.MUL, dtypes.float32, name="a") + UPat.var("b")).named("c"), lambda ctx,a,b,c:
a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
@ -578,8 +575,9 @@ isel_matcher = PatternMatcher([
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.ins(X86Ops.VMOVQ)),
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
# index
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
# index on a buffer (or the stack pointer) computes an address, addresses are 64bit values
(UPat((Ops.INDEX, Ops.SHRINK), name="x"),
lambda x: x.ins(X86Ops.LEA, dtype=dtypes.uint64, src=fold_address(x)) if x.src[0].dtype.count == 1 else None),
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
# copy, load, store
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
@ -608,7 +606,7 @@ isel_matcher = PatternMatcher([
(UPat(Ops.INS, src=(UPat(), UPat(), UPat(Ops.LOAD, src=(UPat(name="a"),), name="y")), allow_any_len=True, name="x"), lambda ctx,y,a,x:
x.replace(src=x.src[:2] + fold_address(a) + x.src[3:]) if x.arg in X86GroupOp.ReadMem3rd and is_foldable(ctx, x, y) else None),
# allocate virtual registers
(UPat((Ops.INS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), alloc_vregs),
(UPat((Ops.INS, Ops.BUFFER), name="x"), alloc_vregs),
])
# ***** pre register allocation *****
@ -656,14 +654,16 @@ post_regalloc_matcher = PatternMatcher([
# ***** X86 instruction encoding *****
def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) -> bytes|None:
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
def _encode(reg_uop:UOp|None, rm_uop:UOp, idx_uop:UOp|None=None, disp_uop:UOp|None=None, sz_uop:UOp|None=None,
vvvv_uop:UOp|None=None, imm_uop:UOp|None=None) -> bytes:
nonlocal reg, opc
# get the encoding values of the different fields
reg = cast(int, cast(Register, reg_uop.reg).index if reg_uop is not None else reg)
rm = cast(Register, rm_uop.reg).index
idx = cast(Register, idx_uop.reg).index if idx_uop is not None and idx_uop.reg is not None else 4
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else rm_uop.dtype.itemsize
reg_sz = (reg_uop.dtype.itemsize if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
# for a memory operand the rm size is the element size from the address, otherwise it's the size of the value in the register
rm_sz = sz_uop.arg if sz_uop is not None else rm_uop.dtype.itemsize
reg_sz = reg_uop.dtype.itemsize if reg_uop is not None else 0
sz = reg_sz or rm_sz
# encode instruction
@ -723,19 +723,19 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) ->
# when a uop writes to memory it takes the form of a store, dtype is void, no definition
address:tuple[UOp|None, ...]
if x.arg in X86GroupOp.WriteMem:
if len(x.src) > 3: address, rest = x.src[:3], x.src[3:]
else: address, rest = (x, None, None), x.src
if len(x.src) > 4: address, rest = x.src[:4], x.src[4:]
else: address, rest = (x, None, None, None), x.src
return _encode(rest[0], *address, *(None, *rest[1:])) if reg is None else _encode(None, *address, *(None, *rest[:1]))
if x.arg in X86GroupOp.Rm1st:
if len(x.src) > 2: address, rest = x.src[:3], x.src[3:]
else: address, rest = (x.src[0], None, None), x.src[1:]
if len(x.src) > 3: address, rest = x.src[:4], x.src[4:]
else: address, rest = (x.src[0], None, None, None), x.src[1:]
imm_uop = rest[:1] if rest and rest[0].op is Ops.CONST else (None,)
return _encode(x, *address, *(None, *imm_uop)) if reg is None else _encode(None, *address, *(x if sel else None, *imm_uop))
if x.arg in X86GroupOp.Rm2nd:
if len(x.src) > 3: address, rest = x.src[1:4], x.src[:1] + x.src[4:]
else: address, rest = (x.src[1], None, None), x.src[:1] + x.src[2:]
if len(x.src) > 4: address, rest = x.src[1:5], x.src[:1] + x.src[5:]
else: address, rest = (x.src[1], None, None, None), x.src[:1] + x.src[2:]
# cmp/vucomiss reg, rm don't define a new register
return _encode(x, *address, *rest) if x.dtype is not dtypes.void else _encode(rest[0], *address)
@ -770,8 +770,9 @@ encodings = {
X86Ops.VCVTDQ2PS: lambda x: encode(x, 0x5B, pp=0, sel=1), X86Ops.VCVTDQ2PD: lambda x: encode(x, 0xE6, pp=2, sel=1),
X86Ops.VCVTPS2PD: lambda x: encode(x, 0x5A, pp=0, sel=1), X86Ops.VCVTPD2PS: lambda x: encode(x, 0x5A, pp=1, sel=1),
X86Ops.VCVTTPS2DQ: lambda x: encode(x, 0x5B, pp=2, sel=1), X86Ops.VCVTTPD2DQ: lambda x: encode(x, 0xE6, pp=1, sel=1),
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.itemsize == 8),
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.itemsize == 8),
# the int src is the 2nd src (the rm field), if it was folded into a memory operand its width is the element size of the address
X86Ops.VCVTSI2SS: lambda x: encode(x, 0x2A, pp=2, sel=1, we=(x.src[4].arg if len(x.src) > 4 else x.src[1].dtype.itemsize) == 8),
X86Ops.VCVTSI2SD: lambda x: encode(x, 0x2A, pp=3, sel=1, we=(x.src[4].arg if len(x.src) > 4 else x.src[1].dtype.itemsize) == 8),
X86Ops.VCVTTSS2SI: lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype.itemsize == 8),
X86Ops.VCVTTSD2SI: lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype.itemsize == 8),
# int division
@ -871,43 +872,41 @@ class X86Renderer(ISARenderer):
self.compiler = X86Compiler()
def is_two_address(self, x:UOp) -> bool: return x.arg in X86GroupOp.TwoAddress
def stack_pointer(self) -> UOp: return def_reg(dtypes.uint64, RSP)
# nasty hacks to deal with pointers TODO: rm pointers
# the value of a BUFFER is its address, it moves through registers and the stack as a 64bit int
def copy(self, x:UOp, reg:Register):
dt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
ret = isel_matcher.rewrite(UOp(Ops.COPY, dt, (x,), tag=reg))
ret = isel_matcher.rewrite(UOp(Ops.COPY, dtypes.uint64 if x.op is Ops.BUFFER else x.dtype, (x,), tag=reg))
assert ret is not None
return ret.replace(dtype=x.dtype)
return ret
def spill(self, disp:UOp, x:UOp) -> UOp:
nx = x.replace(dtype=dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype)
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(nx))
if x.op is Ops.BUFFER: x = x.replace(dtype=dtypes.uint64)
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(x))
assert ret is not None
return ret.replace(src=(s if s is not nx else x for s in ret.src))
return ret
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp:
ndt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=ndt, tag=reg))
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=dtypes.uint64 if x.op is Ops.BUFFER else x.dtype, tag=reg))
assert ret is not None
return ret.replace(dtype=x.dtype)
return ret
def asm_str(self, uops:list[UOp], function_name:str) -> str:
def _format_op(x:UOp) -> str: return f" {(o[7:-1] if (o:=str(x.arg))[-1] in ('i', 'm') else o[7:]).lower():7s}"
def _format_operands(x:UOp) -> str:
def _format(src:tuple[UOp, ...]) -> list[str]:
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize if not isinstance(s.dtype, PtrDType) else 8, o) if \
return [str(s.arg) if s.op is Ops.CONST else reg_strs[o].get(s.dtype.itemsize, o) if \
(o:=str(s.reg)) in reg_strs else o for s in src if s.reg is not None]
def _mem_adress(base:UOp, idx:UOp, disp:UOp) -> list[str]:
return [f"[{base.reg}" + (f" + {idx.reg}*{base.dtype.itemsize}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"]
def _mem_adress(base:UOp, idx:UOp, disp:UOp, sz:UOp) -> list[str]:
return [f"[{base.reg}" + (f" + {idx.reg}*{sz.arg}" if idx.reg else "") + (f" + {disp.arg}" if disp.arg else "") + "]"]
if len(x.src) > 3 and x.arg in X86GroupOp.WriteMem: ret = _mem_adress(*x.src[:3]) + _format(x.src[3:])
elif len(x.src) > 2 and x.arg in X86GroupOp.Rm1st: ret = _format((x,)) + _mem_adress(*x.src[:3]) + _format(x.src[3:])
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm2nd: ret = _format((x, x.src[0])) + _mem_adress(*x.src[1:4]) + _format(x.src[4:])
if len(x.src) > 4 and x.arg in X86GroupOp.WriteMem: ret = _mem_adress(*x.src[:4]) + _format(x.src[4:])
elif len(x.src) > 3 and x.arg in X86GroupOp.Rm1st: ret = _format((x,)) + _mem_adress(*x.src[:4]) + _format(x.src[4:])
elif len(x.src) > 4 and x.arg in X86GroupOp.Rm2nd: ret = _format((x, x.src[0])) + _mem_adress(*x.src[1:5]) + _format(x.src[5:])
else: ret = _format((x,) + x.src)
return ", ".join(ret)
asm = [f".{function_name}:"]
for u in uops:
if u.op is not Ops.INS: continue
if u.op is not Ops.INS or u.arg is X86Ops.DEFINE: continue
if u.arg is X86Ops.LABEL: asm.append(f"{str(u.tag)}:")
elif u.arg is X86Ops.RET: asm.append(_format_op(u))
else: asm.append(_format_op(u) + " " + _format_operands(u))
@ -918,7 +917,7 @@ class X86Renderer(ISARenderer):
jumps: dict[UOp, int] = {}
binary = bytearray()
for u in uops:
if u.op is not Ops.INS: continue
if u.op is not Ops.INS or u.arg is X86Ops.DEFINE: continue
if u.arg is X86Ops.LABEL:
targets[u.tag] = len(binary)
continue