add Ops.INS to x86

This commit is contained in:
ttomsa 2026-02-19 00:20:21 +00:00
commit dd558ecfae
5 changed files with 300 additions and 305 deletions

View file

@ -30,10 +30,10 @@ class RegallocContext:
lr, ranges = self.live_range, []
for i,u in enumerate(reversed(uops)):
if u.op in (Ops.NOOP, Ops.AFTER): continue
for v in {s.arg for s in (u,) + u.src if isinstance(s.arg, Register)}: lr.setdefault(v, []).insert(0, len(uops) - 1 - i)
for v in {s.tag for s in (u,) + u.src if isinstance(s.tag, Register)}: lr.setdefault(v, []).insert(0, len(uops) - 1 - i)
# a var defined before a range and used inside it is needed for the whole range
if u.arg in lr and (n:=max((lr[rng][-1] for rng in ranges if lr[rng][0] < lr[u.arg][-1] < lr[rng][-1]), default=None)): lr[u.arg].append(n)
if u.op is Ops.RANGE: ranges.append(u.arg)
if u.tag in lr and (n:=max((lr[rng][-1] for rng in ranges if lr[rng][0] < lr[u.tag][-1] < lr[rng][-1]), default=None)): lr[u.tag].append(n)
if u.op is Ops.RANGE: ranges.append(u.tag)
# TODO: rm pointers
# nasty hacks to deal with pointers
@ -44,7 +44,7 @@ def assign(ctx:RegallocContext, x:UOp, reg:Register):
return ret.replace(dtype=x.dtype)
def load(ctx:RegallocContext, dt:DType, disp:UOp, reg:Register):
ndt = dtypes.uint64 if isinstance(dt, PtrDType) else dt
ret = ctx.isel.rewrite(ctx.stack_ptr.index(disp).load(dtype=ndt, arg=reg))
ret = ctx.isel.rewrite(ctx.stack_ptr.index(disp).load(dtype=ndt, tag=reg))
assert ret is not None
return ret.replace(dtype=dt)
def store(ctx:RegallocContext, disp:UOp, x:UOp):
@ -71,7 +71,7 @@ def regalloc(ctx:RegallocContext, x:UOp, i:int) -> tuple[UOp, list[UOp]]:
nsrc, loads = [], []
for s in x.src:
# allocate srcs, if src was spilled it's replaced by a load, if it's live the load was already emited otherwise alloc and emit one
if isinstance(s.arg, Register) and (v:=ctx.rewrite_to_vreg[s]) in ctx.spills:
if isinstance(s.tag, Register) and (v:=ctx.rewrite_to_vreg[s]) in ctx.spills:
# TODO: the constraints only apply to the definition, you need to insert moves in the graph to "cleanse" the constraint
# then those moves are removed after regalloc if they move to the same register. I think this is the llvm approach
# alternatively you could beef up the register class to include constraints on the srcs, then you check those here
@ -82,7 +82,7 @@ def regalloc(ctx:RegallocContext, x:UOp, i:int) -> tuple[UOp, list[UOp]]:
else: s = load(ctx, s.dtype, ctx.spills[v], ctx.live[v])
nsrc.append(s)
# allocate destination
if isinstance(v:=x.arg, Register) and v not in ctx.live:
if isinstance(v:=x.tag, Register) and v not in ctx.live:
# if no cons it's a real register, so it can only be assigned to itself
cons = v.cons or (v,)
# two address instructions (src is used in dest) can only coalesce reused src. reused src goes first to get priority in case of a tiebreak
@ -93,7 +93,7 @@ def regalloc(ctx:RegallocContext, x:UOp, i:int) -> tuple[UOp, list[UOp]]:
assert cons
ctx.live[v] = alloc(ctx, cons, i+1)
nx = x.replace(src=tuple(nsrc), arg=ctx.live.get(v, v))
nx = x.replace(src=tuple(nsrc), tag=ctx.live.get(v, v))
# TODO: this check exists because of a hack in x86, rm once multiple outputs are supported
if nx not in ctx.rewrite_to_vreg: ctx.rewrite_to_vreg[nx] = v
if v not in ctx.vreg_to_rewrite: ctx.vreg_to_rewrite[v] = nx
@ -101,10 +101,10 @@ def regalloc(ctx:RegallocContext, x:UOp, i:int) -> tuple[UOp, list[UOp]]:
# move uops to registers before the loop to avoid loading inside the loop
def loop_prologue(ctx:RegallocContext, x:UOp, i:int):
assert isinstance(x.arg, Register)
assert isinstance(x.tag, Register)
nx, lst = regalloc(ctx, x, i)
# we move to register vars used in the loop sorted by next use, vars not used in the loop will not be reloaded in the epilogue
used_in_loop = [v for v in ctx.live.keys() | ctx.spills.keys() if any(i <= l < ctx.live_range[x.arg][-1] for l in ctx.live_range[v])]
used_in_loop = [v for v in ctx.live.keys() | ctx.spills.keys() if any(i <= l < ctx.live_range[x.tag][-1] for l in ctx.live_range[v])]
sorted_uses = sorted(used_in_loop, key=lambda k: next(l-i for l in ctx.live_range[k] if l >= i))
live_in: dict[Register, Register] = {}
loads = []
@ -135,12 +135,12 @@ def loop_epilogue(ctx:RegallocContext, x:UOp, i:int):
pm_regalloc = PatternMatcher([
(UPat(Ops.RANGE, name="x"), lambda ctx,x: loop_prologue(ctx, x, next(ctx.idx))),
(UPat(Ops.END, name="x"), lambda ctx,x: loop_epilogue(ctx, x, next(ctx.idx))),
(UPat(X86GroupOp.All | {Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER}, name="x"), lambda ctx,x: regalloc(ctx, x, next(ctx.idx))),
(UPat({Ops.INS, Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER}, name="x"), lambda ctx,x: regalloc(ctx, x, next(ctx.idx))),
])
# annoying that this is another pm
pm_insert_spills = PatternMatcher([
# insert spill after definition
(UPat(X86GroupOp.All | {Ops.RANGE}, name="x"), lambda ctx,x:
(UPat({Ops.INS, Ops.RANGE}, name="x"), lambda ctx,x:
(x, [x, store(ctx, y, x)]) if (y:=ctx.spills.get(ctx.rewrite_to_vreg.get(x))) is not None else None),
])

View file

@ -1,15 +1,13 @@
# flake8: noqa: E702
# allow semicolons to put multiple ops on one line
from tinygrad.uop.ops import Ops, auto
from tinygrad.uop import FastEnum, auto
# ***** X86 *****
# NOTE: mypy doesn't allow extending enums even with our wrapper, it also doesn't allow overriding i.e. Ops.ADD to X86Ops.ADD
# we ignore it in both cases
class X86Ops(Ops): # type: ignore[misc]
class X86Ops(FastEnum):
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
# register, not an instruction. FRAME_INDEX is used when the function arg is on the stack and is rewritten to IMM when stack size is known
DEFINE_REG = auto(); FRAME_INDEX = auto() # type: ignore[misc]
DEFINE_REG = auto(); FRAME_INDEX = auto()
# const
IMM = auto()
# index
@ -48,10 +46,10 @@ class X86Ops(Ops): # type: ignore[misc]
VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto()
VBROADCASTSS = auto() # TODO: VBROADCASTSD is ymm only, add once they are supported
# int binary
IDIV = auto(); DIV = auto() # type: ignore[misc]
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto() # type: ignore[misc]
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto() # type: ignore[misc]
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto() # type: ignore[misc]
IDIV = auto(); DIV = auto()
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto()
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto()
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto()
# float unary (sometimes not unary)
VROUNDSS = auto(); VROUNDSD = auto(); VROUNDPS = auto(); VROUNDPD = auto()
VSQRTSS = auto(); VSQRTSD = auto(); VSQRTPS = auto(); VSQRTPD = auto()

View file

@ -109,8 +109,8 @@ WGPR = tuple(r for r in GPR if r != RSP)
# ***** X86 instruction selection *****
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(X86Ops.DEFINE_REG, dt, arg=reg)
def imm(dt:DType, v:int|float) -> UOp: return UOp(X86Ops.IMM, dt, arg=v)
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dt, tag=reg)
def imm(dt:DType, v:int|float) -> UOp: return UOp(Ops.INS, arg=X86Ops.IMM, dtype=dt, tag=v)
def to_imm(c:UOp) -> UOp|None:
if c.op is not Ops.CONST: return None
if c.dtype is dtypes.int64: return imm(dtypes.int32, c.arg) if not c.overflows(dtypes.int32) else None
@ -118,17 +118,16 @@ def to_imm(c:UOp) -> UOp|None:
if c.dtype in dtypes.ints+(dtypes.bool,): return imm(c.dtype, c.arg)
return None
def cmp(x:UOp) -> UOp:
if x.src[0].dtype is dtypes.float32: return UOp(X86Ops.VUCOMISS, src=x.src)
if x.src[0].dtype is dtypes.float64: return UOp(X86Ops.VUCOMISD, src=x.src)
return UOp(X86Ops.CMP, src=x.src) if (i:=to_imm(x.src[1])) is None else UOp(X86Ops.CMPi, src=(x.src[0], i))
if x.src[0].dtype is dtypes.float32: return x.ins(X86Ops.VUCOMISS, dtype=dtypes.void)
if x.src[0].dtype is dtypes.float64: return x.ins(X86Ops.VUCOMISD, dtype=dtypes.void)
return x.ins(X86Ops.CMP, dtype=dtypes.void) if (i:=to_imm(x.src[1])) is None else x.ins(X86Ops.CMPi, dtype=dtypes.void, src=(x.src[0], i))
# vshufps xmm2, xmm0, xmm1, imm
# xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1 according to imm
def vshufps(x:UOp) -> UOp|None:
def _in(i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
if len(x.src) != 4 or _in(0) is not _in(1) or _in(2) is not _in(3): return None
return UOp(X86Ops.VSHUFPS, x.dtype, (_in(0), _in(2),
imm(dtypes.uint8, sum(s.arg[0] << 2*i if s.op is Ops.GEP else 0 for i,s in enumerate(x.src)))))
if len(x.src) != 4 or (a:=_in(0)) is not _in(1) or (b:=_in(2)) is not _in(3): return None
return x.ins(X86Ops.VSHUFPS, src=(a, b, imm(dtypes.uint8, sum(s.arg[0] << 2*i if s.op is Ops.GEP else 0 for i,s in enumerate(x.src)))))
# vinsertps xmm2, xmm0, xmm1, imm
# inserts any 32 bit element in xmm1 into any position in xmm0 according to immm, result is written to xmm2
@ -138,14 +137,14 @@ def vinsertps(x:UOp) -> UOp:
s, v = x.src[i], 0
if s.op is Ops.GEP: s, v = s.src[0], s.arg[0]
# moving the 0th element into the 0th position does nothing
return s if i == v == 0 else UOp(X86Ops.VINSERTPS, x.dtype, (ret, s, imm(dtypes.uint8, v << 6 | i << 4)))
return s if i == v == 0 else x.ins(X86Ops.VINSERTPS, src=(ret, s, imm(dtypes.uint8, v << 6 | i << 4)))
return functools.reduce(_insert, range(len(x.src)), def_reg(x.dtype))
# vpinsq xmm2, xmm0, rax, imm
# inserts element in rax into any position in xmm0, result is written to xmm2 according to imm
def vpins(x:UOp) -> UOp:
op = {1: X86Ops.VPINSRB, 2: X86Ops.VPINSRW, 4: X86Ops.VPINSRD, 8: X86Ops.VPINSRQ}[x.dtype.scalar().itemsize]
return functools.reduce(lambda ret,i: UOp(op, x.dtype, (ret, x.src[i], imm(dtypes.uint8, i))), range(len(x.src)), def_reg(x.dtype))
return functools.reduce(lambda ret,i: x.ins(op, src=(ret, x.src[i], imm(dtypes.uint8, i))), range(len(x.src)), def_reg(x.dtype))
def div(ctx:IselContext, x:UOp):
# zero extend or move src[0] to x
@ -184,7 +183,7 @@ def fuse_load(ctx:IselContext, x:UOp, i:int) -> UOp|None:
def abi(ctx:IselContext, x:UOp):
i = ctx.func_args.index(x)
def _stack_arg(disp:int):
return UOp(X86Ops.MOV, x.dtype, (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(X86Ops.FRAME_INDEX, dtypes.int32, arg=disp)))
return x.ins(X86Ops.MOV, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(X86Ops.FRAME_INDEX, dtypes.int32, arg=disp)))
if sys.platform == "win32": return def_reg(x.dtype, (RCX, RDX, GPR[8], GPR[9])[i]) if i < 4 else _stack_arg((i-3)*8+32)
return def_reg(x.dtype, (RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i]) if i < 6 else _stack_arg((i-5)*8)
@ -205,195 +204,195 @@ isel_matcher = PatternMatcher([
# **** Op -> X86Op ****
# add callee saved registers to the RET, these will be scheduled at the top of the kernel and will be saved/restored if they are used in regalloc
# so regalloc builds the prologue/epilogue naturally
(UPat(Ops.SINK, name="x"), lambda x: x.replace(op=X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64, r) for r in [RSP, RBP]))),
(UPat(Ops.SINK, name="x"), lambda x: x.ins(X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64, r) for r in [RSP, RBP]))),
# function abi constraints
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), abi),
# these are treated the same for now
(UPat((Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), lambda ctx,x:
x.replace(op=X86Ops.LEA, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), imm(dtypes.int32, ctx.inc_stack(x.dtype.nbytes()))), arg=None)),
x.ins(X86Ops.LEA, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), imm(dtypes.int32, ctx.inc_stack(x.dtype.nbytes()))), arg=None)),
# constants that can't be immediates, move them to registers
(UPat(Ops.CONST, dtypes.float16, name="x"), lambda x: UOp(X86Ops.VPINSRW, x.dtype, (def_reg(x.dtype), UOp(X86Ops.MOVi, dtypes.int16, (imm(x.dtype, x.arg),)), imm(dtypes.uint8, 0)))), # noqa: E501
(UPat(Ops.CONST, dtypes.float32, name="x"), lambda x: UOp(X86Ops.VMOVD, x.dtype, (UOp(X86Ops.MOVi, dtypes.int32, (imm(x.dtype, x.arg),)),))),
(UPat(Ops.CONST, dtypes.float64, name="x"), lambda x: UOp(X86Ops.VMOVQ, x.dtype, (UOp(X86Ops.MOVABS, dtypes.int64, (imm(x.dtype, x.arg),)),))),
(UPat(Ops.CONST, dtypes.int64s, name="x"), lambda x: UOp(X86Ops.MOVABS, x.dtype, (imm(x.dtype, x.arg),))),
(UPat(Ops.CONST, dtypes.ints+(dtypes.bool,), name="x"), lambda x: UOp(X86Ops.MOVi, x.dtype, (imm(x.dtype, x.arg),))),
(UPat(Ops.CONST, dtypes.float16, name="x"), lambda x: x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype), UOp(Ops.INS, arg=X86Ops.MOVi, dtype=dtypes.int16, src=(imm(x.dtype, x.arg),)), imm(dtypes.uint8, 0)))), # noqa: E501
(UPat(Ops.CONST, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VMOVD, src=(UOp(Ops.INS, arg=X86Ops.MOVi, dtype=dtypes.int32, src=(imm(x.dtype, x.arg),)),))), # noqa: E501
(UPat(Ops.CONST, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VMOVQ, src=(UOp(Ops.INS, arg=X86Ops.MOVABS, dtype=dtypes.int64, src=(imm(x.dtype, x.arg),)),))), # noqa: E501
(UPat(Ops.CONST, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVABS, src=(imm(x.dtype, x.arg),))),
(UPat(Ops.CONST, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOVi, src=(imm(x.dtype, x.arg),))),
# conditional moves that use masks NOTE: these currently assume a mask producing cmp exists
(UPat.var("m").where(UPat.var("a", dtypes.ints), UPat.var("b")), lambda m,a,b: UOp(X86Ops.VPBLENDVB, a.dtype, (b, a, m.replace(dtype=m.src[0].dtype))) if a.dtype.count > 1 else None), # noqa: E501
(UPat.var("m").where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda m,a,b: UOp(X86Ops.VBLENDVPS, a.dtype, (b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501
(UPat.var("m").where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda m,a,b: UOp(X86Ops.VBLENDVPD, a.dtype, (b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501
(UPat.var("m").where(UPat.var("a", dtypes.ints), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.VPBLENDVB, src=(b, a, m.replace(dtype=m.src[0].dtype))) if a.dtype.count > 1 else None), # noqa: E501
(UPat.var("m").where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.VBLENDVPS, src=(b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501
(UPat.var("m").where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.VBLENDVPD, src=(b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501
# in this case we have a mask producing comparison whose user expects a bool, so we convert to bool
(UPat(GroupOp.Comparison, dtypes.bool, (UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(dtype=x.src[0].dtype).bitcast(dtypes.int32).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)), # noqa: E501
(UPat(GroupOp.Comparison, dtypes.bool, (UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(dtype=x.src[0].dtype).bitcast(dtypes.int64).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)), # noqa: E501
# conditional moves that use flags
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVL, a.dtype, src=(b, a, cmp(m)))), # noqa: E501
(UPat(Ops.CMPLT, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVB, a.dtype, src=(b, a, cmp(m)))),
(UPat(Ops.CMPEQ, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVE, a.dtype, src=(b, a, cmp(m)))),
(UPat(Ops.CMPNE, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVNE, a.dtype, src=(b, a, cmp(m)))),
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVL, src=(b, a, cmp(m)))), # noqa: E501
(UPat(Ops.CMPLT, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVB, src=(b, a, cmp(m)))),
(UPat(Ops.CMPEQ, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVE, src=(b, a, cmp(m)))),
(UPat(Ops.CMPNE, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVNE, src=(b, a, cmp(m)))),
# jumps, use flags
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="y"),), name="x"), lambda y,x: UOp(X86Ops.JB, x.dtype, (cmp(y),))), # noqa: E501
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, name="y"),)), lambda y: UOp(X86Ops.JL, src=(cmp(y),))),
(UPat(Ops.IF, src=(UPat(Ops.CMPEQ, name="y"),)), lambda y: UOp(X86Ops.JE, src=(cmp(y),))),
(UPat(Ops.IF, src=(UPat(Ops.CMPNE, name="y"),)), lambda y: UOp(X86Ops.JNE, src=(cmp(y),))),
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JB, src=(cmp(y),))), # noqa: E501
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JL, src=(cmp(y),))),
(UPat(Ops.IF, src=(UPat(Ops.CMPEQ, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JE, src=(cmp(y),))),
(UPat(Ops.IF, src=(UPat(Ops.CMPNE, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JNE, src=(cmp(y),))),
# comparisons whose user doesn't use the flag, move flag result to register
(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="x"), lambda x: UOp(X86Ops.SETB, x.dtype, (cmp(x),))),
(UPat(Ops.CMPLT, dtypes.bool, name="x"), lambda x: UOp(X86Ops.SETL, x.dtype, (cmp(x),))),
(UPat(Ops.CMPEQ, dtypes.bool, name="x"), lambda x: UOp(X86Ops.SETE, x.dtype, (cmp(x),))),
(UPat(Ops.CMPNE, dtypes.bool, name="x"), lambda x: UOp(X86Ops.SETNE, x.dtype, (cmp(x),))),
(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="x"), lambda x: x.ins(X86Ops.SETB, src=(cmp(x),))),
(UPat(Ops.CMPLT, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETL, src=(cmp(x),))),
(UPat(Ops.CMPEQ, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETE, src=(cmp(x),))),
(UPat(Ops.CMPNE, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETNE, src=(cmp(x),))),
# comparisons that produce masks (these aren't bool dtype)
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501
(UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501
(UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int8s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQB)),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int16s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQW)),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int32s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQD)),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int64s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQQ)),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int8s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTB, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int16s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTW, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int32s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTD, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int64s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTQ, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501
(UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501
(UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int8s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQB)),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int16s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQW)),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int32s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQD)),
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int64s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQQ)),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int8s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTB, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int16s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTW, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int32s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTD, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int64s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTQ, src=(b, a))),
# float unary
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: UOp(X86Ops.VSQRTSS, x.dtype, (y, y)) if x.dtype.count == 1 else x.replace(op=X86Ops.VSQRTPS)), # noqa: E501
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: UOp(X86Ops.VSQRTSD, x.dtype, (y, y)) if x.dtype.count == 1 else x.replace(op=X86Ops.VSQRTPD)), # noqa: E501
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDSS, x.dtype, (y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDSD, x.dtype, (y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDPS, x.dtype, (y, imm(dtypes.uint8, 3)))),
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDPD, x.dtype, (y, imm(dtypes.uint8, 3)))),
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSS, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPS)), # noqa: E501
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSD, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPD)), # noqa: E501
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: x.ins(X86Ops.VROUNDSS, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: x.ins(X86Ops.VROUNDSD, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: x.ins(X86Ops.VROUNDPS, src=(y, imm(dtypes.uint8, 3)))),
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: x.ins(X86Ops.VROUNDPD, src=(y, imm(dtypes.uint8, 3)))),
# broadcasts TODO: not quite right, what about load fusion? Also, bitcast should be x86op and reg is xmm?
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTB, x.dtype, (y.bitcast(dtypes.float32),))),
(UPat.var("y", dtypes.int16s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTW, x.dtype, (y.bitcast(dtypes.float32),))),
(UPat.var("y", dtypes.int32s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTD, x.dtype, (y.bitcast(dtypes.float32),))),
(UPat.var("y", dtypes.int64s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTQ, x.dtype, (y.bitcast(dtypes.float64),))),
(UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: UOp(X86Ops.VBROADCASTSS, x.dtype, (y,))),
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VPBROADCASTB, src=(y.bitcast(dtypes.float32),))),
(UPat.var("y", dtypes.int16s).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VPBROADCASTW, src=(y.bitcast(dtypes.float32),))),
(UPat.var("y", dtypes.int32s).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VPBROADCASTD, src=(y.bitcast(dtypes.float32),))),
(UPat.var("y", dtypes.int64s).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VPBROADCASTQ, src=(y.bitcast(dtypes.float64),))),
(UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VBROADCASTSS, src=(y,))),
# shufles
(UPat.var("y", dtypes.int16s).bitcast(dtypes.float16).named("x"), lambda y,x: UOp(X86Ops.VPINSRW, x.dtype, (def_reg(x.dtype), y, imm(dtypes.uint8, 0)))), # noqa: E501
(UPat.var("y", dtypes.int16s).bitcast(dtypes.float16).named("x"), lambda y,x: x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype), y, imm(dtypes.uint8, 0)))), # noqa: E501
(UPat(Ops.VECTORIZE, dtypes.ints+(dtypes.bool,), name="x"), vpins),
(UPat(Ops.VECTORIZE, dtypes.float32, name="x"), vshufps),
(UPat(Ops.VECTORIZE, dtypes.float32, name="x"), vinsertps),
(UPat.var("y", dtypes.float32).gep(name="x"), lambda y,x: UOp(X86Ops.VINSERTPS, x.dtype, (y, y, imm(dtypes.uint8, x.arg[0] << 6)))),
(UPat.var("y", dtypes.float32).gep(name="x"), lambda y,x: x.ins(X86Ops.VINSERTPS, src=(y, y, imm(dtypes.uint8, x.arg[0] << 6)))),
# extract
(UPat.var("y", dtypes.float16).bitcast(dtypes.int16s).named("x"), lambda y,x: UOp(X86Ops.VPEXTRW, x.dtype, (y, imm(dtypes.uint8, 0)))),
(UPat.var("y", dtypes.int8s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRB, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRW, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRD, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRQ, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.float16).bitcast(dtypes.int16s).named("x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, 0)))),
(UPat.var("y", dtypes.int8s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, x.arg[0])))),
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, x.arg[0])))),
# fused multiply add TODO: don't fuse if mul used several times
(UPat.var('a', dtypes.float32) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.alu(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, b, c)), # noqa: E501
(UPat.var('a', dtypes.float64) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.alu(X86Ops.VFMADD213SD if a.dtype.count == 1 else X86Ops.VFMADD213PD, b, c)), # noqa: E501
(UPat.var('a', dtypes.float32) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(a, b, c))), # noqa: E501
(UPat.var('a', dtypes.float64) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.ins(X86Ops.VFMADD213SD if a.dtype.count == 1 else X86Ops.VFMADD213PD, src=(a, b, c))), # noqa: E501
# packed bitwise
((UPat() & UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPAND) if x.dtype.count > 1 else None),
((UPat() | UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPOR) if x.dtype.count > 1 else None),
((UPat() ^ UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPXOR) if x.dtype.count > 1 else None),
((UPat() & UPat()).named("x"), lambda x: x.ins(X86Ops.VPAND) if x.dtype.count > 1 else None),
((UPat() | UPat()).named("x"), lambda x: x.ins(X86Ops.VPOR) if x.dtype.count > 1 else None),
((UPat() ^ UPat()).named("x"), lambda x: x.ins(X86Ops.VPXOR) if x.dtype.count > 1 else None),
# packed int binary
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSLLVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSLLVQ) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSRLVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSRLVQ) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSRAVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDB) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDW) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDQ) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBB) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBW) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBD) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBQ) if x.dtype.count > 1 else None),
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPMULLW) if x.dtype.count > 1 else None),
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMULLD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVQ) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVQ) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRAVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDB) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDW) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDQ) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.ins(X86Ops.VPSUBB) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPSUBW) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPSUBD) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPSUBQ) if x.dtype.count > 1 else None),
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMULLW) if x.dtype.count > 1 else None),
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMULLD) if x.dtype.count > 1 else None),
# scalar int binary
((UPat(dtype=dtypes.uints) // UPat()).named("x"), div),
((UPat(dtype=dtypes.sints) // UPat()).named("x"), idiv),
((UPat.var("a", dtypes.ints) << UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.SHLi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.replace(op=X86Ops.SHL)), # noqa: E501
((UPat.var("a", dtypes.uints) >> UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.SHRi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.replace(op=X86Ops.SHR)), # noqa: E501
((UPat.var("a", dtypes.sints) >> UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.SARi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.replace(op=X86Ops.SHR)), # noqa: E501
((UPat.var("a", dtypes.ints+(dtypes.bool,)) & UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.AND) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.ANDi, src=(a, i))), # noqa: E501
((UPat.var("a", dtypes.ints+(dtypes.bool,)) | UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.OR) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.ORi, src=(a, i))), # noqa: E501
((UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.XOR) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.XORi, src=(a, i))), # noqa: E501
((UPat.var("a", dtypes.ints) * UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.IMUL) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.IMULi, src=(a, i))), # noqa: E501
((UPat.var("a", dtypes.ints) + UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.ADD) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.ADDi, src=(a, i))), # noqa: E501
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.SUB) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.SUBi, src=(a, i))), # noqa: E501
((UPat.var("a", dtypes.ints) << UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.SHLi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.ins(X86Ops.SHL)), # noqa: E501
((UPat.var("a", dtypes.uints) >> UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.SHRi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.ins(X86Ops.SHR)), # noqa: E501
((UPat.var("a", dtypes.sints) >> UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.SARi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.ins(X86Ops.SHR)), # noqa: E501
((UPat.var("a", dtypes.ints+(dtypes.bool,)) & UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.AND) if (i:=to_imm(b)) is None else x.ins(X86Ops.ANDi, src=(a, i))), # noqa: E501
((UPat.var("a", dtypes.ints+(dtypes.bool,)) | UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.OR) if (i:=to_imm(b)) is None else x.ins(X86Ops.ORi, src=(a, i))), # noqa: E501
((UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.XOR) if (i:=to_imm(b)) is None else x.ins(X86Ops.XORi, src=(a, i))), # noqa: E501
((UPat.var("a", dtypes.ints) * UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.IMUL) if (i:=to_imm(b)) is None else x.ins(X86Ops.IMULi, src=(a, i))), # noqa: E501
((UPat.var("a", dtypes.ints) + UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.ADD) if (i:=to_imm(b)) is None else x.ins(X86Ops.ADDi, src=(a, i))), # noqa: E501
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.SUB) if (i:=to_imm(b)) is None else x.ins(X86Ops.SUBi, src=(a, i))), # noqa: E501
# float binary
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VADDSS if x.dtype.count == 1 else X86Ops.VADDPS)),
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VADDSD if x.dtype.count == 1 else X86Ops.VADDPD)),
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.replace(op=X86Ops.VMULSS if x.dtype.count == 1 else X86Ops.VMULPS)),
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.replace(op=X86Ops.VMULSD if x.dtype.count == 1 else X86Ops.VMULPD)),
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VSUBSS if x.dtype.count == 1 else X86Ops.VSUBPS)),
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)),
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)),
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)),
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSS if x.dtype.count == 1 else X86Ops.VADDPS)),
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSD if x.dtype.count == 1 else X86Ops.VADDPD)),
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSS if x.dtype.count == 1 else X86Ops.VMULPS)),
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSD if x.dtype.count == 1 else X86Ops.VMULPD)),
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VSUBSS if x.dtype.count == 1 else X86Ops.VSUBPS)),
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)),
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)),
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)),
# TODO: these should use a.maximum(b) / a.minimum(b)
(UPat(Ops.MAX, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VMAXSS if x.dtype.count == 1 else X86Ops.VMAXPS)),
(UPat(Ops.MAX, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VMAXSD if x.dtype.count == 1 else X86Ops.VMAXPD)),
((UPat.var("a", dtypes.float32) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: UOp(X86Ops.VMINSS if a.dtype.count == 1 else X86Ops.VMINPS, a.dtype, (a, b))), # noqa: E501
((UPat.var("a", dtypes.float64) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: UOp(X86Ops.VMINSD if a.dtype.count == 1 else X86Ops.VMINPD, a.dtype, (a, b))), # noqa: E501
(UPat(Ops.MAX, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VMAXSS if x.dtype.count == 1 else X86Ops.VMAXPS)),
(UPat(Ops.MAX, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VMAXSD if x.dtype.count == 1 else X86Ops.VMAXPD)),
((UPat.var("a", dtypes.float32) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: a.ins(X86Ops.VMINSS if a.dtype.count == 1 else X86Ops.VMINPS, src=(a, b))), # noqa: E501
((UPat.var("a", dtypes.float64) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: a.ins(X86Ops.VMINSD if a.dtype.count == 1 else X86Ops.VMINPD, src=(a, b))), # noqa: E501
# casts
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTPS2DQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTPD2DQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VCVTPS2PD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTPD2PS) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.float16, name="x"), lambda x: x.replace(op=X86Ops.VCVTPS2PH, src=x.src + (imm(dtypes.uint8, 4),))),
(UPat(dtype=dtypes.float16).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTPH2PS)),
(UPat(dtype=dtypes.float32).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTSS2SI)),
(UPat(dtype=dtypes.float64).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTSD2SI)),
(UPat.var("y", dtypes.float32).cast(dtypes.float64, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSS2SD, src=(y, y))),
(UPat.var("y", dtypes.float64).cast(dtypes.float32, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSD2SS, src=(y, y))),
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float32, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSI2SS, src=(def_reg(x.dtype), y))), # noqa: E501
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float64, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSI2SD, src=(def_reg(x.dtype), y))), # noqa: E501
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXBW) if x.dtype.count > 1 else None),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXBD) if x.dtype.count > 1 else None),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXBQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.uint16).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXWD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.uint16).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXWQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.uint32).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXDQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int8).cast(dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXBW) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int8).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXBD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int8).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXBQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int16).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXWD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int16).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXWQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXDQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.replace(op=X86Ops.MOVZX)),
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.MOVSXD)),
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.replace(op=X86Ops.MOVSX)),
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPS2DQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPD2DQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPD2PS) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.float16, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PH, src=x.src + (imm(dtypes.uint8, 4),))),
(UPat(dtype=dtypes.float16).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPH2PS)),
(UPat(dtype=dtypes.float32).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSS2SI)),
(UPat(dtype=dtypes.float64).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSD2SI)),
(UPat.var("y", dtypes.float32).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSS2SD, src=(y, y))),
(UPat.var("y", dtypes.float64).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSD2SS, src=(y, y))),
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SS, src=(def_reg(x.dtype), y))),
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SD, src=(def_reg(x.dtype), y))),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBW) if x.dtype.count > 1 else None),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBD) if x.dtype.count > 1 else None),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.uint16).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXWD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.uint16).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXWQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.uint32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXDQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int8).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBW) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int8).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int8).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int16).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXWD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int16).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXWQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXDQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVZX)),
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVSXD)),
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVSX)),
# bitcasts
(UPat(dtype=dtypes.int32s).bitcast(dtypes.float32).named("x"), lambda x: x.replace(op=X86Ops.VMOVD)),
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.replace(op=X86Ops.VMOVQ)),
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.replace(op=X86Ops.VMOVDm)),
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.replace(op=X86Ops.VMOVQm)),
(UPat(dtype=dtypes.int32s).bitcast(dtypes.float32).named("x"), lambda x: x.ins(X86Ops.VMOVD)),
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.ins(X86Ops.VMOVQ)),
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
# index
(UPat(Ops.INDEX, name="x"), lambda x: x.replace(op=X86Ops.LEA, src=fuse_address(x))),
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fuse_address(x))),
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
# assign, load, store
# NOTE: assign here violates the spec, it only happens in register allocation when a reg to reg move needs to be inserted
(UPat(Ops.ASSIGN, dt_128bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVUPS)),
(UPat(Ops.ASSIGN, dt_64bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSD)),
(UPat(Ops.ASSIGN, dt_32bit+dt_16bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSS)),
(UPat(Ops.ASSIGN, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.replace(op=X86Ops.MOV)),
(UPat(Ops.LOAD, dt_128bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVUPS, src=fuse_address(x.src[0]))),
(UPat(Ops.LOAD, dt_64bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSD, src=fuse_address(x.src[0]))),
(UPat(Ops.LOAD, dt_32bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSS, src=fuse_address(x.src[0]))),
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda x: x.replace(op=X86Ops.VPINSRW, src=(def_reg(x.dtype, x.arg),) + fuse_address(x.src[0]) + (imm(dtypes.uint8, 0),))), # noqa: E501
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.replace(op=X86Ops.MOV, src=fuse_address(x.src[0]))),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_128bit)), name="x"), lambda x: x.replace(op=X86Ops.VMOVUPSm, src=fuse_address(x.src[0]) + (x.src[1],))), # noqa: E501
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_64bit)), name="x"), lambda x: x.replace(op=X86Ops.VMOVSDm, src=fuse_address(x.src[0]) + (x.src[1],))),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_32bit)), name="x"), lambda x: x.replace(op=X86Ops.VMOVSSm, src=fuse_address(x.src[0]) + (x.src[1],))),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_16bit)), name="x"), lambda x: x.replace(op=X86Ops.VPEXTRW, src=fuse_address(x.src[0]) + (x.src[1], imm(dtypes.uint8, 0)))), # noqa: E501
(UPat(Ops.ASSIGN, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS)),
(UPat(Ops.ASSIGN, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD)),
(UPat(Ops.ASSIGN, dt_32bit+dt_16bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS)),
(UPat(Ops.ASSIGN, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV)),
(UPat(Ops.LOAD, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS, src=fuse_address(x.src[0]))),
(UPat(Ops.LOAD, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD, src=fuse_address(x.src[0]))),
(UPat(Ops.LOAD, dt_32bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS, src=fuse_address(x.src[0]))),
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda x: x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype, x.arg),) + fuse_address(x.src[0]) + (imm(dtypes.uint8, 0),))), # noqa: E501
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV, src=fuse_address(x.src[0]))),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_128bit)), name="x"), lambda x: x.ins(X86Ops.VMOVUPSm, src=fuse_address(x.src[0]) + (x.src[1],))),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_64bit)), name="x"), lambda x: x.ins(X86Ops.VMOVSDm, src=fuse_address(x.src[0]) + (x.src[1],))),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_32bit)), name="x"), lambda x: x.ins(X86Ops.VMOVSSm, src=fuse_address(x.src[0]) + (x.src[1],))),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_16bit)), name="x"), lambda x: x.ins(X86Ops.VPEXTRW, src=fuse_address(x.src[0]) + (x.src[1], imm(dtypes.uint8, 0)))), # noqa: E501
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.ints+(dtypes.bool,))), name="x"), lambda x:
x.replace(op=X86Ops.MOVm, src=fuse_address(x.src[0]) + (x.src[1],)) if (i:=to_imm(x.src[1])) is None else x.replace(op=X86Ops.MOVi, src=fuse_address(x.src[0]) + (i,))), # noqa: E501
x.ins(X86Ops.MOVm, src=fuse_address(x.src[0]) + (x.src[1],)) if (i:=to_imm(x.src[1])) is None else x.ins(X86Ops.MOVi, src=fuse_address(x.src[0]) + (i,))), # noqa: E501
# **** X86Op -> X86Op ****
# fuse loads into X86Ops that allow it, if beneficial
(UPat(X86GroupOp.ReadMem1st, src=(UPat(Ops.LOAD),), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 0)),
(UPat(X86GroupOp.ReadMem2nd, src=(UPat(), UPat(Ops.LOAD)), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 1)),
(UPat(X86GroupOp.ReadMem3rd, src=(UPat(), UPat(), UPat(Ops.LOAD)), name="x"), lambda ctx,x: fuse_load(ctx, x, 2)),
*[(UPat.ins(op, src=(UPat(Ops.LOAD),), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 0)) for op in X86GroupOp.ReadMem1st],
*[(UPat.ins(op, src=(UPat(), UPat(Ops.LOAD)), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 1)) for op in X86GroupOp.ReadMem2nd],
*[(UPat.ins(op, src=(UPat(), UPat(), UPat(Ops.LOAD)), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 2)) for op in X86GroupOp.ReadMem3rd], # noqa: E501
# allocate virtual register to X86Op with special constaints
(UPat(X86GroupOp.All, dtypes.ints+dtypes.floats+(dtypes.bool,), name="x"), lambda ctx,x:
x.replace(arg=ctx.vreg(x.arg)) if isinstance(x.arg, tuple) else None),
(UPat(Ops.INS, dtypes.ints+dtypes.floats+(dtypes.bool,), name="x"), lambda ctx,x:
x.replace(tag=ctx.vreg(x.tag)) if isinstance(x.tag, tuple) else None),
# allocate virtual register to X86Op without special constraints
(UPat(X86GroupOp.All, name="x"), lambda ctx,x:
x.replace(arg=ctx.vreg(XMM if x.dtype in dtypes.floats or x.dtype.count > 1 else WGPR)) if x.arg is None and x.dtype != dtypes.void else None),
(UPat(Ops.INS, name="x"), lambda ctx,x:
x.replace(tag=ctx.vreg(XMM if x.dtype in dtypes.floats or x.dtype.count > 1 else WGPR)) if x.tag is None and x.dtype != dtypes.void else None),
])
# ***** post register allocation *****
@ -402,15 +401,15 @@ isel_matcher = PatternMatcher([
# final rewrite to match the isa spec
post_regalloc_matcher = PatternMatcher([
# alloc stack space
(UPat(X86Ops.DEFINE_REG, dtypes.uint64, arg=RSP, name="x"), lambda ctx,x:
(x, [x, UOp(X86Ops.SUBi, dtypes.uint64, (imm(dtypes.uint32, ctx.stack_size),), RSP)]) if ctx.stack_size > 0 else None),
(UPat.ins(X86Ops.DEFINE_REG, dtype=dtypes.uint64, tag=RSP, name="x"), lambda ctx,x:
(x, [x, x.ins(X86Ops.SUBi, src=(imm(dtypes.uint32, ctx.stack_size),), tag=RSP)]) if ctx.stack_size > 0 else None),
# dealloc stack space
(UPat(X86Ops.RET, name="x"), lambda ctx,x:
(UPat.ins(X86Ops.RET, name="x"), lambda ctx,x:
(x, [UOp(X86Ops.ADDi, dtypes.uint64, (imm(dtypes.uint32, ctx.stack_size),), RSP), x]) if ctx.stack_size > 0 else None),
# rewrite FRAME_INDEX to IMM now that the stack size is known
(UPat(X86Ops.FRAME_INDEX, name="x"), lambda ctx,x: (nx:=x.replace(op=X86Ops.IMM, arg=ctx.stack_size + x.arg), [nx])),
# rewrite RANGE to MOV reg, 0. Terrible HACK to pass the CONST to the END
(UPat(Ops.RANGE, name="x"), lambda x: (nx:=x.replace(op=X86Ops.MOVi, src=(imm(x.dtype, 0),), tag=x.src[0].arg), [nx])),
(UPat(Ops.RANGE, name="x"), lambda x: (nx:=x.ins(X86Ops.MOVi, src=(imm(x.dtype, 0),), tag=x.src[0].arg), [nx])),
# rewrite END to ADD 1 -> CMPLT -> JUMP
(UPat(Ops.END, name="x"), lambda x:
(jl:=x.replace(op=X86Ops.JL, src=(x.src[1], cmp:=UOp(X86Ops.CMPi if isinstance(x.src[1].tag, int) else X86Ops.CMP,
@ -419,7 +418,7 @@ post_regalloc_matcher = PatternMatcher([
# TODO: need a generic way to model clobbers, idiv and flags should be handled the same way, maybe add clobber field to Register?
# fixup div, zero rdx again because scheduling constraint isn't being respected
(UPat(X86Ops.DIV, name="x"), lambda x:
(nx:=x.replace(src=x.src[:1]), [UOp(X86Ops.MOVi, x.dtype, (imm(min(dtypes.uint32, x.dtype), 0),), RDX), nx])),
(nx:=x.replace(src=x.src[:1]), [x.ins(X86Ops.MOVi, src=(imm(min(dtypes.uint32, x.dtype), 0),), tag=RDX), nx])),
# rewrite two address instructions to two address form, if reused src wasn't coalesced insert a move
(UPat(X86GroupOp.TwoAddress1st, name="x"), lambda ctx,x:
(nx:=x.replace(src=x.src[1:]), [assign(ctx, x.src[0], x.arg), nx] if x.arg != x.src[0].arg else [nx])),
@ -431,11 +430,11 @@ isa_spec = PatternMatcher([
# these are the only non X86Ops allowed
(UPat((Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER)), lambda: True),
# vblends take a mask which is float or int dtype
(UPat((X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD), src=(UPat.var("a"), UPat.var("b"), UPat.var("m")), name="x"),
lambda a,b,m,x: x.dtype == a.dtype == b.dtype and x.dtype.itemsize == m.dtype.itemsize),
#(UPat((X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD), src=(UPat.var("a"), UPat.var("b"), UPat.var("m")), name="x"),
# lambda a,b,m,x: x.dtype == a.dtype == b.dtype and x.dtype.itemsize == m.dtype.itemsize),
# cmoves take a flag producing instruction
(UPat((X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE), dtypes.bool, (UPat(), UPat(), UPat(X86GroupOp.WriteFlags))), lambda: True),
(UPat(X86GroupOp.All), lambda: True),
#(UPat((X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE), dtypes.bool, (UPat(), UPat(), UPat(X86GroupOp.WriteFlags))), lambda: True),
(UPat(Ops.INS), lambda: True),
])
# ***** X86 instruction encoding *****
@ -447,15 +446,15 @@ def to_bytes(dt:DType, v:int|float):
def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
# when a uop writes to memory it takes the form of a store, dtype is void, no definition
if x.op in X86GroupOp.WriteMem:
if x.arg in X86GroupOp.WriteMem:
if len(x.src) > 3: address, rest = x.src[:3], x.src[3:]
else: address, rest = (x, None, None), x.src
elif x.op in X86GroupOp.ReadMem1st or x.op in X86GroupOp.ReadMem2nd and x.op in X86GroupOp.TwoAddress1st:
elif x.arg in X86GroupOp.ReadMem1st or x.arg in X86GroupOp.ReadMem2nd and x.arg in X86GroupOp.TwoAddress1st:
if len(x.src) > 2: address, rest = x.src[:3], (x,) + x.src[3:]
else: address, rest = (x.src[0], None, None), (x,) + x.src[1:]
elif x.op in X86GroupOp.ReadMem2nd or x.op in X86GroupOp.ReadMem3rd and x.op in X86GroupOp.TwoAddress1st:
elif x.arg in X86GroupOp.ReadMem2nd or x.arg in X86GroupOp.ReadMem3rd and x.arg in X86GroupOp.TwoAddress1st:
if len(x.src) > 3: address, rest = x.src[1:4], x.src[:1] + x.src[4:]
else: address, rest = (x.src[1], None, None), x.src[:1] + x.src[2:]
if x.dtype is not dtypes.void: rest = (x,) + rest
@ -464,12 +463,12 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
# get the encoding values of the different fields
reg_sz = (rest[0].dtype.itemsize if not isinstance(rest[0].dtype, PtrDType) else 8) if reg is None else 0
reg = cast(Register, rest[0].arg).index if reg is None else reg
vvvv = rest[1].arg.index if len(rest) > 1 and isinstance(rest[1].arg, Register) else 0
rm = cast(Register, address[0].arg).index
idx = cast(Register, address[1].arg).index if address[1] is not None and address[1].arg is not None else 4
reg = cast(Register, rest[0].tag).index if reg is None else reg
vvvv = rest[1].tag.index if len(rest) > 1 and isinstance(rest[1].tag, Register) else 0
rm = cast(Register, address[0].tag).index
idx = cast(Register, address[1].tag).index if address[1] is not None and address[1].tag is not None else 4
disp_uop = address[2]
imm_uop = rest[-1] if rest[-1].op is X86Ops.IMM or len(rest) == 3 else None
imm_uop = rest[-1] if rest[-1].arg is X86Ops.IMM or len(rest) == 3 else None
# TODO: another reason to get rid of ptrs, if we access memory the size should be in scale uop otherwise size is in rm
rm_sz = 8 if isinstance(address[0].dtype, PtrDType) and disp_uop is None else address[0].dtype.itemsize
@ -494,7 +493,7 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
if w | r | _x | b | (reg_sz == 1 & reg >> 2) | (rm_sz == 1 & rm >> 2): inst += bytes([0b0100 << 4 | w << 3 | r << 2 | _x << 1 | b])
# OPCODE byte
# legacy 8bit opcodes are 1 less than 16-64bit versions, with these exceptions
real_opc = opc-1 if (rm_sz == 1 or reg_sz == 1) and x.op not in {X86Ops.SETB, X86Ops.SETE, X86Ops.SETL, X86Ops.SETNE, X86Ops.LEA} else opc
real_opc = opc-1 if (rm_sz == 1 or reg_sz == 1) and x.arg not in {X86Ops.SETB, X86Ops.SETE, X86Ops.SETL, X86Ops.SETNE, X86Ops.LEA} else opc
inst += real_opc.to_bytes((real_opc.bit_length() + 7) // 8, 'big')
# MODRM byte
# now we only care about the lower 3 bits
@ -506,7 +505,7 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
if disp_uop is not None:
assert disp_uop.dtype in (dtypes.int8, dtypes.int32), "displacement can only be 1 or 4 byte signed int"
# rbp/r13 always require a displacement
if disp_uop.arg != 0 or rm == 0b101: mod = 0b01 if disp_uop.dtype.itemsize == 1 else 0b10
if disp_uop.tag != 0 or rm == 0b101: mod = 0b01 if disp_uop.dtype.itemsize == 1 else 0b10
else: mod = 0b00
else: mod = 0b11
# x 0b0 and idx 0b100 means rsp which means no index exists
@ -520,11 +519,11 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
# DISP byte
if mod == 0b01 or mod == 0b10:
assert disp_uop is not None
inst += disp_uop.arg.to_bytes(disp_uop.dtype.itemsize, 'little', signed=True)
inst += disp_uop.tag.to_bytes(disp_uop.dtype.itemsize, 'little', signed=True)
# IMM byte
if imm_uop is not None:
if isinstance(imm_uop.arg, Register): inst += bytes([(imm_uop.arg.index & 0b1111) << 4 | 0b0000])
else: inst += to_bytes(imm_uop.dtype, imm_uop.arg)
if isinstance(imm_uop.tag, Register): inst += bytes([(imm_uop.tag.index & 0b1111) << 4 | 0b0000])
else: inst += to_bytes(imm_uop.dtype, imm_uop.tag)
return inst
# https://www.felixcloutier.com/x86/
@ -533,109 +532,109 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
# map select: 0F == 1, 0F38 == 2, 0F3A == 3
encodings = PatternMatcher([
# moves
(UPat(X86Ops.MOVABS, name="x"), lambda x:
(UPat.ins(X86Ops.MOVABS, name="x"), lambda x:
bytes([0b0100 << 4 | 0b1 << 3 | 0b00 << 2 | x.arg.index >> 3, 0xB8 + (x.arg.index & 0b111)]) + to_bytes(x.src[0].dtype, x.src[0].arg)),
(UPat(X86Ops.MOV, name="x"), lambda x: encode(x, 0x8B)), (UPat(X86Ops.MOVi, name="x"), lambda x: encode(x, 0xC7, reg=0)),
(UPat(X86Ops.MOVm, name="x"), lambda x: encode(x, 0x89)), (UPat(X86Ops.LEA, name="x"), lambda x: encode(x, 0x8D)),
(UPat(X86Ops.VMOVSS, name="x"), lambda x: encode(x, 0x10, pp=2, sel=1)), (UPat(X86Ops.VMOVSSm, name="x"), lambda x: encode(x, 0x11, pp=2, sel=1)),
(UPat(X86Ops.VMOVSD, name="x"), lambda x: encode(x, 0x10, pp=3, sel=1)), (UPat(X86Ops.VMOVSDm, name="x"), lambda x: encode(x, 0x11, pp=3, sel=1)),
(UPat(X86Ops.VMOVUPS, name="x"), lambda x: encode(x, 0x10, pp=0, sel=1)), (UPat(X86Ops.VMOVUPSm, name="x"), lambda x: encode(x, 0x11, pp=0, sel=1)), # noqa: E501
(UPat(X86Ops.VMOVD, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1)), (UPat(X86Ops.VMOVQ, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1, we=1)), # noqa: E501
(UPat(X86Ops.VMOVDm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1)), (UPat(X86Ops.VMOVQm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1, we=1)), # noqa: E501
(UPat.ins(X86Ops.MOV, name="x"), lambda x: encode(x, 0x8B)), (UPat.ins(X86Ops.MOVi, name="x"), lambda x: encode(x, 0xC7, reg=0)),
(UPat.ins(X86Ops.MOVm, name="x"), lambda x: encode(x, 0x89)), (UPat.ins(X86Ops.LEA, name="x"), lambda x: encode(x, 0x8D)),
(UPat.ins(X86Ops.VMOVSS, name="x"), lambda x: encode(x, 0x10, pp=2, sel=1)), (UPat.ins(X86Ops.VMOVSSm, name="x"), lambda x: encode(x, 0x11, pp=2, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VMOVSD, name="x"), lambda x: encode(x, 0x10, pp=3, sel=1)), (UPat.ins(X86Ops.VMOVSDm, name="x"), lambda x: encode(x, 0x11, pp=3, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VMOVUPS, name="x"), lambda x: encode(x, 0x10, pp=0, sel=1)), (UPat.ins(X86Ops.VMOVUPSm, name="x"), lambda x: encode(x, 0x11, pp=0, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VMOVD, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1)), (UPat.ins(X86Ops.VMOVQ, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1, we=1)), # noqa: E501
(UPat.ins(X86Ops.VMOVDm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1)), (UPat.ins(X86Ops.VMOVQm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1, we=1)), # noqa: E501
# casts
(UPat(X86Ops.MOVZX, name="x"), lambda x: encode(x, 0x0FB7)),
(UPat(X86Ops.MOVSX, name="x"), lambda x: encode(x, 0x0FBF)), (UPat(X86Ops.MOVSXD, name="x"), lambda x: encode(x, 0x63)),
(UPat(X86Ops.VPMOVZXBW, name="x"), lambda x: encode(x, 0x30, pp=1, sel=2)), (UPat(X86Ops.VPMOVZXBD, name="x"), lambda x: encode(x, 0x31, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VPMOVZXBQ, name="x"), lambda x: encode(x, 0x32, pp=1, sel=2)), (UPat(X86Ops.VPMOVZXWD, name="x"), lambda x: encode(x, 0x33, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VPMOVZXWQ, name="x"), lambda x: encode(x, 0x34, pp=1, sel=2)), (UPat(X86Ops.VPMOVZXDQ, name="x"), lambda x: encode(x, 0x35, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VPMOVSXBW, name="x"), lambda x: encode(x, 0x20, pp=1, sel=2)), (UPat(X86Ops.VPMOVSXBD, name="x"), lambda x: encode(x, 0x21, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VPMOVSXBQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=2)), (UPat(X86Ops.VPMOVSXWD, name="x"), lambda x: encode(x, 0x23, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VPMOVSXWQ, name="x"), lambda x: encode(x, 0x24, pp=1, sel=2)), (UPat(X86Ops.VPMOVSXDQ, name="x"), lambda x: encode(x, 0x25, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VCVTSS2SD, name="x"), lambda x: encode(x, 0x5A, pp=2, sel=1)), (UPat(X86Ops.VCVTSD2SS, name="x"), lambda x: encode(x, 0x5A, pp=3, sel=1)), # noqa: E501
(UPat(X86Ops.VCVTPH2PS, name="x"), lambda x: encode(x, 0x13, pp=1, sel=2)), (UPat(X86Ops.VCVTPS2PH, name="x"), lambda x: encode(x, 0x1D, pp=1, sel=3)), # noqa: E501
(UPat(X86Ops.VCVTDQ2PS, name="x"), lambda x: encode(x, 0x5B, pp=0, sel=1)), (UPat(X86Ops.VCVTDQ2PD, name="x"), lambda x: encode(x, 0xE6, pp=2, sel=1)), # noqa: E501
(UPat(X86Ops.VCVTPS2PD, name="x"), lambda x: encode(x, 0x5A, pp=0, sel=1)), (UPat(X86Ops.VCVTPD2PS, name="x"), lambda x: encode(x, 0x5A, pp=1, sel=1)), # noqa: E501
(UPat(X86Ops.VCVTTPS2DQ, name="x"), lambda x: encode(x, 0x5B, pp=2, sel=1)), (UPat(X86Ops.VCVTTPD2DQ, name="x"), lambda x: encode(x, 0xE6, pp=1, sel=1)), # noqa: E501
(UPat(X86Ops.VCVTSI2SS, name="x"), lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.base is dtypes.int64)),
(UPat(X86Ops.VCVTSI2SD, name="x"), lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.base is dtypes.int64)),
(UPat(X86Ops.VCVTTSS2SI, name="x"), lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype in dtypes.int64s)),
(UPat(X86Ops.VCVTTSD2SI, name="x"), lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype in dtypes.int64s)),
(UPat.ins(X86Ops.MOVZX, name="x"), lambda x: encode(x, 0x0FB7)),
(UPat.ins(X86Ops.MOVSX, name="x"), lambda x: encode(x, 0x0FBF)), (UPat.ins(X86Ops.MOVSXD, name="x"), lambda x: encode(x, 0x63)),
(UPat.ins(X86Ops.VPMOVZXBW, name="x"), lambda x: encode(x, 0x30, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVZXBD, name="x"), lambda x: encode(x, 0x31, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VPMOVZXBQ, name="x"), lambda x: encode(x, 0x32, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVZXWD, name="x"), lambda x: encode(x, 0x33, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VPMOVZXWQ, name="x"), lambda x: encode(x, 0x34, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVZXDQ, name="x"), lambda x: encode(x, 0x35, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VPMOVSXBW, name="x"), lambda x: encode(x, 0x20, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVSXBD, name="x"), lambda x: encode(x, 0x21, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VPMOVSXBQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVSXWD, name="x"), lambda x: encode(x, 0x23, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VPMOVSXWQ, name="x"), lambda x: encode(x, 0x24, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVSXDQ, name="x"), lambda x: encode(x, 0x25, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VCVTSS2SD, name="x"), lambda x: encode(x, 0x5A, pp=2, sel=1)), (UPat.ins(X86Ops.VCVTSD2SS, name="x"), lambda x: encode(x, 0x5A, pp=3, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VCVTPH2PS, name="x"), lambda x: encode(x, 0x13, pp=1, sel=2)), (UPat.ins(X86Ops.VCVTPS2PH, name="x"), lambda x: encode(x, 0x1D, pp=1, sel=3)), # noqa: E501
(UPat.ins(X86Ops.VCVTDQ2PS, name="x"), lambda x: encode(x, 0x5B, pp=0, sel=1)), (UPat.ins(X86Ops.VCVTDQ2PD, name="x"), lambda x: encode(x, 0xE6, pp=2, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VCVTPS2PD, name="x"), lambda x: encode(x, 0x5A, pp=0, sel=1)), (UPat.ins(X86Ops.VCVTPD2PS, name="x"), lambda x: encode(x, 0x5A, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VCVTTPS2DQ, name="x"), lambda x: encode(x, 0x5B, pp=2, sel=1)), (UPat.ins(X86Ops.VCVTTPD2DQ, name="x"), lambda x: encode(x, 0xE6, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VCVTSI2SS, name="x"), lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.base is dtypes.int64)),
(UPat.ins(X86Ops.VCVTSI2SD, name="x"), lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.base is dtypes.int64)),
(UPat.ins(X86Ops.VCVTTSS2SI, name="x"), lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype in dtypes.int64s)),
(UPat.ins(X86Ops.VCVTTSD2SI, name="x"), lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype in dtypes.int64s)),
# int division
(UPat(X86Ops.IDIV, name="x"), lambda x: encode(x, 0xF7, reg=7)), (UPat(X86Ops.DIV, name="x"), lambda x: encode(x, 0xF7, reg=6)),
(UPat.ins(X86Ops.IDIV, name="x"), lambda x: encode(x, 0xF7, reg=7)), (UPat.ins(X86Ops.DIV, name="x"), lambda x: encode(x, 0xF7, reg=6)),
# scalar int binary
(UPat(X86Ops.SHLi, name="x"), lambda x: encode(x, 0xC1, reg=4)),
(UPat(X86Ops.SHRi, name="x"), lambda x: encode(x, 0xC1, reg=5)), (UPat(X86Ops.SARi, name="x"), lambda x: encode(x, 0xC1, reg=7)),
(UPat(X86Ops.ADD, name="x"), lambda x: encode(x, 0x03)), (UPat(X86Ops.ADDi, name="x"), lambda x: encode(x, 0x81, reg=0)),
(UPat(X86Ops.SUB, name="x"), lambda x: encode(x, 0x2B)), (UPat(X86Ops.SUBi, name="x"), lambda x: encode(x, 0x81, reg=5)),
(UPat(X86Ops.AND, name="x"), lambda x: encode(x, 0x23)), (UPat(X86Ops.ANDi, name="x"), lambda x: encode(x, 0x81, reg=4)),
(UPat(X86Ops.XOR, name="x"), lambda x: encode(x, 0x33)), (UPat(X86Ops.XORi, name="x"), lambda x: encode(x, 0x81, reg=6)),
(UPat(X86Ops.OR, name="x"), lambda x: encode(x, 0x0B)), (UPat(X86Ops.ORi, name="x"), lambda x: encode(x, 0x81, reg=1)),
(UPat(X86Ops.CMP, name="x"), lambda x: encode(x, 0x3B)), (UPat(X86Ops.CMPi, name="x"), lambda x: encode(x, 0x81, reg=7)),
(UPat(X86Ops.IMUL, name="x"), lambda x: encode(x, 0x0FAF)), (UPat(X86Ops.IMULi, name="x"), lambda x: encode(x, 0x69)),
(UPat(X86Ops.SETB, name="x"), lambda x: encode(x, 0x0F92, reg=0)), (UPat(X86Ops.SETL, name="x"), lambda x: encode(x, 0x0F9C, reg=0)),
(UPat(X86Ops.SETE, name="x"), lambda x: encode(x, 0x0F94, reg=0)), (UPat(X86Ops.SETNE, name="x"), lambda x: encode(x, 0x0F95, reg=0)),
(UPat.ins(X86Ops.SHLi, name="x"), lambda x: encode(x, 0xC1, reg=4)),
(UPat.ins(X86Ops.SHRi, name="x"), lambda x: encode(x, 0xC1, reg=5)), (UPat.ins(X86Ops.SARi, name="x"), lambda x: encode(x, 0xC1, reg=7)),
(UPat.ins(X86Ops.ADD, name="x"), lambda x: encode(x, 0x03)), (UPat.ins(X86Ops.ADDi, name="x"), lambda x: encode(x, 0x81, reg=0)),
(UPat.ins(X86Ops.SUB, name="x"), lambda x: encode(x, 0x2B)), (UPat.ins(X86Ops.SUBi, name="x"), lambda x: encode(x, 0x81, reg=5)),
(UPat.ins(X86Ops.AND, name="x"), lambda x: encode(x, 0x23)), (UPat.ins(X86Ops.ANDi, name="x"), lambda x: encode(x, 0x81, reg=4)),
(UPat.ins(X86Ops.XOR, name="x"), lambda x: encode(x, 0x33)), (UPat.ins(X86Ops.XORi, name="x"), lambda x: encode(x, 0x81, reg=6)),
(UPat.ins(X86Ops.OR, name="x"), lambda x: encode(x, 0x0B)), (UPat.ins(X86Ops.ORi, name="x"), lambda x: encode(x, 0x81, reg=1)),
(UPat.ins(X86Ops.CMP, name="x"), lambda x: encode(x, 0x3B)), (UPat.ins(X86Ops.CMPi, name="x"), lambda x: encode(x, 0x81, reg=7)),
(UPat.ins(X86Ops.IMUL, name="x"), lambda x: encode(x, 0x0FAF)), (UPat.ins(X86Ops.IMULi, name="x"), lambda x: encode(x, 0x69)),
(UPat.ins(X86Ops.SETB, name="x"), lambda x: encode(x, 0x0F92, reg=0)), (UPat.ins(X86Ops.SETL, name="x"), lambda x: encode(x, 0x0F9C, reg=0)),
(UPat.ins(X86Ops.SETE, name="x"), lambda x: encode(x, 0x0F94, reg=0)), (UPat.ins(X86Ops.SETNE, name="x"), lambda x: encode(x, 0x0F95, reg=0)),
# packed bitwise NOTE: only bitwise and packed
(UPat(X86Ops.VPAND, name="x"), lambda x: encode(x, 0xDB, pp=1, sel=1)), (UPat(X86Ops.VPXOR, name="x"), lambda x: encode(x, 0xEF, pp=1, sel=1)),
(UPat(X86Ops.VPOR, name="x"), lambda x: encode(x, 0xEB, pp=1, sel=1)),
(UPat.ins(X86Ops.VPAND, name="x"), lambda x: encode(x, 0xDB, pp=1, sel=1)), (UPat.ins(X86Ops.VPXOR, name="x"), lambda x: encode(x, 0xEF, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VPOR, name="x"), lambda x: encode(x, 0xEB, pp=1, sel=1)),
# unary
(UPat(X86Ops.VSQRTSS, name="x"), lambda x: encode(x, 0x51, pp=2, sel=1)), (UPat(X86Ops.VSQRTPS, name="x"), lambda x: encode(x, 0x51, pp=0, sel=1)),
(UPat(X86Ops.VSQRTSD, name="x"), lambda x: encode(x, 0x51, pp=3, sel=1)), (UPat(X86Ops.VSQRTPD, name="x"), lambda x: encode(x, 0x51, pp=1, sel=1)),
(UPat(X86Ops.VROUNDSS, name="x"), lambda x: encode(x, 0x0A, pp=1, sel=3)), (UPat(X86Ops.VROUNDPS, name="x"), lambda x: encode(x, 0x08, pp=1, sel=3)), # noqa: E501
(UPat(X86Ops.VROUNDSD, name="x"), lambda x: encode(x, 0x0B, pp=1, sel=3)), (UPat(X86Ops.VROUNDPD, name="x"), lambda x: encode(x, 0x09, pp=1, sel=3)), # noqa: E501
(UPat.ins(X86Ops.VSQRTSS, name="x"), lambda x: encode(x, 0x51, pp=2, sel=1)), (UPat.ins(X86Ops.VSQRTPS, name="x"), lambda x: encode(x, 0x51, pp=0, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VSQRTSD, name="x"), lambda x: encode(x, 0x51, pp=3, sel=1)), (UPat.ins(X86Ops.VSQRTPD, name="x"), lambda x: encode(x, 0x51, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VROUNDSS, name="x"), lambda x: encode(x, 0x0A, pp=1, sel=3)), (UPat.ins(X86Ops.VROUNDPS, name="x"), lambda x: encode(x, 0x08, pp=1, sel=3)), # noqa: E501
(UPat.ins(X86Ops.VROUNDSD, name="x"), lambda x: encode(x, 0x0B, pp=1, sel=3)), (UPat.ins(X86Ops.VROUNDPD, name="x"), lambda x: encode(x, 0x09, pp=1, sel=3)), # noqa: E501
# packed int binary
(UPat(X86Ops.VPSLLVD, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2)), (UPat(X86Ops.VPSLLVQ, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2, we=1)), # noqa: E501
(UPat(X86Ops.VPSRLVD, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2)), (UPat(X86Ops.VPSRLVQ, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2, we=1)), # noqa: E501
(UPat(X86Ops.VPCMPGTB, name="x"), lambda x: encode(x, 0x64, pp=1, sel=1)), (UPat(X86Ops.VPCMPGTW, name="x"), lambda x: encode(x, 0x65, pp=1, sel=1)), # noqa: E501
(UPat(X86Ops.VPCMPGTD, name="x"), lambda x: encode(x, 0x66, pp=1, sel=1)), (UPat(X86Ops.VPCMPGTQ, name="x"), lambda x: encode(x, 0x37, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VPCMPEQB, name="x"), lambda x: encode(x, 0x74, pp=1, sel=1)), (UPat(X86Ops.VPCMPEQW, name="x"), lambda x: encode(x, 0x75, pp=1, sel=1)), # noqa: E501
(UPat(X86Ops.VPCMPEQD, name="x"), lambda x: encode(x, 0x76, pp=1, sel=1)), (UPat(X86Ops.VPCMPEQQ, name="x"), lambda x: encode(x, 0x29, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VPMULLW, name="x"), lambda x: encode(x, 0xD5, pp=1, sel=1)), (UPat(X86Ops.VPMULLD, name="x"), lambda x: encode(x, 0x40, pp=1, sel=2)),
(UPat(X86Ops.VPADDB, name="x"), lambda x: encode(x, 0xFC, pp=1, sel=1)), (UPat(X86Ops.VPADDW, name="x"), lambda x: encode(x, 0xFD, pp=1, sel=1)),
(UPat(X86Ops.VPADDD, name="x"), lambda x: encode(x, 0xFE, pp=1, sel=1)), (UPat(X86Ops.VPADDQ, name="x"), lambda x: encode(x, 0xD4, pp=1, sel=1)),
(UPat(X86Ops.VPSUBB, name="x"), lambda x: encode(x, 0xF8, pp=1, sel=1)), (UPat(X86Ops.VPSUBW, name="x"), lambda x: encode(x, 0xF9, pp=1, sel=1)),
(UPat(X86Ops.VPSUBD, name="x"), lambda x: encode(x, 0xFA, pp=1, sel=1)), (UPat(X86Ops.VPSUBQ, name="x"), lambda x: encode(x, 0xFB, pp=1, sel=1)),
(UPat(X86Ops.VPSRAVD, name="x"), lambda x: encode(x, 0x46, pp=1, sel=2)),
(UPat.ins(X86Ops.VPSLLVD, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2)), (UPat.ins(X86Ops.VPSLLVQ, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2, we=1)), # noqa: E501
(UPat.ins(X86Ops.VPSRLVD, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2)), (UPat.ins(X86Ops.VPSRLVQ, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2, we=1)), # noqa: E501
(UPat.ins(X86Ops.VPCMPGTB, name="x"), lambda x: encode(x, 0x64, pp=1, sel=1)), (UPat.ins(X86Ops.VPCMPGTW, name="x"), lambda x: encode(x, 0x65, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VPCMPGTD, name="x"), lambda x: encode(x, 0x66, pp=1, sel=1)), (UPat.ins(X86Ops.VPCMPGTQ, name="x"), lambda x: encode(x, 0x37, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VPCMPEQB, name="x"), lambda x: encode(x, 0x74, pp=1, sel=1)), (UPat.ins(X86Ops.VPCMPEQW, name="x"), lambda x: encode(x, 0x75, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VPCMPEQD, name="x"), lambda x: encode(x, 0x76, pp=1, sel=1)), (UPat.ins(X86Ops.VPCMPEQQ, name="x"), lambda x: encode(x, 0x29, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VPMULLW, name="x"), lambda x: encode(x, 0xD5, pp=1, sel=1)), (UPat.ins(X86Ops.VPMULLD, name="x"), lambda x: encode(x, 0x40, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VPADDB, name="x"), lambda x: encode(x, 0xFC, pp=1, sel=1)), (UPat.ins(X86Ops.VPADDW, name="x"), lambda x: encode(x, 0xFD, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VPADDD, name="x"), lambda x: encode(x, 0xFE, pp=1, sel=1)), (UPat.ins(X86Ops.VPADDQ, name="x"), lambda x: encode(x, 0xD4, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VPSUBB, name="x"), lambda x: encode(x, 0xF8, pp=1, sel=1)), (UPat.ins(X86Ops.VPSUBW, name="x"), lambda x: encode(x, 0xF9, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VPSUBD, name="x"), lambda x: encode(x, 0xFA, pp=1, sel=1)), (UPat.ins(X86Ops.VPSUBQ, name="x"), lambda x: encode(x, 0xFB, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VPSRAVD, name="x"), lambda x: encode(x, 0x46, pp=1, sel=2)),
# float cmp
(UPat(X86Ops.VUCOMISS, name="x"), lambda x: encode(x, 0x2E, pp=0, sel=1)), (UPat(X86Ops.VUCOMISD, name="x"), lambda x: encode(x, 0x2E, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VUCOMISS, name="x"), lambda x: encode(x, 0x2E, pp=0, sel=1)), (UPat.ins(X86Ops.VUCOMISD, name="x"), lambda x: encode(x, 0x2E, pp=1, sel=1)), # noqa: E501
# scalar / packed float binary
(UPat(X86Ops.VADDSS, name="x"), lambda x: encode(x, 0x58, pp=2, sel=1)), (UPat(X86Ops.VADDPS, name="x"), lambda x: encode(x, 0x58, pp=0, sel=1)),
(UPat(X86Ops.VADDSD, name="x"), lambda x: encode(x, 0x58, pp=3, sel=1)), (UPat(X86Ops.VADDPD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=1)),
(UPat(X86Ops.VSUBSS, name="x"), lambda x: encode(x, 0x5C, pp=2, sel=1)), (UPat(X86Ops.VSUBPS, name="x"), lambda x: encode(x, 0x5C, pp=0, sel=1)),
(UPat(X86Ops.VSUBSD, name="x"), lambda x: encode(x, 0x5C, pp=3, sel=1)), (UPat(X86Ops.VSUBPD, name="x"), lambda x: encode(x, 0x5C, pp=1, sel=1)),
(UPat(X86Ops.VMULSS, name="x"), lambda x: encode(x, 0x59, pp=2, sel=1)), (UPat(X86Ops.VMULPS, name="x"), lambda x: encode(x, 0x59, pp=0, sel=1)),
(UPat(X86Ops.VMULSD, name="x"), lambda x: encode(x, 0x59, pp=3, sel=1)), (UPat(X86Ops.VMULPD, name="x"), lambda x: encode(x, 0x59, pp=1, sel=1)),
(UPat(X86Ops.VDIVSS, name="x"), lambda x: encode(x, 0x5E, pp=2, sel=1)), (UPat(X86Ops.VDIVPS, name="x"), lambda x: encode(x, 0x5E, pp=0, sel=1)),
(UPat(X86Ops.VDIVSD, name="x"), lambda x: encode(x, 0x5E, pp=3, sel=1)), (UPat(X86Ops.VDIVPD, name="x"), lambda x: encode(x, 0x5E, pp=1, sel=1)),
(UPat(X86Ops.VCMPSS, name="x"), lambda x: encode(x, 0xC2, pp=2, sel=1)), (UPat(X86Ops.VCMPPS, name="x"), lambda x: encode(x, 0xC2, pp=0, sel=1)),
(UPat(X86Ops.VCMPSD, name="x"), lambda x: encode(x, 0xC2, pp=3, sel=1)), (UPat(X86Ops.VCMPPD, name="x"), lambda x: encode(x, 0xC2, pp=1, sel=1)),
(UPat(X86Ops.VMAXSS, name="x"), lambda x: encode(x, 0x5F, pp=2, sel=1)), (UPat(X86Ops.VMAXPS, name="x"), lambda x: encode(x, 0x5F, pp=0, sel=1)),
(UPat(X86Ops.VMAXSD, name="x"), lambda x: encode(x, 0x5F, pp=3, sel=1)), (UPat(X86Ops.VMAXPD, name="x"), lambda x: encode(x, 0x5F, pp=1, sel=1)),
(UPat(X86Ops.VMINSS, name="x"), lambda x: encode(x, 0x5D, pp=2, sel=1)), (UPat(X86Ops.VMINPS, name="x"), lambda x: encode(x, 0x5D, pp=0, sel=1)),
(UPat(X86Ops.VMINSD, name="x"), lambda x: encode(x, 0x5D, pp=3, sel=1)), (UPat(X86Ops.VMINPD, name="x"), lambda x: encode(x, 0x5D, pp=1, sel=1)),
(UPat.ins(X86Ops.VADDSS, name="x"), lambda x: encode(x, 0x58, pp=2, sel=1)), (UPat.ins(X86Ops.VADDPS, name="x"), lambda x: encode(x, 0x58, pp=0, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VADDSD, name="x"), lambda x: encode(x, 0x58, pp=3, sel=1)), (UPat.ins(X86Ops.VADDPD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VSUBSS, name="x"), lambda x: encode(x, 0x5C, pp=2, sel=1)), (UPat.ins(X86Ops.VSUBPS, name="x"), lambda x: encode(x, 0x5C, pp=0, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VSUBSD, name="x"), lambda x: encode(x, 0x5C, pp=3, sel=1)), (UPat.ins(X86Ops.VSUBPD, name="x"), lambda x: encode(x, 0x5C, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VMULSS, name="x"), lambda x: encode(x, 0x59, pp=2, sel=1)), (UPat.ins(X86Ops.VMULPS, name="x"), lambda x: encode(x, 0x59, pp=0, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VMULSD, name="x"), lambda x: encode(x, 0x59, pp=3, sel=1)), (UPat.ins(X86Ops.VMULPD, name="x"), lambda x: encode(x, 0x59, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VDIVSS, name="x"), lambda x: encode(x, 0x5E, pp=2, sel=1)), (UPat.ins(X86Ops.VDIVPS, name="x"), lambda x: encode(x, 0x5E, pp=0, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VDIVSD, name="x"), lambda x: encode(x, 0x5E, pp=3, sel=1)), (UPat.ins(X86Ops.VDIVPD, name="x"), lambda x: encode(x, 0x5E, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VCMPSS, name="x"), lambda x: encode(x, 0xC2, pp=2, sel=1)), (UPat.ins(X86Ops.VCMPPS, name="x"), lambda x: encode(x, 0xC2, pp=0, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VCMPSD, name="x"), lambda x: encode(x, 0xC2, pp=3, sel=1)), (UPat.ins(X86Ops.VCMPPD, name="x"), lambda x: encode(x, 0xC2, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VMAXSS, name="x"), lambda x: encode(x, 0x5F, pp=2, sel=1)), (UPat.ins(X86Ops.VMAXPS, name="x"), lambda x: encode(x, 0x5F, pp=0, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VMAXSD, name="x"), lambda x: encode(x, 0x5F, pp=3, sel=1)), (UPat.ins(X86Ops.VMAXPD, name="x"), lambda x: encode(x, 0x5F, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VMINSS, name="x"), lambda x: encode(x, 0x5D, pp=2, sel=1)), (UPat.ins(X86Ops.VMINPS, name="x"), lambda x: encode(x, 0x5D, pp=0, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VMINSD, name="x"), lambda x: encode(x, 0x5D, pp=3, sel=1)), (UPat.ins(X86Ops.VMINPD, name="x"), lambda x: encode(x, 0x5D, pp=1, sel=1)), # noqa: E501
# ternary
(UPat(X86Ops.CMOVB, name="x"), lambda x: encode(x, 0x0F42)), (UPat(X86Ops.CMOVL, name="x"), lambda x: encode(x, 0x0F4C)),
(UPat(X86Ops.CMOVE, name="x"), lambda x: encode(x, 0x0F44)), (UPat(X86Ops.CMOVNE, name="x"), lambda x: encode(x, 0x0F45)),
(UPat(X86Ops.VFMADD213SS, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2)), (UPat(X86Ops.VFMADD213SD, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2, we=1)), # noqa: E501
(UPat(X86Ops.VFMADD213PS, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2)), (UPat(X86Ops.VFMADD213PD, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2, we=1)), # noqa: E501
(UPat(X86Ops.VBLENDVPS, name="x"), lambda x: encode(x, 0x4A, pp=1, sel=3)), (UPat(X86Ops.VBLENDVPD, name="x"), lambda x: encode(x, 0x4B, pp=1, sel=3)), # noqa: E501
(UPat(X86Ops.VPBLENDVB, name="x"), lambda x: encode(x, 0x4C, pp=1, sel=3)),
(UPat.ins(X86Ops.CMOVB, name="x"), lambda x: encode(x, 0x0F42)), (UPat.ins(X86Ops.CMOVL, name="x"), lambda x: encode(x, 0x0F4C)),
(UPat.ins(X86Ops.CMOVE, name="x"), lambda x: encode(x, 0x0F44)), (UPat.ins(X86Ops.CMOVNE, name="x"), lambda x: encode(x, 0x0F45)),
(UPat.ins(X86Ops.VFMADD213SS, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2)), (UPat.ins(X86Ops.VFMADD213SD, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2, we=1)), # noqa: E501
(UPat.ins(X86Ops.VFMADD213PS, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2)), (UPat.ins(X86Ops.VFMADD213PD, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2, we=1)), # noqa: E501
(UPat.ins(X86Ops.VBLENDVPS, name="x"), lambda x: encode(x, 0x4A, pp=1, sel=3)), (UPat.ins(X86Ops.VBLENDVPD, name="x"), lambda x: encode(x, 0x4B, pp=1, sel=3)), # noqa: E501
(UPat.ins(X86Ops.VPBLENDVB, name="x"), lambda x: encode(x, 0x4C, pp=1, sel=3)),
# shuffles
(UPat(X86Ops.VPBROADCASTB, name="x"), lambda x: encode(x, 0x78, pp=1, sel=2)), (UPat(X86Ops.VPBROADCASTW, name="x"), lambda x: encode(x, 0x79, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VPBROADCASTD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=2)), (UPat(X86Ops.VPBROADCASTQ, name="x"), lambda x: encode(x, 0x59, pp=1, sel=2)), # noqa: E501
(UPat(X86Ops.VBROADCASTSS, name="x"), lambda x: encode(x, 0x18, pp=1, sel=2)),
(UPat(X86Ops.VPINSRB, name="x"), lambda x: encode(x, 0x20, pp=1, sel=3)), (UPat(X86Ops.VPINSRW, name="x"), lambda x: encode(x, 0xC4, pp=1, sel=1)),
(UPat(X86Ops.VPINSRD, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3)), (UPat(X86Ops.VPINSRQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3, we=1)), # noqa: E501
(UPat(X86Ops.VSHUFPS, name="x"), lambda x: encode(x, 0xC6, pp=0, sel=1)), (UPat(X86Ops.VINSERTPS, name="x"), lambda x: encode(x, 0x21, pp=1, sel=3)), # noqa: E501
(UPat.ins(X86Ops.VPBROADCASTB, name="x"), lambda x: encode(x, 0x78, pp=1, sel=2)), (UPat.ins(X86Ops.VPBROADCASTW, name="x"), lambda x: encode(x, 0x79, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VPBROADCASTD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=2)), (UPat.ins(X86Ops.VPBROADCASTQ, name="x"), lambda x: encode(x, 0x59, pp=1, sel=2)), # noqa: E501
(UPat.ins(X86Ops.VBROADCASTSS, name="x"), lambda x: encode(x, 0x18, pp=1, sel=2)),
(UPat.ins(X86Ops.VPINSRB, name="x"), lambda x: encode(x, 0x20, pp=1, sel=3)), (UPat.ins(X86Ops.VPINSRW, name="x"), lambda x: encode(x, 0xC4, pp=1, sel=1)), # noqa: E501
(UPat.ins(X86Ops.VPINSRD, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3)), (UPat.ins(X86Ops.VPINSRQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3, we=1)), # noqa: E501
(UPat.ins(X86Ops.VSHUFPS, name="x"), lambda x: encode(x, 0xC6, pp=0, sel=1)), (UPat.ins(X86Ops.VINSERTPS, name="x"), lambda x: encode(x, 0x21, pp=1, sel=3)), # noqa: E501
# extract
(UPat(X86Ops.VPEXTRB, name="x"), lambda x: encode(x, 0x14, pp=1, sel=3)), (UPat(X86Ops.VPEXTRW, name="x"), lambda x: encode(x, 0x15, pp=1, sel=3)),
(UPat(X86Ops.VPEXTRD, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3)), (UPat(X86Ops.VPEXTRQ, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3, we=1)), # noqa: E501
(UPat.ins(X86Ops.VPEXTRB, name="x"), lambda x: encode(x, 0x14, pp=1, sel=3)), (UPat.ins(X86Ops.VPEXTRW, name="x"), lambda x: encode(x, 0x15, pp=1, sel=3)), # noqa: E501
(UPat.ins(X86Ops.VPEXTRD, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3)), (UPat.ins(X86Ops.VPEXTRQ, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3, we=1)), # noqa: E501
# jumps are encoded with a placeholder which gets patched later once the real offset is known
(UPat(X86Ops.JE), lambda: bytes([0x0F, 0x84]) + int(0).to_bytes(4, 'little', signed=True)),
(UPat(X86Ops.JNE), lambda: bytes([0x0F, 0x85]) + int(0).to_bytes(4, 'little', signed=True)),
(UPat(X86Ops.JL), lambda: bytes([0x0F, 0x8C]) + int(0).to_bytes(4, 'little', signed=True)),
(UPat(X86Ops.JB), lambda: bytes([0x0F, 0x82]) + int(0).to_bytes(4, 'little', signed=True)),
(UPat.ins(X86Ops.JE), lambda: bytes([0x0F, 0x84]) + int(0).to_bytes(4, 'little', signed=True)),
(UPat.ins(X86Ops.JNE), lambda: bytes([0x0F, 0x85]) + int(0).to_bytes(4, 'little', signed=True)),
(UPat.ins(X86Ops.JL), lambda: bytes([0x0F, 0x8C]) + int(0).to_bytes(4, 'little', signed=True)),
(UPat.ins(X86Ops.JB), lambda: bytes([0x0F, 0x82]) + int(0).to_bytes(4, 'little', signed=True)),
# return
(UPat(X86Ops.RET), lambda: bytes([0xC3])),
(UPat.ins(X86Ops.RET), lambda: bytes([0xC3])),
])
class X86Renderer(ISARenderer):
@ -652,21 +651,21 @@ class X86Renderer(ISARenderer):
def __init__(self):
from tinygrad.runtime.support.compiler_cpu import X86Compiler
self.compiler = X86Compiler()
def stack_pointer(self) -> UOp: return UOp(X86Ops.DEFINE_REG, dtypes.uint64, arg=RSP)
def stack_pointer(self) -> UOp: return UOp(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dtypes.uint64, tag=RSP)
def render(self, uops:list[UOp], lower:bool=True) -> str:
if lower: uops = self.lower(uops[-1])
targets: set[UOp] = set()
target_loc: list[int] = []
binary = bytearray()
for u in uops:
if u.op in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE): targets.add(u.src[0])
if u.arg in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE): targets.add(u.src[0])
for u in uops:
if u.op in (Ops.GROUP, Ops.NOOP, Ops.AFTER, Ops.BARRIER): continue
if u.op in (X86Ops.IMM, X86Ops.DEFINE_REG): continue
if u.arg in (Ops.GROUP, Ops.NOOP, Ops.AFTER, Ops.BARRIER): continue
if u.arg in (X86Ops.IMM, X86Ops.DEFINE_REG): continue
if (l:=cast(bytes|None, encodings.rewrite(u))) is None:
raise RuntimeError(f"failed to encode {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
raise RuntimeError(f"failed to encode {u.arg} with {u.dtype} srcs {[x.dtype for x in u.src]}")
binary.extend(l)
if u in targets: target_loc.append(len(binary))
elif u.op in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE):
elif u.arg in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE):
binary[-4:] = (target_loc.pop() - len(binary)).to_bytes(4, 'little', signed=True)
return binary.hex()

View file

@ -1,14 +1,9 @@
# flake8: noqa: E702
# allow semicolons to put multiple ops on one line
from enum import auto, IntEnum, Enum, EnumType
# wrapper around EnumType to allow extending enums with members
class ExtensibleEnumType(EnumType):
@classmethod
def _check_for_existing_members_(mcls, class_name, bases): return
from enum import auto, IntEnum, Enum
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
class FastEnum(IntEnum, metaclass=ExtensibleEnumType):
class FastEnum(IntEnum):
def __str__(self): return Enum.__str__(self)
def __repr__(x): return str(x)
@staticmethod

View file

@ -421,6 +421,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
def ins(self, arg, **kwargs): return UOp(Ops.INS, kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), arg, kwargs.pop("tag", self.tag))
def contract(self, *rngs:UOp):
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs))
@ -969,6 +970,8 @@ class UPat(OpMixin):
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.match_dtype, (self,)+src, **kwargs)
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.match_dtype, (self,)+src, **kwargs)
@staticmethod
def ins(arg, **kwargs): return UPat(Ops.INS, arg=arg, **kwargs)
def const_like(self, b:ConstLike): return UPat.const(self.match_dtype, cast(ConstType, b))
def alu(self, op:Ops, *src:UPat):