mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add Ops.INS to x86
This commit is contained in:
parent
1f140d9d53
commit
dd558ecfae
5 changed files with 300 additions and 305 deletions
|
|
@ -30,10 +30,10 @@ class RegallocContext:
|
|||
lr, ranges = self.live_range, []
|
||||
for i,u in enumerate(reversed(uops)):
|
||||
if u.op in (Ops.NOOP, Ops.AFTER): continue
|
||||
for v in {s.arg for s in (u,) + u.src if isinstance(s.arg, Register)}: lr.setdefault(v, []).insert(0, len(uops) - 1 - i)
|
||||
for v in {s.tag for s in (u,) + u.src if isinstance(s.tag, Register)}: lr.setdefault(v, []).insert(0, len(uops) - 1 - i)
|
||||
# a var defined before a range and used inside it is needed for the whole range
|
||||
if u.arg in lr and (n:=max((lr[rng][-1] for rng in ranges if lr[rng][0] < lr[u.arg][-1] < lr[rng][-1]), default=None)): lr[u.arg].append(n)
|
||||
if u.op is Ops.RANGE: ranges.append(u.arg)
|
||||
if u.tag in lr and (n:=max((lr[rng][-1] for rng in ranges if lr[rng][0] < lr[u.tag][-1] < lr[rng][-1]), default=None)): lr[u.tag].append(n)
|
||||
if u.op is Ops.RANGE: ranges.append(u.tag)
|
||||
|
||||
# TODO: rm pointers
|
||||
# nasty hacks to deal with pointers
|
||||
|
|
@ -44,7 +44,7 @@ def assign(ctx:RegallocContext, x:UOp, reg:Register):
|
|||
return ret.replace(dtype=x.dtype)
|
||||
def load(ctx:RegallocContext, dt:DType, disp:UOp, reg:Register):
|
||||
ndt = dtypes.uint64 if isinstance(dt, PtrDType) else dt
|
||||
ret = ctx.isel.rewrite(ctx.stack_ptr.index(disp).load(dtype=ndt, arg=reg))
|
||||
ret = ctx.isel.rewrite(ctx.stack_ptr.index(disp).load(dtype=ndt, tag=reg))
|
||||
assert ret is not None
|
||||
return ret.replace(dtype=dt)
|
||||
def store(ctx:RegallocContext, disp:UOp, x:UOp):
|
||||
|
|
@ -71,7 +71,7 @@ def regalloc(ctx:RegallocContext, x:UOp, i:int) -> tuple[UOp, list[UOp]]:
|
|||
nsrc, loads = [], []
|
||||
for s in x.src:
|
||||
# allocate srcs, if src was spilled it's replaced by a load, if it's live the load was already emited otherwise alloc and emit one
|
||||
if isinstance(s.arg, Register) and (v:=ctx.rewrite_to_vreg[s]) in ctx.spills:
|
||||
if isinstance(s.tag, Register) and (v:=ctx.rewrite_to_vreg[s]) in ctx.spills:
|
||||
# TODO: the constraints only apply to the definition, you need to insert moves in the graph to "cleanse" the constraint
|
||||
# then those moves are removed after regalloc if they move to the same register. I think this is the llvm approach
|
||||
# alternatively you could beef up the register class to include constraints on the srcs, then you check those here
|
||||
|
|
@ -82,7 +82,7 @@ def regalloc(ctx:RegallocContext, x:UOp, i:int) -> tuple[UOp, list[UOp]]:
|
|||
else: s = load(ctx, s.dtype, ctx.spills[v], ctx.live[v])
|
||||
nsrc.append(s)
|
||||
# allocate destination
|
||||
if isinstance(v:=x.arg, Register) and v not in ctx.live:
|
||||
if isinstance(v:=x.tag, Register) and v not in ctx.live:
|
||||
# if no cons it's a real register, so it can only be assigned to itself
|
||||
cons = v.cons or (v,)
|
||||
# two address instructions (src is used in dest) can only coalesce reused src. reused src goes first to get priority in case of a tiebreak
|
||||
|
|
@ -93,7 +93,7 @@ def regalloc(ctx:RegallocContext, x:UOp, i:int) -> tuple[UOp, list[UOp]]:
|
|||
assert cons
|
||||
ctx.live[v] = alloc(ctx, cons, i+1)
|
||||
|
||||
nx = x.replace(src=tuple(nsrc), arg=ctx.live.get(v, v))
|
||||
nx = x.replace(src=tuple(nsrc), tag=ctx.live.get(v, v))
|
||||
# TODO: this check exists because of a hack in x86, rm once multiple outputs are supported
|
||||
if nx not in ctx.rewrite_to_vreg: ctx.rewrite_to_vreg[nx] = v
|
||||
if v not in ctx.vreg_to_rewrite: ctx.vreg_to_rewrite[v] = nx
|
||||
|
|
@ -101,10 +101,10 @@ def regalloc(ctx:RegallocContext, x:UOp, i:int) -> tuple[UOp, list[UOp]]:
|
|||
|
||||
# move uops to registers before the loop to avoid loading inside the loop
|
||||
def loop_prologue(ctx:RegallocContext, x:UOp, i:int):
|
||||
assert isinstance(x.arg, Register)
|
||||
assert isinstance(x.tag, Register)
|
||||
nx, lst = regalloc(ctx, x, i)
|
||||
# we move to register vars used in the loop sorted by next use, vars not used in the loop will not be reloaded in the epilogue
|
||||
used_in_loop = [v for v in ctx.live.keys() | ctx.spills.keys() if any(i <= l < ctx.live_range[x.arg][-1] for l in ctx.live_range[v])]
|
||||
used_in_loop = [v for v in ctx.live.keys() | ctx.spills.keys() if any(i <= l < ctx.live_range[x.tag][-1] for l in ctx.live_range[v])]
|
||||
sorted_uses = sorted(used_in_loop, key=lambda k: next(l-i for l in ctx.live_range[k] if l >= i))
|
||||
live_in: dict[Register, Register] = {}
|
||||
loads = []
|
||||
|
|
@ -135,12 +135,12 @@ def loop_epilogue(ctx:RegallocContext, x:UOp, i:int):
|
|||
pm_regalloc = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x: loop_prologue(ctx, x, next(ctx.idx))),
|
||||
(UPat(Ops.END, name="x"), lambda ctx,x: loop_epilogue(ctx, x, next(ctx.idx))),
|
||||
(UPat(X86GroupOp.All | {Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER}, name="x"), lambda ctx,x: regalloc(ctx, x, next(ctx.idx))),
|
||||
(UPat({Ops.INS, Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER}, name="x"), lambda ctx,x: regalloc(ctx, x, next(ctx.idx))),
|
||||
])
|
||||
|
||||
# annoying that this is another pm
|
||||
pm_insert_spills = PatternMatcher([
|
||||
# insert spill after definition
|
||||
(UPat(X86GroupOp.All | {Ops.RANGE}, name="x"), lambda ctx,x:
|
||||
(UPat({Ops.INS, Ops.RANGE}, name="x"), lambda ctx,x:
|
||||
(x, [x, store(ctx, y, x)]) if (y:=ctx.spills.get(ctx.rewrite_to_vreg.get(x))) is not None else None),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
# flake8: noqa: E702
|
||||
# allow semicolons to put multiple ops on one line
|
||||
from tinygrad.uop.ops import Ops, auto
|
||||
from tinygrad.uop import FastEnum, auto
|
||||
|
||||
# ***** X86 *****
|
||||
|
||||
# NOTE: mypy doesn't allow extending enums even with our wrapper, it also doesn't allow overriding i.e. Ops.ADD to X86Ops.ADD
|
||||
# we ignore it in both cases
|
||||
class X86Ops(Ops): # type: ignore[misc]
|
||||
class X86Ops(FastEnum):
|
||||
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
|
||||
# register, not an instruction. FRAME_INDEX is used when the function arg is on the stack and is rewritten to IMM when stack size is known
|
||||
DEFINE_REG = auto(); FRAME_INDEX = auto() # type: ignore[misc]
|
||||
DEFINE_REG = auto(); FRAME_INDEX = auto()
|
||||
# const
|
||||
IMM = auto()
|
||||
# index
|
||||
|
|
@ -48,10 +46,10 @@ class X86Ops(Ops): # type: ignore[misc]
|
|||
VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto()
|
||||
VBROADCASTSS = auto() # TODO: VBROADCASTSD is ymm only, add once they are supported
|
||||
# int binary
|
||||
IDIV = auto(); DIV = auto() # type: ignore[misc]
|
||||
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto() # type: ignore[misc]
|
||||
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto() # type: ignore[misc]
|
||||
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto() # type: ignore[misc]
|
||||
IDIV = auto(); DIV = auto()
|
||||
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto()
|
||||
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto()
|
||||
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto()
|
||||
# float unary (sometimes not unary)
|
||||
VROUNDSS = auto(); VROUNDSD = auto(); VROUNDPS = auto(); VROUNDPD = auto()
|
||||
VSQRTSS = auto(); VSQRTSD = auto(); VSQRTPS = auto(); VSQRTPD = auto()
|
||||
|
|
|
|||
|
|
@ -109,8 +109,8 @@ WGPR = tuple(r for r in GPR if r != RSP)
|
|||
|
||||
# ***** X86 instruction selection *****
|
||||
|
||||
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(X86Ops.DEFINE_REG, dt, arg=reg)
|
||||
def imm(dt:DType, v:int|float) -> UOp: return UOp(X86Ops.IMM, dt, arg=v)
|
||||
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dt, tag=reg)
|
||||
def imm(dt:DType, v:int|float) -> UOp: return UOp(Ops.INS, arg=X86Ops.IMM, dtype=dt, tag=v)
|
||||
def to_imm(c:UOp) -> UOp|None:
|
||||
if c.op is not Ops.CONST: return None
|
||||
if c.dtype is dtypes.int64: return imm(dtypes.int32, c.arg) if not c.overflows(dtypes.int32) else None
|
||||
|
|
@ -118,17 +118,16 @@ def to_imm(c:UOp) -> UOp|None:
|
|||
if c.dtype in dtypes.ints+(dtypes.bool,): return imm(c.dtype, c.arg)
|
||||
return None
|
||||
def cmp(x:UOp) -> UOp:
|
||||
if x.src[0].dtype is dtypes.float32: return UOp(X86Ops.VUCOMISS, src=x.src)
|
||||
if x.src[0].dtype is dtypes.float64: return UOp(X86Ops.VUCOMISD, src=x.src)
|
||||
return UOp(X86Ops.CMP, src=x.src) if (i:=to_imm(x.src[1])) is None else UOp(X86Ops.CMPi, src=(x.src[0], i))
|
||||
if x.src[0].dtype is dtypes.float32: return x.ins(X86Ops.VUCOMISS, dtype=dtypes.void)
|
||||
if x.src[0].dtype is dtypes.float64: return x.ins(X86Ops.VUCOMISD, dtype=dtypes.void)
|
||||
return x.ins(X86Ops.CMP, dtype=dtypes.void) if (i:=to_imm(x.src[1])) is None else x.ins(X86Ops.CMPi, dtype=dtypes.void, src=(x.src[0], i))
|
||||
|
||||
# vshufps xmm2, xmm0, xmm1, imm
|
||||
# xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1 according to imm
|
||||
def vshufps(x:UOp) -> UOp|None:
|
||||
def _in(i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
|
||||
if len(x.src) != 4 or _in(0) is not _in(1) or _in(2) is not _in(3): return None
|
||||
return UOp(X86Ops.VSHUFPS, x.dtype, (_in(0), _in(2),
|
||||
imm(dtypes.uint8, sum(s.arg[0] << 2*i if s.op is Ops.GEP else 0 for i,s in enumerate(x.src)))))
|
||||
if len(x.src) != 4 or (a:=_in(0)) is not _in(1) or (b:=_in(2)) is not _in(3): return None
|
||||
return x.ins(X86Ops.VSHUFPS, src=(a, b, imm(dtypes.uint8, sum(s.arg[0] << 2*i if s.op is Ops.GEP else 0 for i,s in enumerate(x.src)))))
|
||||
|
||||
# vinsertps xmm2, xmm0, xmm1, imm
|
||||
# inserts any 32 bit element in xmm1 into any position in xmm0 according to immm, result is written to xmm2
|
||||
|
|
@ -138,14 +137,14 @@ def vinsertps(x:UOp) -> UOp:
|
|||
s, v = x.src[i], 0
|
||||
if s.op is Ops.GEP: s, v = s.src[0], s.arg[0]
|
||||
# moving the 0th element into the 0th position does nothing
|
||||
return s if i == v == 0 else UOp(X86Ops.VINSERTPS, x.dtype, (ret, s, imm(dtypes.uint8, v << 6 | i << 4)))
|
||||
return s if i == v == 0 else x.ins(X86Ops.VINSERTPS, src=(ret, s, imm(dtypes.uint8, v << 6 | i << 4)))
|
||||
return functools.reduce(_insert, range(len(x.src)), def_reg(x.dtype))
|
||||
|
||||
# vpinsq xmm2, xmm0, rax, imm
|
||||
# inserts element in rax into any position in xmm0, result is written to xmm2 according to imm
|
||||
def vpins(x:UOp) -> UOp:
|
||||
op = {1: X86Ops.VPINSRB, 2: X86Ops.VPINSRW, 4: X86Ops.VPINSRD, 8: X86Ops.VPINSRQ}[x.dtype.scalar().itemsize]
|
||||
return functools.reduce(lambda ret,i: UOp(op, x.dtype, (ret, x.src[i], imm(dtypes.uint8, i))), range(len(x.src)), def_reg(x.dtype))
|
||||
return functools.reduce(lambda ret,i: x.ins(op, src=(ret, x.src[i], imm(dtypes.uint8, i))), range(len(x.src)), def_reg(x.dtype))
|
||||
|
||||
def div(ctx:IselContext, x:UOp):
|
||||
# zero extend or move src[0] to x
|
||||
|
|
@ -184,7 +183,7 @@ def fuse_load(ctx:IselContext, x:UOp, i:int) -> UOp|None:
|
|||
def abi(ctx:IselContext, x:UOp):
|
||||
i = ctx.func_args.index(x)
|
||||
def _stack_arg(disp:int):
|
||||
return UOp(X86Ops.MOV, x.dtype, (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(X86Ops.FRAME_INDEX, dtypes.int32, arg=disp)))
|
||||
return x.ins(X86Ops.MOV, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(X86Ops.FRAME_INDEX, dtypes.int32, arg=disp)))
|
||||
if sys.platform == "win32": return def_reg(x.dtype, (RCX, RDX, GPR[8], GPR[9])[i]) if i < 4 else _stack_arg((i-3)*8+32)
|
||||
return def_reg(x.dtype, (RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i]) if i < 6 else _stack_arg((i-5)*8)
|
||||
|
||||
|
|
@ -205,195 +204,195 @@ isel_matcher = PatternMatcher([
|
|||
# **** Op -> X86Op ****
|
||||
# add callee saved registers to the RET, these will be scheduled at the top of the kernel and will be saved/restored if they are used in regalloc
|
||||
# so regalloc builds the prologue/epilogue naturally
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(op=X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64, r) for r in [RSP, RBP]))),
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.ins(X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64, r) for r in [RSP, RBP]))),
|
||||
# function abi constraints
|
||||
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), abi),
|
||||
# these are treated the same for now
|
||||
(UPat((Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), lambda ctx,x:
|
||||
x.replace(op=X86Ops.LEA, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), imm(dtypes.int32, ctx.inc_stack(x.dtype.nbytes()))), arg=None)),
|
||||
x.ins(X86Ops.LEA, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), imm(dtypes.int32, ctx.inc_stack(x.dtype.nbytes()))), arg=None)),
|
||||
# constants that can't be immediates, move them to registers
|
||||
(UPat(Ops.CONST, dtypes.float16, name="x"), lambda x: UOp(X86Ops.VPINSRW, x.dtype, (def_reg(x.dtype), UOp(X86Ops.MOVi, dtypes.int16, (imm(x.dtype, x.arg),)), imm(dtypes.uint8, 0)))), # noqa: E501
|
||||
(UPat(Ops.CONST, dtypes.float32, name="x"), lambda x: UOp(X86Ops.VMOVD, x.dtype, (UOp(X86Ops.MOVi, dtypes.int32, (imm(x.dtype, x.arg),)),))),
|
||||
(UPat(Ops.CONST, dtypes.float64, name="x"), lambda x: UOp(X86Ops.VMOVQ, x.dtype, (UOp(X86Ops.MOVABS, dtypes.int64, (imm(x.dtype, x.arg),)),))),
|
||||
(UPat(Ops.CONST, dtypes.int64s, name="x"), lambda x: UOp(X86Ops.MOVABS, x.dtype, (imm(x.dtype, x.arg),))),
|
||||
(UPat(Ops.CONST, dtypes.ints+(dtypes.bool,), name="x"), lambda x: UOp(X86Ops.MOVi, x.dtype, (imm(x.dtype, x.arg),))),
|
||||
(UPat(Ops.CONST, dtypes.float16, name="x"), lambda x: x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype), UOp(Ops.INS, arg=X86Ops.MOVi, dtype=dtypes.int16, src=(imm(x.dtype, x.arg),)), imm(dtypes.uint8, 0)))), # noqa: E501
|
||||
(UPat(Ops.CONST, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VMOVD, src=(UOp(Ops.INS, arg=X86Ops.MOVi, dtype=dtypes.int32, src=(imm(x.dtype, x.arg),)),))), # noqa: E501
|
||||
(UPat(Ops.CONST, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VMOVQ, src=(UOp(Ops.INS, arg=X86Ops.MOVABS, dtype=dtypes.int64, src=(imm(x.dtype, x.arg),)),))), # noqa: E501
|
||||
(UPat(Ops.CONST, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVABS, src=(imm(x.dtype, x.arg),))),
|
||||
(UPat(Ops.CONST, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOVi, src=(imm(x.dtype, x.arg),))),
|
||||
# conditional moves that use masks NOTE: these currently assume a mask producing cmp exists
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.ints), UPat.var("b")), lambda m,a,b: UOp(X86Ops.VPBLENDVB, a.dtype, (b, a, m.replace(dtype=m.src[0].dtype))) if a.dtype.count > 1 else None), # noqa: E501
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda m,a,b: UOp(X86Ops.VBLENDVPS, a.dtype, (b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda m,a,b: UOp(X86Ops.VBLENDVPD, a.dtype, (b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.ints), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.VPBLENDVB, src=(b, a, m.replace(dtype=m.src[0].dtype))) if a.dtype.count > 1 else None), # noqa: E501
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.VBLENDVPS, src=(b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.VBLENDVPD, src=(b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501
|
||||
# in this case we have a mask producing comparison whose user expects a bool, so we convert to bool
|
||||
(UPat(GroupOp.Comparison, dtypes.bool, (UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(dtype=x.src[0].dtype).bitcast(dtypes.int32).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)), # noqa: E501
|
||||
(UPat(GroupOp.Comparison, dtypes.bool, (UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(dtype=x.src[0].dtype).bitcast(dtypes.int64).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)), # noqa: E501
|
||||
# conditional moves that use flags
|
||||
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVL, a.dtype, src=(b, a, cmp(m)))), # noqa: E501
|
||||
(UPat(Ops.CMPLT, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVB, a.dtype, src=(b, a, cmp(m)))),
|
||||
(UPat(Ops.CMPEQ, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVE, a.dtype, src=(b, a, cmp(m)))),
|
||||
(UPat(Ops.CMPNE, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVNE, a.dtype, src=(b, a, cmp(m)))),
|
||||
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVL, src=(b, a, cmp(m)))), # noqa: E501
|
||||
(UPat(Ops.CMPLT, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVB, src=(b, a, cmp(m)))),
|
||||
(UPat(Ops.CMPEQ, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVE, src=(b, a, cmp(m)))),
|
||||
(UPat(Ops.CMPNE, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVNE, src=(b, a, cmp(m)))),
|
||||
# jumps, use flags
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="y"),), name="x"), lambda y,x: UOp(X86Ops.JB, x.dtype, (cmp(y),))), # noqa: E501
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, name="y"),)), lambda y: UOp(X86Ops.JL, src=(cmp(y),))),
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPEQ, name="y"),)), lambda y: UOp(X86Ops.JE, src=(cmp(y),))),
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPNE, name="y"),)), lambda y: UOp(X86Ops.JNE, src=(cmp(y),))),
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JB, src=(cmp(y),))), # noqa: E501
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPLT, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JL, src=(cmp(y),))),
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPEQ, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JE, src=(cmp(y),))),
|
||||
(UPat(Ops.IF, src=(UPat(Ops.CMPNE, name="y"),), name="x"), lambda y,x: x.ins(X86Ops.JNE, src=(cmp(y),))),
|
||||
# comparisons whose user doesn't use the flag, move flag result to register
|
||||
(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="x"), lambda x: UOp(X86Ops.SETB, x.dtype, (cmp(x),))),
|
||||
(UPat(Ops.CMPLT, dtypes.bool, name="x"), lambda x: UOp(X86Ops.SETL, x.dtype, (cmp(x),))),
|
||||
(UPat(Ops.CMPEQ, dtypes.bool, name="x"), lambda x: UOp(X86Ops.SETE, x.dtype, (cmp(x),))),
|
||||
(UPat(Ops.CMPNE, dtypes.bool, name="x"), lambda x: UOp(X86Ops.SETNE, x.dtype, (cmp(x),))),
|
||||
(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="x"), lambda x: x.ins(X86Ops.SETB, src=(cmp(x),))),
|
||||
(UPat(Ops.CMPLT, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETL, src=(cmp(x),))),
|
||||
(UPat(Ops.CMPEQ, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETE, src=(cmp(x),))),
|
||||
(UPat(Ops.CMPNE, dtypes.bool, name="x"), lambda x: x.ins(X86Ops.SETNE, src=(cmp(x),))),
|
||||
# comparisons that produce masks (these aren't bool dtype)
|
||||
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501
|
||||
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501
|
||||
(UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501
|
||||
(UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int8s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQB)),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int16s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQW)),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int32s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQD)),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int64s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQQ)),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int8s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTB, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int16s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTW, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int32s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTD, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int64s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTQ, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501
|
||||
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501
|
||||
(UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501
|
||||
(UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int8s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQB)),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int16s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQW)),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int32s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQD)),
|
||||
(UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int64s), UPat()), name="x"), lambda x: x.ins(X86Ops.VPCMPEQQ)),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int8s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTB, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int16s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTW, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int32s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTD, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int64s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTQ, src=(b, a))),
|
||||
# float unary
|
||||
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: UOp(X86Ops.VSQRTSS, x.dtype, (y, y)) if x.dtype.count == 1 else x.replace(op=X86Ops.VSQRTPS)), # noqa: E501
|
||||
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: UOp(X86Ops.VSQRTSD, x.dtype, (y, y)) if x.dtype.count == 1 else x.replace(op=X86Ops.VSQRTPD)), # noqa: E501
|
||||
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDSS, x.dtype, (y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501
|
||||
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDSD, x.dtype, (y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501
|
||||
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDPS, x.dtype, (y, imm(dtypes.uint8, 3)))),
|
||||
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDPD, x.dtype, (y, imm(dtypes.uint8, 3)))),
|
||||
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSS, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPS)), # noqa: E501
|
||||
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSD, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPD)), # noqa: E501
|
||||
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: x.ins(X86Ops.VROUNDSS, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501
|
||||
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: x.ins(X86Ops.VROUNDSD, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501
|
||||
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: x.ins(X86Ops.VROUNDPS, src=(y, imm(dtypes.uint8, 3)))),
|
||||
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: x.ins(X86Ops.VROUNDPD, src=(y, imm(dtypes.uint8, 3)))),
|
||||
# broadcasts TODO: not quite right, what about load fusion? Also, bitcast should be x86op and reg is xmm?
|
||||
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTB, x.dtype, (y.bitcast(dtypes.float32),))),
|
||||
(UPat.var("y", dtypes.int16s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTW, x.dtype, (y.bitcast(dtypes.float32),))),
|
||||
(UPat.var("y", dtypes.int32s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTD, x.dtype, (y.bitcast(dtypes.float32),))),
|
||||
(UPat.var("y", dtypes.int64s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTQ, x.dtype, (y.bitcast(dtypes.float64),))),
|
||||
(UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: UOp(X86Ops.VBROADCASTSS, x.dtype, (y,))),
|
||||
(UPat.var("y", dtypes.int8s+(dtypes.bool,)).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VPBROADCASTB, src=(y.bitcast(dtypes.float32),))),
|
||||
(UPat.var("y", dtypes.int16s).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VPBROADCASTW, src=(y.bitcast(dtypes.float32),))),
|
||||
(UPat.var("y", dtypes.int32s).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VPBROADCASTD, src=(y.bitcast(dtypes.float32),))),
|
||||
(UPat.var("y", dtypes.int64s).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VPBROADCASTQ, src=(y.bitcast(dtypes.float64),))),
|
||||
(UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VBROADCASTSS, src=(y,))),
|
||||
# shufles
|
||||
(UPat.var("y", dtypes.int16s).bitcast(dtypes.float16).named("x"), lambda y,x: UOp(X86Ops.VPINSRW, x.dtype, (def_reg(x.dtype), y, imm(dtypes.uint8, 0)))), # noqa: E501
|
||||
(UPat.var("y", dtypes.int16s).bitcast(dtypes.float16).named("x"), lambda y,x: x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype), y, imm(dtypes.uint8, 0)))), # noqa: E501
|
||||
(UPat(Ops.VECTORIZE, dtypes.ints+(dtypes.bool,), name="x"), vpins),
|
||||
(UPat(Ops.VECTORIZE, dtypes.float32, name="x"), vshufps),
|
||||
(UPat(Ops.VECTORIZE, dtypes.float32, name="x"), vinsertps),
|
||||
(UPat.var("y", dtypes.float32).gep(name="x"), lambda y,x: UOp(X86Ops.VINSERTPS, x.dtype, (y, y, imm(dtypes.uint8, x.arg[0] << 6)))),
|
||||
(UPat.var("y", dtypes.float32).gep(name="x"), lambda y,x: x.ins(X86Ops.VINSERTPS, src=(y, y, imm(dtypes.uint8, x.arg[0] << 6)))),
|
||||
# extract
|
||||
(UPat.var("y", dtypes.float16).bitcast(dtypes.int16s).named("x"), lambda y,x: UOp(X86Ops.VPEXTRW, x.dtype, (y, imm(dtypes.uint8, 0)))),
|
||||
(UPat.var("y", dtypes.int8s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRB, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRW, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRD, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRQ, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.float16).bitcast(dtypes.int16s).named("x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, 0)))),
|
||||
(UPat.var("y", dtypes.int8s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRB, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRW, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRD, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
(UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: x.ins(X86Ops.VPEXTRQ, src=(y, imm(dtypes.uint8, x.arg[0])))),
|
||||
# fused multiply add TODO: don't fuse if mul used several times
|
||||
(UPat.var('a', dtypes.float32) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.alu(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, b, c)), # noqa: E501
|
||||
(UPat.var('a', dtypes.float64) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.alu(X86Ops.VFMADD213SD if a.dtype.count == 1 else X86Ops.VFMADD213PD, b, c)), # noqa: E501
|
||||
(UPat.var('a', dtypes.float32) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(a, b, c))), # noqa: E501
|
||||
(UPat.var('a', dtypes.float64) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.ins(X86Ops.VFMADD213SD if a.dtype.count == 1 else X86Ops.VFMADD213PD, src=(a, b, c))), # noqa: E501
|
||||
# packed bitwise
|
||||
((UPat() & UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPAND) if x.dtype.count > 1 else None),
|
||||
((UPat() | UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPOR) if x.dtype.count > 1 else None),
|
||||
((UPat() ^ UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPXOR) if x.dtype.count > 1 else None),
|
||||
((UPat() & UPat()).named("x"), lambda x: x.ins(X86Ops.VPAND) if x.dtype.count > 1 else None),
|
||||
((UPat() | UPat()).named("x"), lambda x: x.ins(X86Ops.VPOR) if x.dtype.count > 1 else None),
|
||||
((UPat() ^ UPat()).named("x"), lambda x: x.ins(X86Ops.VPXOR) if x.dtype.count > 1 else None),
|
||||
# packed int binary
|
||||
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSLLVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSLLVQ) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSRLVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSRLVQ) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSRAVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDB) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDW) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDQ) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBB) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBW) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBD) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBQ) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPMULLW) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMULLD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVQ) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVQ) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRAVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDB) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDW) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDQ) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.ins(X86Ops.VPSUBB) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPSUBW) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPSUBD) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPSUBQ) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMULLW) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMULLD) if x.dtype.count > 1 else None),
|
||||
# scalar int binary
|
||||
((UPat(dtype=dtypes.uints) // UPat()).named("x"), div),
|
||||
((UPat(dtype=dtypes.sints) // UPat()).named("x"), idiv),
|
||||
((UPat.var("a", dtypes.ints) << UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.SHLi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.replace(op=X86Ops.SHL)), # noqa: E501
|
||||
((UPat.var("a", dtypes.uints) >> UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.SHRi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.replace(op=X86Ops.SHR)), # noqa: E501
|
||||
((UPat.var("a", dtypes.sints) >> UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.SARi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.replace(op=X86Ops.SHR)), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints+(dtypes.bool,)) & UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.AND) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.ANDi, src=(a, i))), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints+(dtypes.bool,)) | UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.OR) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.ORi, src=(a, i))), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.XOR) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.XORi, src=(a, i))), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints) * UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.IMUL) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.IMULi, src=(a, i))), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints) + UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.ADD) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.ADDi, src=(a, i))), # noqa: E501
|
||||
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.SUB) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.SUBi, src=(a, i))), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints) << UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.SHLi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.ins(X86Ops.SHL)), # noqa: E501
|
||||
((UPat.var("a", dtypes.uints) >> UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.SHRi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.ins(X86Ops.SHR)), # noqa: E501
|
||||
((UPat.var("a", dtypes.sints) >> UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.SARi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.ins(X86Ops.SHR)), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints+(dtypes.bool,)) & UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.AND) if (i:=to_imm(b)) is None else x.ins(X86Ops.ANDi, src=(a, i))), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints+(dtypes.bool,)) | UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.OR) if (i:=to_imm(b)) is None else x.ins(X86Ops.ORi, src=(a, i))), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.XOR) if (i:=to_imm(b)) is None else x.ins(X86Ops.XORi, src=(a, i))), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints) * UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.IMUL) if (i:=to_imm(b)) is None else x.ins(X86Ops.IMULi, src=(a, i))), # noqa: E501
|
||||
((UPat.var("a", dtypes.ints) + UPat.var("b")).named("x"), lambda a,b,x: x.ins(X86Ops.ADD) if (i:=to_imm(b)) is None else x.ins(X86Ops.ADDi, src=(a, i))), # noqa: E501
|
||||
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.SUB) if (i:=to_imm(b)) is None else x.ins(X86Ops.SUBi, src=(a, i))), # noqa: E501
|
||||
# float binary
|
||||
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VADDSS if x.dtype.count == 1 else X86Ops.VADDPS)),
|
||||
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VADDSD if x.dtype.count == 1 else X86Ops.VADDPD)),
|
||||
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.replace(op=X86Ops.VMULSS if x.dtype.count == 1 else X86Ops.VMULPS)),
|
||||
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.replace(op=X86Ops.VMULSD if x.dtype.count == 1 else X86Ops.VMULPD)),
|
||||
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VSUBSS if x.dtype.count == 1 else X86Ops.VSUBPS)),
|
||||
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)),
|
||||
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)),
|
||||
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)),
|
||||
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSS if x.dtype.count == 1 else X86Ops.VADDPS)),
|
||||
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSD if x.dtype.count == 1 else X86Ops.VADDPD)),
|
||||
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSS if x.dtype.count == 1 else X86Ops.VMULPS)),
|
||||
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSD if x.dtype.count == 1 else X86Ops.VMULPD)),
|
||||
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VSUBSS if x.dtype.count == 1 else X86Ops.VSUBPS)),
|
||||
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)),
|
||||
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)),
|
||||
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)),
|
||||
# TODO: these should use a.maximum(b) / a.minimum(b)
|
||||
(UPat(Ops.MAX, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VMAXSS if x.dtype.count == 1 else X86Ops.VMAXPS)),
|
||||
(UPat(Ops.MAX, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VMAXSD if x.dtype.count == 1 else X86Ops.VMAXPD)),
|
||||
((UPat.var("a", dtypes.float32) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: UOp(X86Ops.VMINSS if a.dtype.count == 1 else X86Ops.VMINPS, a.dtype, (a, b))), # noqa: E501
|
||||
((UPat.var("a", dtypes.float64) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: UOp(X86Ops.VMINSD if a.dtype.count == 1 else X86Ops.VMINPD, a.dtype, (a, b))), # noqa: E501
|
||||
(UPat(Ops.MAX, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VMAXSS if x.dtype.count == 1 else X86Ops.VMAXPS)),
|
||||
(UPat(Ops.MAX, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VMAXSD if x.dtype.count == 1 else X86Ops.VMAXPD)),
|
||||
((UPat.var("a", dtypes.float32) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: a.ins(X86Ops.VMINSS if a.dtype.count == 1 else X86Ops.VMINPS, src=(a, b))), # noqa: E501
|
||||
((UPat.var("a", dtypes.float64) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: a.ins(X86Ops.VMINSD if a.dtype.count == 1 else X86Ops.VMINPD, src=(a, b))), # noqa: E501
|
||||
# casts
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTPS2DQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTPD2DQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VCVTPS2PD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTPD2PS) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.float16, name="x"), lambda x: x.replace(op=X86Ops.VCVTPS2PH, src=x.src + (imm(dtypes.uint8, 4),))),
|
||||
(UPat(dtype=dtypes.float16).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTPH2PS)),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTSS2SI)),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTSD2SI)),
|
||||
(UPat.var("y", dtypes.float32).cast(dtypes.float64, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSS2SD, src=(y, y))),
|
||||
(UPat.var("y", dtypes.float64).cast(dtypes.float32, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSD2SS, src=(y, y))),
|
||||
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float32, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSI2SS, src=(def_reg(x.dtype), y))), # noqa: E501
|
||||
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float64, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSI2SD, src=(def_reg(x.dtype), y))), # noqa: E501
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXBW) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXBD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXBQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.uint16).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXWD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.uint16).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXWQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.uint32).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXDQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int8).cast(dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXBW) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int8).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXBD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int8).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXBQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int16).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXWD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int16).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXWQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXDQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.replace(op=X86Ops.MOVZX)),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.MOVSXD)),
|
||||
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.replace(op=X86Ops.MOVSX)),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPS2DQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPD2DQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPD2PS) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.float16, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PH, src=x.src + (imm(dtypes.uint8, 4),))),
|
||||
(UPat(dtype=dtypes.float16).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPH2PS)),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSS2SI)),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSD2SI)),
|
||||
(UPat.var("y", dtypes.float32).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSS2SD, src=(y, y))),
|
||||
(UPat.var("y", dtypes.float64).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSD2SS, src=(y, y))),
|
||||
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SS, src=(def_reg(x.dtype), y))),
|
||||
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SD, src=(def_reg(x.dtype), y))),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBW) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.uint16).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXWD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.uint16).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXWQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.uint32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXDQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int8).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBW) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int8).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int8).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXBQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int16).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXWD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int16).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXWQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVSXDQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVZX)),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVSXD)),
|
||||
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVSX)),
|
||||
# bitcasts
|
||||
(UPat(dtype=dtypes.int32s).bitcast(dtypes.float32).named("x"), lambda x: x.replace(op=X86Ops.VMOVD)),
|
||||
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.replace(op=X86Ops.VMOVQ)),
|
||||
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.replace(op=X86Ops.VMOVDm)),
|
||||
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.replace(op=X86Ops.VMOVQm)),
|
||||
(UPat(dtype=dtypes.int32s).bitcast(dtypes.float32).named("x"), lambda x: x.ins(X86Ops.VMOVD)),
|
||||
(UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.ins(X86Ops.VMOVQ)),
|
||||
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
|
||||
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
|
||||
# index
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: x.replace(op=X86Ops.LEA, src=fuse_address(x))),
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fuse_address(x))),
|
||||
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
|
||||
# assign, load, store
|
||||
# NOTE: assign here violates the spec, it only happens in register allocation when a reg to reg move needs to be inserted
|
||||
(UPat(Ops.ASSIGN, dt_128bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVUPS)),
|
||||
(UPat(Ops.ASSIGN, dt_64bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSD)),
|
||||
(UPat(Ops.ASSIGN, dt_32bit+dt_16bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSS)),
|
||||
(UPat(Ops.ASSIGN, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.replace(op=X86Ops.MOV)),
|
||||
(UPat(Ops.LOAD, dt_128bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVUPS, src=fuse_address(x.src[0]))),
|
||||
(UPat(Ops.LOAD, dt_64bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSD, src=fuse_address(x.src[0]))),
|
||||
(UPat(Ops.LOAD, dt_32bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSS, src=fuse_address(x.src[0]))),
|
||||
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda x: x.replace(op=X86Ops.VPINSRW, src=(def_reg(x.dtype, x.arg),) + fuse_address(x.src[0]) + (imm(dtypes.uint8, 0),))), # noqa: E501
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.replace(op=X86Ops.MOV, src=fuse_address(x.src[0]))),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_128bit)), name="x"), lambda x: x.replace(op=X86Ops.VMOVUPSm, src=fuse_address(x.src[0]) + (x.src[1],))), # noqa: E501
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_64bit)), name="x"), lambda x: x.replace(op=X86Ops.VMOVSDm, src=fuse_address(x.src[0]) + (x.src[1],))),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_32bit)), name="x"), lambda x: x.replace(op=X86Ops.VMOVSSm, src=fuse_address(x.src[0]) + (x.src[1],))),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_16bit)), name="x"), lambda x: x.replace(op=X86Ops.VPEXTRW, src=fuse_address(x.src[0]) + (x.src[1], imm(dtypes.uint8, 0)))), # noqa: E501
|
||||
(UPat(Ops.ASSIGN, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS)),
|
||||
(UPat(Ops.ASSIGN, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD)),
|
||||
(UPat(Ops.ASSIGN, dt_32bit+dt_16bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS)),
|
||||
(UPat(Ops.ASSIGN, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV)),
|
||||
(UPat(Ops.LOAD, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS, src=fuse_address(x.src[0]))),
|
||||
(UPat(Ops.LOAD, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD, src=fuse_address(x.src[0]))),
|
||||
(UPat(Ops.LOAD, dt_32bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS, src=fuse_address(x.src[0]))),
|
||||
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda x: x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype, x.arg),) + fuse_address(x.src[0]) + (imm(dtypes.uint8, 0),))), # noqa: E501
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV, src=fuse_address(x.src[0]))),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_128bit)), name="x"), lambda x: x.ins(X86Ops.VMOVUPSm, src=fuse_address(x.src[0]) + (x.src[1],))),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_64bit)), name="x"), lambda x: x.ins(X86Ops.VMOVSDm, src=fuse_address(x.src[0]) + (x.src[1],))),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_32bit)), name="x"), lambda x: x.ins(X86Ops.VMOVSSm, src=fuse_address(x.src[0]) + (x.src[1],))),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_16bit)), name="x"), lambda x: x.ins(X86Ops.VPEXTRW, src=fuse_address(x.src[0]) + (x.src[1], imm(dtypes.uint8, 0)))), # noqa: E501
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.ints+(dtypes.bool,))), name="x"), lambda x:
|
||||
x.replace(op=X86Ops.MOVm, src=fuse_address(x.src[0]) + (x.src[1],)) if (i:=to_imm(x.src[1])) is None else x.replace(op=X86Ops.MOVi, src=fuse_address(x.src[0]) + (i,))), # noqa: E501
|
||||
x.ins(X86Ops.MOVm, src=fuse_address(x.src[0]) + (x.src[1],)) if (i:=to_imm(x.src[1])) is None else x.ins(X86Ops.MOVi, src=fuse_address(x.src[0]) + (i,))), # noqa: E501
|
||||
# **** X86Op -> X86Op ****
|
||||
# fuse loads into X86Ops that allow it, if beneficial
|
||||
(UPat(X86GroupOp.ReadMem1st, src=(UPat(Ops.LOAD),), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 0)),
|
||||
(UPat(X86GroupOp.ReadMem2nd, src=(UPat(), UPat(Ops.LOAD)), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 1)),
|
||||
(UPat(X86GroupOp.ReadMem3rd, src=(UPat(), UPat(), UPat(Ops.LOAD)), name="x"), lambda ctx,x: fuse_load(ctx, x, 2)),
|
||||
*[(UPat.ins(op, src=(UPat(Ops.LOAD),), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 0)) for op in X86GroupOp.ReadMem1st],
|
||||
*[(UPat.ins(op, src=(UPat(), UPat(Ops.LOAD)), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 1)) for op in X86GroupOp.ReadMem2nd],
|
||||
*[(UPat.ins(op, src=(UPat(), UPat(), UPat(Ops.LOAD)), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 2)) for op in X86GroupOp.ReadMem3rd], # noqa: E501
|
||||
# allocate virtual register to X86Op with special constaints
|
||||
(UPat(X86GroupOp.All, dtypes.ints+dtypes.floats+(dtypes.bool,), name="x"), lambda ctx,x:
|
||||
x.replace(arg=ctx.vreg(x.arg)) if isinstance(x.arg, tuple) else None),
|
||||
(UPat(Ops.INS, dtypes.ints+dtypes.floats+(dtypes.bool,), name="x"), lambda ctx,x:
|
||||
x.replace(tag=ctx.vreg(x.tag)) if isinstance(x.tag, tuple) else None),
|
||||
# allocate virtual register to X86Op without special constraints
|
||||
(UPat(X86GroupOp.All, name="x"), lambda ctx,x:
|
||||
x.replace(arg=ctx.vreg(XMM if x.dtype in dtypes.floats or x.dtype.count > 1 else WGPR)) if x.arg is None and x.dtype != dtypes.void else None),
|
||||
(UPat(Ops.INS, name="x"), lambda ctx,x:
|
||||
x.replace(tag=ctx.vreg(XMM if x.dtype in dtypes.floats or x.dtype.count > 1 else WGPR)) if x.tag is None and x.dtype != dtypes.void else None),
|
||||
])
|
||||
|
||||
# ***** post register allocation *****
|
||||
|
|
@ -402,15 +401,15 @@ isel_matcher = PatternMatcher([
|
|||
# final rewrite to match the isa spec
|
||||
post_regalloc_matcher = PatternMatcher([
|
||||
# alloc stack space
|
||||
(UPat(X86Ops.DEFINE_REG, dtypes.uint64, arg=RSP, name="x"), lambda ctx,x:
|
||||
(x, [x, UOp(X86Ops.SUBi, dtypes.uint64, (imm(dtypes.uint32, ctx.stack_size),), RSP)]) if ctx.stack_size > 0 else None),
|
||||
(UPat.ins(X86Ops.DEFINE_REG, dtype=dtypes.uint64, tag=RSP, name="x"), lambda ctx,x:
|
||||
(x, [x, x.ins(X86Ops.SUBi, src=(imm(dtypes.uint32, ctx.stack_size),), tag=RSP)]) if ctx.stack_size > 0 else None),
|
||||
# dealloc stack space
|
||||
(UPat(X86Ops.RET, name="x"), lambda ctx,x:
|
||||
(UPat.ins(X86Ops.RET, name="x"), lambda ctx,x:
|
||||
(x, [UOp(X86Ops.ADDi, dtypes.uint64, (imm(dtypes.uint32, ctx.stack_size),), RSP), x]) if ctx.stack_size > 0 else None),
|
||||
# rewrite FRAME_INDEX to IMM now that the stack size is known
|
||||
(UPat(X86Ops.FRAME_INDEX, name="x"), lambda ctx,x: (nx:=x.replace(op=X86Ops.IMM, arg=ctx.stack_size + x.arg), [nx])),
|
||||
# rewrite RANGE to MOV reg, 0. Terrible HACK to pass the CONST to the END
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: (nx:=x.replace(op=X86Ops.MOVi, src=(imm(x.dtype, 0),), tag=x.src[0].arg), [nx])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: (nx:=x.ins(X86Ops.MOVi, src=(imm(x.dtype, 0),), tag=x.src[0].arg), [nx])),
|
||||
# rewrite END to ADD 1 -> CMPLT -> JUMP
|
||||
(UPat(Ops.END, name="x"), lambda x:
|
||||
(jl:=x.replace(op=X86Ops.JL, src=(x.src[1], cmp:=UOp(X86Ops.CMPi if isinstance(x.src[1].tag, int) else X86Ops.CMP,
|
||||
|
|
@ -419,7 +418,7 @@ post_regalloc_matcher = PatternMatcher([
|
|||
# TODO: need a generic way to model clobbers, idiv and flags should be handled the same way, maybe add clobber field to Register?
|
||||
# fixup div, zero rdx again because scheduling constraint isn't being respected
|
||||
(UPat(X86Ops.DIV, name="x"), lambda x:
|
||||
(nx:=x.replace(src=x.src[:1]), [UOp(X86Ops.MOVi, x.dtype, (imm(min(dtypes.uint32, x.dtype), 0),), RDX), nx])),
|
||||
(nx:=x.replace(src=x.src[:1]), [x.ins(X86Ops.MOVi, src=(imm(min(dtypes.uint32, x.dtype), 0),), tag=RDX), nx])),
|
||||
# rewrite two address instructions to two address form, if reused src wasn't coalesced insert a move
|
||||
(UPat(X86GroupOp.TwoAddress1st, name="x"), lambda ctx,x:
|
||||
(nx:=x.replace(src=x.src[1:]), [assign(ctx, x.src[0], x.arg), nx] if x.arg != x.src[0].arg else [nx])),
|
||||
|
|
@ -431,11 +430,11 @@ isa_spec = PatternMatcher([
|
|||
# these are the only non X86Ops allowed
|
||||
(UPat((Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER)), lambda: True),
|
||||
# vblends take a mask which is float or int dtype
|
||||
(UPat((X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD), src=(UPat.var("a"), UPat.var("b"), UPat.var("m")), name="x"),
|
||||
lambda a,b,m,x: x.dtype == a.dtype == b.dtype and x.dtype.itemsize == m.dtype.itemsize),
|
||||
#(UPat((X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD), src=(UPat.var("a"), UPat.var("b"), UPat.var("m")), name="x"),
|
||||
# lambda a,b,m,x: x.dtype == a.dtype == b.dtype and x.dtype.itemsize == m.dtype.itemsize),
|
||||
# cmoves take a flag producing instruction
|
||||
(UPat((X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE), dtypes.bool, (UPat(), UPat(), UPat(X86GroupOp.WriteFlags))), lambda: True),
|
||||
(UPat(X86GroupOp.All), lambda: True),
|
||||
#(UPat((X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE), dtypes.bool, (UPat(), UPat(), UPat(X86GroupOp.WriteFlags))), lambda: True),
|
||||
(UPat(Ops.INS), lambda: True),
|
||||
])
|
||||
|
||||
# ***** X86 instruction encoding *****
|
||||
|
|
@ -447,15 +446,15 @@ def to_bytes(dt:DType, v:int|float):
|
|||
|
||||
def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
|
||||
# when a uop writes to memory it takes the form of a store, dtype is void, no definition
|
||||
if x.op in X86GroupOp.WriteMem:
|
||||
if x.arg in X86GroupOp.WriteMem:
|
||||
if len(x.src) > 3: address, rest = x.src[:3], x.src[3:]
|
||||
else: address, rest = (x, None, None), x.src
|
||||
|
||||
elif x.op in X86GroupOp.ReadMem1st or x.op in X86GroupOp.ReadMem2nd and x.op in X86GroupOp.TwoAddress1st:
|
||||
elif x.arg in X86GroupOp.ReadMem1st or x.arg in X86GroupOp.ReadMem2nd and x.arg in X86GroupOp.TwoAddress1st:
|
||||
if len(x.src) > 2: address, rest = x.src[:3], (x,) + x.src[3:]
|
||||
else: address, rest = (x.src[0], None, None), (x,) + x.src[1:]
|
||||
|
||||
elif x.op in X86GroupOp.ReadMem2nd or x.op in X86GroupOp.ReadMem3rd and x.op in X86GroupOp.TwoAddress1st:
|
||||
elif x.arg in X86GroupOp.ReadMem2nd or x.arg in X86GroupOp.ReadMem3rd and x.arg in X86GroupOp.TwoAddress1st:
|
||||
if len(x.src) > 3: address, rest = x.src[1:4], x.src[:1] + x.src[4:]
|
||||
else: address, rest = (x.src[1], None, None), x.src[:1] + x.src[2:]
|
||||
if x.dtype is not dtypes.void: rest = (x,) + rest
|
||||
|
|
@ -464,12 +463,12 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
|
|||
|
||||
# get the encoding values of the different fields
|
||||
reg_sz = (rest[0].dtype.itemsize if not isinstance(rest[0].dtype, PtrDType) else 8) if reg is None else 0
|
||||
reg = cast(Register, rest[0].arg).index if reg is None else reg
|
||||
vvvv = rest[1].arg.index if len(rest) > 1 and isinstance(rest[1].arg, Register) else 0
|
||||
rm = cast(Register, address[0].arg).index
|
||||
idx = cast(Register, address[1].arg).index if address[1] is not None and address[1].arg is not None else 4
|
||||
reg = cast(Register, rest[0].tag).index if reg is None else reg
|
||||
vvvv = rest[1].tag.index if len(rest) > 1 and isinstance(rest[1].tag, Register) else 0
|
||||
rm = cast(Register, address[0].tag).index
|
||||
idx = cast(Register, address[1].tag).index if address[1] is not None and address[1].tag is not None else 4
|
||||
disp_uop = address[2]
|
||||
imm_uop = rest[-1] if rest[-1].op is X86Ops.IMM or len(rest) == 3 else None
|
||||
imm_uop = rest[-1] if rest[-1].arg is X86Ops.IMM or len(rest) == 3 else None
|
||||
# TODO: another reason to get rid of ptrs, if we access memory the size should be in scale uop otherwise size is in rm
|
||||
rm_sz = 8 if isinstance(address[0].dtype, PtrDType) and disp_uop is None else address[0].dtype.itemsize
|
||||
|
||||
|
|
@ -494,7 +493,7 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
|
|||
if w | r | _x | b | (reg_sz == 1 & reg >> 2) | (rm_sz == 1 & rm >> 2): inst += bytes([0b0100 << 4 | w << 3 | r << 2 | _x << 1 | b])
|
||||
# OPCODE byte
|
||||
# legacy 8bit opcodes are 1 less than 16-64bit versions, with these exceptions
|
||||
real_opc = opc-1 if (rm_sz == 1 or reg_sz == 1) and x.op not in {X86Ops.SETB, X86Ops.SETE, X86Ops.SETL, X86Ops.SETNE, X86Ops.LEA} else opc
|
||||
real_opc = opc-1 if (rm_sz == 1 or reg_sz == 1) and x.arg not in {X86Ops.SETB, X86Ops.SETE, X86Ops.SETL, X86Ops.SETNE, X86Ops.LEA} else opc
|
||||
inst += real_opc.to_bytes((real_opc.bit_length() + 7) // 8, 'big')
|
||||
# MODRM byte
|
||||
# now we only care about the lower 3 bits
|
||||
|
|
@ -506,7 +505,7 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
|
|||
if disp_uop is not None:
|
||||
assert disp_uop.dtype in (dtypes.int8, dtypes.int32), "displacement can only be 1 or 4 byte signed int"
|
||||
# rbp/r13 always require a displacement
|
||||
if disp_uop.arg != 0 or rm == 0b101: mod = 0b01 if disp_uop.dtype.itemsize == 1 else 0b10
|
||||
if disp_uop.tag != 0 or rm == 0b101: mod = 0b01 if disp_uop.dtype.itemsize == 1 else 0b10
|
||||
else: mod = 0b00
|
||||
else: mod = 0b11
|
||||
# x 0b0 and idx 0b100 means rsp which means no index exists
|
||||
|
|
@ -520,11 +519,11 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
|
|||
# DISP byte
|
||||
if mod == 0b01 or mod == 0b10:
|
||||
assert disp_uop is not None
|
||||
inst += disp_uop.arg.to_bytes(disp_uop.dtype.itemsize, 'little', signed=True)
|
||||
inst += disp_uop.tag.to_bytes(disp_uop.dtype.itemsize, 'little', signed=True)
|
||||
# IMM byte
|
||||
if imm_uop is not None:
|
||||
if isinstance(imm_uop.arg, Register): inst += bytes([(imm_uop.arg.index & 0b1111) << 4 | 0b0000])
|
||||
else: inst += to_bytes(imm_uop.dtype, imm_uop.arg)
|
||||
if isinstance(imm_uop.tag, Register): inst += bytes([(imm_uop.tag.index & 0b1111) << 4 | 0b0000])
|
||||
else: inst += to_bytes(imm_uop.dtype, imm_uop.tag)
|
||||
return inst
|
||||
|
||||
# https://www.felixcloutier.com/x86/
|
||||
|
|
@ -533,109 +532,109 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0):
|
|||
# map select: 0F == 1, 0F38 == 2, 0F3A == 3
|
||||
encodings = PatternMatcher([
|
||||
# moves
|
||||
(UPat(X86Ops.MOVABS, name="x"), lambda x:
|
||||
(UPat.ins(X86Ops.MOVABS, name="x"), lambda x:
|
||||
bytes([0b0100 << 4 | 0b1 << 3 | 0b00 << 2 | x.arg.index >> 3, 0xB8 + (x.arg.index & 0b111)]) + to_bytes(x.src[0].dtype, x.src[0].arg)),
|
||||
(UPat(X86Ops.MOV, name="x"), lambda x: encode(x, 0x8B)), (UPat(X86Ops.MOVi, name="x"), lambda x: encode(x, 0xC7, reg=0)),
|
||||
(UPat(X86Ops.MOVm, name="x"), lambda x: encode(x, 0x89)), (UPat(X86Ops.LEA, name="x"), lambda x: encode(x, 0x8D)),
|
||||
(UPat(X86Ops.VMOVSS, name="x"), lambda x: encode(x, 0x10, pp=2, sel=1)), (UPat(X86Ops.VMOVSSm, name="x"), lambda x: encode(x, 0x11, pp=2, sel=1)),
|
||||
(UPat(X86Ops.VMOVSD, name="x"), lambda x: encode(x, 0x10, pp=3, sel=1)), (UPat(X86Ops.VMOVSDm, name="x"), lambda x: encode(x, 0x11, pp=3, sel=1)),
|
||||
(UPat(X86Ops.VMOVUPS, name="x"), lambda x: encode(x, 0x10, pp=0, sel=1)), (UPat(X86Ops.VMOVUPSm, name="x"), lambda x: encode(x, 0x11, pp=0, sel=1)), # noqa: E501
|
||||
(UPat(X86Ops.VMOVD, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1)), (UPat(X86Ops.VMOVQ, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1, we=1)), # noqa: E501
|
||||
(UPat(X86Ops.VMOVDm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1)), (UPat(X86Ops.VMOVQm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1, we=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.MOV, name="x"), lambda x: encode(x, 0x8B)), (UPat.ins(X86Ops.MOVi, name="x"), lambda x: encode(x, 0xC7, reg=0)),
|
||||
(UPat.ins(X86Ops.MOVm, name="x"), lambda x: encode(x, 0x89)), (UPat.ins(X86Ops.LEA, name="x"), lambda x: encode(x, 0x8D)),
|
||||
(UPat.ins(X86Ops.VMOVSS, name="x"), lambda x: encode(x, 0x10, pp=2, sel=1)), (UPat.ins(X86Ops.VMOVSSm, name="x"), lambda x: encode(x, 0x11, pp=2, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMOVSD, name="x"), lambda x: encode(x, 0x10, pp=3, sel=1)), (UPat.ins(X86Ops.VMOVSDm, name="x"), lambda x: encode(x, 0x11, pp=3, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMOVUPS, name="x"), lambda x: encode(x, 0x10, pp=0, sel=1)), (UPat.ins(X86Ops.VMOVUPSm, name="x"), lambda x: encode(x, 0x11, pp=0, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMOVD, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1)), (UPat.ins(X86Ops.VMOVQ, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1, we=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMOVDm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1)), (UPat.ins(X86Ops.VMOVQm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1, we=1)), # noqa: E501
|
||||
# casts
|
||||
(UPat(X86Ops.MOVZX, name="x"), lambda x: encode(x, 0x0FB7)),
|
||||
(UPat(X86Ops.MOVSX, name="x"), lambda x: encode(x, 0x0FBF)), (UPat(X86Ops.MOVSXD, name="x"), lambda x: encode(x, 0x63)),
|
||||
(UPat(X86Ops.VPMOVZXBW, name="x"), lambda x: encode(x, 0x30, pp=1, sel=2)), (UPat(X86Ops.VPMOVZXBD, name="x"), lambda x: encode(x, 0x31, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VPMOVZXBQ, name="x"), lambda x: encode(x, 0x32, pp=1, sel=2)), (UPat(X86Ops.VPMOVZXWD, name="x"), lambda x: encode(x, 0x33, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VPMOVZXWQ, name="x"), lambda x: encode(x, 0x34, pp=1, sel=2)), (UPat(X86Ops.VPMOVZXDQ, name="x"), lambda x: encode(x, 0x35, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VPMOVSXBW, name="x"), lambda x: encode(x, 0x20, pp=1, sel=2)), (UPat(X86Ops.VPMOVSXBD, name="x"), lambda x: encode(x, 0x21, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VPMOVSXBQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=2)), (UPat(X86Ops.VPMOVSXWD, name="x"), lambda x: encode(x, 0x23, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VPMOVSXWQ, name="x"), lambda x: encode(x, 0x24, pp=1, sel=2)), (UPat(X86Ops.VPMOVSXDQ, name="x"), lambda x: encode(x, 0x25, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VCVTSS2SD, name="x"), lambda x: encode(x, 0x5A, pp=2, sel=1)), (UPat(X86Ops.VCVTSD2SS, name="x"), lambda x: encode(x, 0x5A, pp=3, sel=1)), # noqa: E501
|
||||
(UPat(X86Ops.VCVTPH2PS, name="x"), lambda x: encode(x, 0x13, pp=1, sel=2)), (UPat(X86Ops.VCVTPS2PH, name="x"), lambda x: encode(x, 0x1D, pp=1, sel=3)), # noqa: E501
|
||||
(UPat(X86Ops.VCVTDQ2PS, name="x"), lambda x: encode(x, 0x5B, pp=0, sel=1)), (UPat(X86Ops.VCVTDQ2PD, name="x"), lambda x: encode(x, 0xE6, pp=2, sel=1)), # noqa: E501
|
||||
(UPat(X86Ops.VCVTPS2PD, name="x"), lambda x: encode(x, 0x5A, pp=0, sel=1)), (UPat(X86Ops.VCVTPD2PS, name="x"), lambda x: encode(x, 0x5A, pp=1, sel=1)), # noqa: E501
|
||||
(UPat(X86Ops.VCVTTPS2DQ, name="x"), lambda x: encode(x, 0x5B, pp=2, sel=1)), (UPat(X86Ops.VCVTTPD2DQ, name="x"), lambda x: encode(x, 0xE6, pp=1, sel=1)), # noqa: E501
|
||||
(UPat(X86Ops.VCVTSI2SS, name="x"), lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.base is dtypes.int64)),
|
||||
(UPat(X86Ops.VCVTSI2SD, name="x"), lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.base is dtypes.int64)),
|
||||
(UPat(X86Ops.VCVTTSS2SI, name="x"), lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype in dtypes.int64s)),
|
||||
(UPat(X86Ops.VCVTTSD2SI, name="x"), lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype in dtypes.int64s)),
|
||||
(UPat.ins(X86Ops.MOVZX, name="x"), lambda x: encode(x, 0x0FB7)),
|
||||
(UPat.ins(X86Ops.MOVSX, name="x"), lambda x: encode(x, 0x0FBF)), (UPat.ins(X86Ops.MOVSXD, name="x"), lambda x: encode(x, 0x63)),
|
||||
(UPat.ins(X86Ops.VPMOVZXBW, name="x"), lambda x: encode(x, 0x30, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVZXBD, name="x"), lambda x: encode(x, 0x31, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPMOVZXBQ, name="x"), lambda x: encode(x, 0x32, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVZXWD, name="x"), lambda x: encode(x, 0x33, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPMOVZXWQ, name="x"), lambda x: encode(x, 0x34, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVZXDQ, name="x"), lambda x: encode(x, 0x35, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPMOVSXBW, name="x"), lambda x: encode(x, 0x20, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVSXBD, name="x"), lambda x: encode(x, 0x21, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPMOVSXBQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVSXWD, name="x"), lambda x: encode(x, 0x23, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPMOVSXWQ, name="x"), lambda x: encode(x, 0x24, pp=1, sel=2)), (UPat.ins(X86Ops.VPMOVSXDQ, name="x"), lambda x: encode(x, 0x25, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VCVTSS2SD, name="x"), lambda x: encode(x, 0x5A, pp=2, sel=1)), (UPat.ins(X86Ops.VCVTSD2SS, name="x"), lambda x: encode(x, 0x5A, pp=3, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VCVTPH2PS, name="x"), lambda x: encode(x, 0x13, pp=1, sel=2)), (UPat.ins(X86Ops.VCVTPS2PH, name="x"), lambda x: encode(x, 0x1D, pp=1, sel=3)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VCVTDQ2PS, name="x"), lambda x: encode(x, 0x5B, pp=0, sel=1)), (UPat.ins(X86Ops.VCVTDQ2PD, name="x"), lambda x: encode(x, 0xE6, pp=2, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VCVTPS2PD, name="x"), lambda x: encode(x, 0x5A, pp=0, sel=1)), (UPat.ins(X86Ops.VCVTPD2PS, name="x"), lambda x: encode(x, 0x5A, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VCVTTPS2DQ, name="x"), lambda x: encode(x, 0x5B, pp=2, sel=1)), (UPat.ins(X86Ops.VCVTTPD2DQ, name="x"), lambda x: encode(x, 0xE6, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VCVTSI2SS, name="x"), lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.base is dtypes.int64)),
|
||||
(UPat.ins(X86Ops.VCVTSI2SD, name="x"), lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.base is dtypes.int64)),
|
||||
(UPat.ins(X86Ops.VCVTTSS2SI, name="x"), lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype in dtypes.int64s)),
|
||||
(UPat.ins(X86Ops.VCVTTSD2SI, name="x"), lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype in dtypes.int64s)),
|
||||
# int division
|
||||
(UPat(X86Ops.IDIV, name="x"), lambda x: encode(x, 0xF7, reg=7)), (UPat(X86Ops.DIV, name="x"), lambda x: encode(x, 0xF7, reg=6)),
|
||||
(UPat.ins(X86Ops.IDIV, name="x"), lambda x: encode(x, 0xF7, reg=7)), (UPat.ins(X86Ops.DIV, name="x"), lambda x: encode(x, 0xF7, reg=6)),
|
||||
# scalar int binary
|
||||
(UPat(X86Ops.SHLi, name="x"), lambda x: encode(x, 0xC1, reg=4)),
|
||||
(UPat(X86Ops.SHRi, name="x"), lambda x: encode(x, 0xC1, reg=5)), (UPat(X86Ops.SARi, name="x"), lambda x: encode(x, 0xC1, reg=7)),
|
||||
(UPat(X86Ops.ADD, name="x"), lambda x: encode(x, 0x03)), (UPat(X86Ops.ADDi, name="x"), lambda x: encode(x, 0x81, reg=0)),
|
||||
(UPat(X86Ops.SUB, name="x"), lambda x: encode(x, 0x2B)), (UPat(X86Ops.SUBi, name="x"), lambda x: encode(x, 0x81, reg=5)),
|
||||
(UPat(X86Ops.AND, name="x"), lambda x: encode(x, 0x23)), (UPat(X86Ops.ANDi, name="x"), lambda x: encode(x, 0x81, reg=4)),
|
||||
(UPat(X86Ops.XOR, name="x"), lambda x: encode(x, 0x33)), (UPat(X86Ops.XORi, name="x"), lambda x: encode(x, 0x81, reg=6)),
|
||||
(UPat(X86Ops.OR, name="x"), lambda x: encode(x, 0x0B)), (UPat(X86Ops.ORi, name="x"), lambda x: encode(x, 0x81, reg=1)),
|
||||
(UPat(X86Ops.CMP, name="x"), lambda x: encode(x, 0x3B)), (UPat(X86Ops.CMPi, name="x"), lambda x: encode(x, 0x81, reg=7)),
|
||||
(UPat(X86Ops.IMUL, name="x"), lambda x: encode(x, 0x0FAF)), (UPat(X86Ops.IMULi, name="x"), lambda x: encode(x, 0x69)),
|
||||
(UPat(X86Ops.SETB, name="x"), lambda x: encode(x, 0x0F92, reg=0)), (UPat(X86Ops.SETL, name="x"), lambda x: encode(x, 0x0F9C, reg=0)),
|
||||
(UPat(X86Ops.SETE, name="x"), lambda x: encode(x, 0x0F94, reg=0)), (UPat(X86Ops.SETNE, name="x"), lambda x: encode(x, 0x0F95, reg=0)),
|
||||
(UPat.ins(X86Ops.SHLi, name="x"), lambda x: encode(x, 0xC1, reg=4)),
|
||||
(UPat.ins(X86Ops.SHRi, name="x"), lambda x: encode(x, 0xC1, reg=5)), (UPat.ins(X86Ops.SARi, name="x"), lambda x: encode(x, 0xC1, reg=7)),
|
||||
(UPat.ins(X86Ops.ADD, name="x"), lambda x: encode(x, 0x03)), (UPat.ins(X86Ops.ADDi, name="x"), lambda x: encode(x, 0x81, reg=0)),
|
||||
(UPat.ins(X86Ops.SUB, name="x"), lambda x: encode(x, 0x2B)), (UPat.ins(X86Ops.SUBi, name="x"), lambda x: encode(x, 0x81, reg=5)),
|
||||
(UPat.ins(X86Ops.AND, name="x"), lambda x: encode(x, 0x23)), (UPat.ins(X86Ops.ANDi, name="x"), lambda x: encode(x, 0x81, reg=4)),
|
||||
(UPat.ins(X86Ops.XOR, name="x"), lambda x: encode(x, 0x33)), (UPat.ins(X86Ops.XORi, name="x"), lambda x: encode(x, 0x81, reg=6)),
|
||||
(UPat.ins(X86Ops.OR, name="x"), lambda x: encode(x, 0x0B)), (UPat.ins(X86Ops.ORi, name="x"), lambda x: encode(x, 0x81, reg=1)),
|
||||
(UPat.ins(X86Ops.CMP, name="x"), lambda x: encode(x, 0x3B)), (UPat.ins(X86Ops.CMPi, name="x"), lambda x: encode(x, 0x81, reg=7)),
|
||||
(UPat.ins(X86Ops.IMUL, name="x"), lambda x: encode(x, 0x0FAF)), (UPat.ins(X86Ops.IMULi, name="x"), lambda x: encode(x, 0x69)),
|
||||
(UPat.ins(X86Ops.SETB, name="x"), lambda x: encode(x, 0x0F92, reg=0)), (UPat.ins(X86Ops.SETL, name="x"), lambda x: encode(x, 0x0F9C, reg=0)),
|
||||
(UPat.ins(X86Ops.SETE, name="x"), lambda x: encode(x, 0x0F94, reg=0)), (UPat.ins(X86Ops.SETNE, name="x"), lambda x: encode(x, 0x0F95, reg=0)),
|
||||
# packed bitwise NOTE: only bitwise and packed
|
||||
(UPat(X86Ops.VPAND, name="x"), lambda x: encode(x, 0xDB, pp=1, sel=1)), (UPat(X86Ops.VPXOR, name="x"), lambda x: encode(x, 0xEF, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VPOR, name="x"), lambda x: encode(x, 0xEB, pp=1, sel=1)),
|
||||
(UPat.ins(X86Ops.VPAND, name="x"), lambda x: encode(x, 0xDB, pp=1, sel=1)), (UPat.ins(X86Ops.VPXOR, name="x"), lambda x: encode(x, 0xEF, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPOR, name="x"), lambda x: encode(x, 0xEB, pp=1, sel=1)),
|
||||
# unary
|
||||
(UPat(X86Ops.VSQRTSS, name="x"), lambda x: encode(x, 0x51, pp=2, sel=1)), (UPat(X86Ops.VSQRTPS, name="x"), lambda x: encode(x, 0x51, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VSQRTSD, name="x"), lambda x: encode(x, 0x51, pp=3, sel=1)), (UPat(X86Ops.VSQRTPD, name="x"), lambda x: encode(x, 0x51, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VROUNDSS, name="x"), lambda x: encode(x, 0x0A, pp=1, sel=3)), (UPat(X86Ops.VROUNDPS, name="x"), lambda x: encode(x, 0x08, pp=1, sel=3)), # noqa: E501
|
||||
(UPat(X86Ops.VROUNDSD, name="x"), lambda x: encode(x, 0x0B, pp=1, sel=3)), (UPat(X86Ops.VROUNDPD, name="x"), lambda x: encode(x, 0x09, pp=1, sel=3)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VSQRTSS, name="x"), lambda x: encode(x, 0x51, pp=2, sel=1)), (UPat.ins(X86Ops.VSQRTPS, name="x"), lambda x: encode(x, 0x51, pp=0, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VSQRTSD, name="x"), lambda x: encode(x, 0x51, pp=3, sel=1)), (UPat.ins(X86Ops.VSQRTPD, name="x"), lambda x: encode(x, 0x51, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VROUNDSS, name="x"), lambda x: encode(x, 0x0A, pp=1, sel=3)), (UPat.ins(X86Ops.VROUNDPS, name="x"), lambda x: encode(x, 0x08, pp=1, sel=3)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VROUNDSD, name="x"), lambda x: encode(x, 0x0B, pp=1, sel=3)), (UPat.ins(X86Ops.VROUNDPD, name="x"), lambda x: encode(x, 0x09, pp=1, sel=3)), # noqa: E501
|
||||
# packed int binary
|
||||
(UPat(X86Ops.VPSLLVD, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2)), (UPat(X86Ops.VPSLLVQ, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2, we=1)), # noqa: E501
|
||||
(UPat(X86Ops.VPSRLVD, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2)), (UPat(X86Ops.VPSRLVQ, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2, we=1)), # noqa: E501
|
||||
(UPat(X86Ops.VPCMPGTB, name="x"), lambda x: encode(x, 0x64, pp=1, sel=1)), (UPat(X86Ops.VPCMPGTW, name="x"), lambda x: encode(x, 0x65, pp=1, sel=1)), # noqa: E501
|
||||
(UPat(X86Ops.VPCMPGTD, name="x"), lambda x: encode(x, 0x66, pp=1, sel=1)), (UPat(X86Ops.VPCMPGTQ, name="x"), lambda x: encode(x, 0x37, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VPCMPEQB, name="x"), lambda x: encode(x, 0x74, pp=1, sel=1)), (UPat(X86Ops.VPCMPEQW, name="x"), lambda x: encode(x, 0x75, pp=1, sel=1)), # noqa: E501
|
||||
(UPat(X86Ops.VPCMPEQD, name="x"), lambda x: encode(x, 0x76, pp=1, sel=1)), (UPat(X86Ops.VPCMPEQQ, name="x"), lambda x: encode(x, 0x29, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VPMULLW, name="x"), lambda x: encode(x, 0xD5, pp=1, sel=1)), (UPat(X86Ops.VPMULLD, name="x"), lambda x: encode(x, 0x40, pp=1, sel=2)),
|
||||
(UPat(X86Ops.VPADDB, name="x"), lambda x: encode(x, 0xFC, pp=1, sel=1)), (UPat(X86Ops.VPADDW, name="x"), lambda x: encode(x, 0xFD, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VPADDD, name="x"), lambda x: encode(x, 0xFE, pp=1, sel=1)), (UPat(X86Ops.VPADDQ, name="x"), lambda x: encode(x, 0xD4, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VPSUBB, name="x"), lambda x: encode(x, 0xF8, pp=1, sel=1)), (UPat(X86Ops.VPSUBW, name="x"), lambda x: encode(x, 0xF9, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VPSUBD, name="x"), lambda x: encode(x, 0xFA, pp=1, sel=1)), (UPat(X86Ops.VPSUBQ, name="x"), lambda x: encode(x, 0xFB, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VPSRAVD, name="x"), lambda x: encode(x, 0x46, pp=1, sel=2)),
|
||||
(UPat.ins(X86Ops.VPSLLVD, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2)), (UPat.ins(X86Ops.VPSLLVQ, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2, we=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPSRLVD, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2)), (UPat.ins(X86Ops.VPSRLVQ, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2, we=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPCMPGTB, name="x"), lambda x: encode(x, 0x64, pp=1, sel=1)), (UPat.ins(X86Ops.VPCMPGTW, name="x"), lambda x: encode(x, 0x65, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPCMPGTD, name="x"), lambda x: encode(x, 0x66, pp=1, sel=1)), (UPat.ins(X86Ops.VPCMPGTQ, name="x"), lambda x: encode(x, 0x37, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPCMPEQB, name="x"), lambda x: encode(x, 0x74, pp=1, sel=1)), (UPat.ins(X86Ops.VPCMPEQW, name="x"), lambda x: encode(x, 0x75, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPCMPEQD, name="x"), lambda x: encode(x, 0x76, pp=1, sel=1)), (UPat.ins(X86Ops.VPCMPEQQ, name="x"), lambda x: encode(x, 0x29, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPMULLW, name="x"), lambda x: encode(x, 0xD5, pp=1, sel=1)), (UPat.ins(X86Ops.VPMULLD, name="x"), lambda x: encode(x, 0x40, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPADDB, name="x"), lambda x: encode(x, 0xFC, pp=1, sel=1)), (UPat.ins(X86Ops.VPADDW, name="x"), lambda x: encode(x, 0xFD, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPADDD, name="x"), lambda x: encode(x, 0xFE, pp=1, sel=1)), (UPat.ins(X86Ops.VPADDQ, name="x"), lambda x: encode(x, 0xD4, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPSUBB, name="x"), lambda x: encode(x, 0xF8, pp=1, sel=1)), (UPat.ins(X86Ops.VPSUBW, name="x"), lambda x: encode(x, 0xF9, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPSUBD, name="x"), lambda x: encode(x, 0xFA, pp=1, sel=1)), (UPat.ins(X86Ops.VPSUBQ, name="x"), lambda x: encode(x, 0xFB, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPSRAVD, name="x"), lambda x: encode(x, 0x46, pp=1, sel=2)),
|
||||
# float cmp
|
||||
(UPat(X86Ops.VUCOMISS, name="x"), lambda x: encode(x, 0x2E, pp=0, sel=1)), (UPat(X86Ops.VUCOMISD, name="x"), lambda x: encode(x, 0x2E, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VUCOMISS, name="x"), lambda x: encode(x, 0x2E, pp=0, sel=1)), (UPat.ins(X86Ops.VUCOMISD, name="x"), lambda x: encode(x, 0x2E, pp=1, sel=1)), # noqa: E501
|
||||
# scalar / packed float binary
|
||||
(UPat(X86Ops.VADDSS, name="x"), lambda x: encode(x, 0x58, pp=2, sel=1)), (UPat(X86Ops.VADDPS, name="x"), lambda x: encode(x, 0x58, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VADDSD, name="x"), lambda x: encode(x, 0x58, pp=3, sel=1)), (UPat(X86Ops.VADDPD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VSUBSS, name="x"), lambda x: encode(x, 0x5C, pp=2, sel=1)), (UPat(X86Ops.VSUBPS, name="x"), lambda x: encode(x, 0x5C, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VSUBSD, name="x"), lambda x: encode(x, 0x5C, pp=3, sel=1)), (UPat(X86Ops.VSUBPD, name="x"), lambda x: encode(x, 0x5C, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VMULSS, name="x"), lambda x: encode(x, 0x59, pp=2, sel=1)), (UPat(X86Ops.VMULPS, name="x"), lambda x: encode(x, 0x59, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VMULSD, name="x"), lambda x: encode(x, 0x59, pp=3, sel=1)), (UPat(X86Ops.VMULPD, name="x"), lambda x: encode(x, 0x59, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VDIVSS, name="x"), lambda x: encode(x, 0x5E, pp=2, sel=1)), (UPat(X86Ops.VDIVPS, name="x"), lambda x: encode(x, 0x5E, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VDIVSD, name="x"), lambda x: encode(x, 0x5E, pp=3, sel=1)), (UPat(X86Ops.VDIVPD, name="x"), lambda x: encode(x, 0x5E, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VCMPSS, name="x"), lambda x: encode(x, 0xC2, pp=2, sel=1)), (UPat(X86Ops.VCMPPS, name="x"), lambda x: encode(x, 0xC2, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VCMPSD, name="x"), lambda x: encode(x, 0xC2, pp=3, sel=1)), (UPat(X86Ops.VCMPPD, name="x"), lambda x: encode(x, 0xC2, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VMAXSS, name="x"), lambda x: encode(x, 0x5F, pp=2, sel=1)), (UPat(X86Ops.VMAXPS, name="x"), lambda x: encode(x, 0x5F, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VMAXSD, name="x"), lambda x: encode(x, 0x5F, pp=3, sel=1)), (UPat(X86Ops.VMAXPD, name="x"), lambda x: encode(x, 0x5F, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VMINSS, name="x"), lambda x: encode(x, 0x5D, pp=2, sel=1)), (UPat(X86Ops.VMINPS, name="x"), lambda x: encode(x, 0x5D, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VMINSD, name="x"), lambda x: encode(x, 0x5D, pp=3, sel=1)), (UPat(X86Ops.VMINPD, name="x"), lambda x: encode(x, 0x5D, pp=1, sel=1)),
|
||||
(UPat.ins(X86Ops.VADDSS, name="x"), lambda x: encode(x, 0x58, pp=2, sel=1)), (UPat.ins(X86Ops.VADDPS, name="x"), lambda x: encode(x, 0x58, pp=0, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VADDSD, name="x"), lambda x: encode(x, 0x58, pp=3, sel=1)), (UPat.ins(X86Ops.VADDPD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VSUBSS, name="x"), lambda x: encode(x, 0x5C, pp=2, sel=1)), (UPat.ins(X86Ops.VSUBPS, name="x"), lambda x: encode(x, 0x5C, pp=0, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VSUBSD, name="x"), lambda x: encode(x, 0x5C, pp=3, sel=1)), (UPat.ins(X86Ops.VSUBPD, name="x"), lambda x: encode(x, 0x5C, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMULSS, name="x"), lambda x: encode(x, 0x59, pp=2, sel=1)), (UPat.ins(X86Ops.VMULPS, name="x"), lambda x: encode(x, 0x59, pp=0, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMULSD, name="x"), lambda x: encode(x, 0x59, pp=3, sel=1)), (UPat.ins(X86Ops.VMULPD, name="x"), lambda x: encode(x, 0x59, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VDIVSS, name="x"), lambda x: encode(x, 0x5E, pp=2, sel=1)), (UPat.ins(X86Ops.VDIVPS, name="x"), lambda x: encode(x, 0x5E, pp=0, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VDIVSD, name="x"), lambda x: encode(x, 0x5E, pp=3, sel=1)), (UPat.ins(X86Ops.VDIVPD, name="x"), lambda x: encode(x, 0x5E, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VCMPSS, name="x"), lambda x: encode(x, 0xC2, pp=2, sel=1)), (UPat.ins(X86Ops.VCMPPS, name="x"), lambda x: encode(x, 0xC2, pp=0, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VCMPSD, name="x"), lambda x: encode(x, 0xC2, pp=3, sel=1)), (UPat.ins(X86Ops.VCMPPD, name="x"), lambda x: encode(x, 0xC2, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMAXSS, name="x"), lambda x: encode(x, 0x5F, pp=2, sel=1)), (UPat.ins(X86Ops.VMAXPS, name="x"), lambda x: encode(x, 0x5F, pp=0, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMAXSD, name="x"), lambda x: encode(x, 0x5F, pp=3, sel=1)), (UPat.ins(X86Ops.VMAXPD, name="x"), lambda x: encode(x, 0x5F, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMINSS, name="x"), lambda x: encode(x, 0x5D, pp=2, sel=1)), (UPat.ins(X86Ops.VMINPS, name="x"), lambda x: encode(x, 0x5D, pp=0, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VMINSD, name="x"), lambda x: encode(x, 0x5D, pp=3, sel=1)), (UPat.ins(X86Ops.VMINPD, name="x"), lambda x: encode(x, 0x5D, pp=1, sel=1)), # noqa: E501
|
||||
# ternary
|
||||
(UPat(X86Ops.CMOVB, name="x"), lambda x: encode(x, 0x0F42)), (UPat(X86Ops.CMOVL, name="x"), lambda x: encode(x, 0x0F4C)),
|
||||
(UPat(X86Ops.CMOVE, name="x"), lambda x: encode(x, 0x0F44)), (UPat(X86Ops.CMOVNE, name="x"), lambda x: encode(x, 0x0F45)),
|
||||
(UPat(X86Ops.VFMADD213SS, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2)), (UPat(X86Ops.VFMADD213SD, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2, we=1)), # noqa: E501
|
||||
(UPat(X86Ops.VFMADD213PS, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2)), (UPat(X86Ops.VFMADD213PD, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2, we=1)), # noqa: E501
|
||||
(UPat(X86Ops.VBLENDVPS, name="x"), lambda x: encode(x, 0x4A, pp=1, sel=3)), (UPat(X86Ops.VBLENDVPD, name="x"), lambda x: encode(x, 0x4B, pp=1, sel=3)), # noqa: E501
|
||||
(UPat(X86Ops.VPBLENDVB, name="x"), lambda x: encode(x, 0x4C, pp=1, sel=3)),
|
||||
(UPat.ins(X86Ops.CMOVB, name="x"), lambda x: encode(x, 0x0F42)), (UPat.ins(X86Ops.CMOVL, name="x"), lambda x: encode(x, 0x0F4C)),
|
||||
(UPat.ins(X86Ops.CMOVE, name="x"), lambda x: encode(x, 0x0F44)), (UPat.ins(X86Ops.CMOVNE, name="x"), lambda x: encode(x, 0x0F45)),
|
||||
(UPat.ins(X86Ops.VFMADD213SS, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2)), (UPat.ins(X86Ops.VFMADD213SD, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2, we=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VFMADD213PS, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2)), (UPat.ins(X86Ops.VFMADD213PD, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2, we=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VBLENDVPS, name="x"), lambda x: encode(x, 0x4A, pp=1, sel=3)), (UPat.ins(X86Ops.VBLENDVPD, name="x"), lambda x: encode(x, 0x4B, pp=1, sel=3)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPBLENDVB, name="x"), lambda x: encode(x, 0x4C, pp=1, sel=3)),
|
||||
# shuffles
|
||||
(UPat(X86Ops.VPBROADCASTB, name="x"), lambda x: encode(x, 0x78, pp=1, sel=2)), (UPat(X86Ops.VPBROADCASTW, name="x"), lambda x: encode(x, 0x79, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VPBROADCASTD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=2)), (UPat(X86Ops.VPBROADCASTQ, name="x"), lambda x: encode(x, 0x59, pp=1, sel=2)), # noqa: E501
|
||||
(UPat(X86Ops.VBROADCASTSS, name="x"), lambda x: encode(x, 0x18, pp=1, sel=2)),
|
||||
(UPat(X86Ops.VPINSRB, name="x"), lambda x: encode(x, 0x20, pp=1, sel=3)), (UPat(X86Ops.VPINSRW, name="x"), lambda x: encode(x, 0xC4, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VPINSRD, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3)), (UPat(X86Ops.VPINSRQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3, we=1)), # noqa: E501
|
||||
(UPat(X86Ops.VSHUFPS, name="x"), lambda x: encode(x, 0xC6, pp=0, sel=1)), (UPat(X86Ops.VINSERTPS, name="x"), lambda x: encode(x, 0x21, pp=1, sel=3)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPBROADCASTB, name="x"), lambda x: encode(x, 0x78, pp=1, sel=2)), (UPat.ins(X86Ops.VPBROADCASTW, name="x"), lambda x: encode(x, 0x79, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPBROADCASTD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=2)), (UPat.ins(X86Ops.VPBROADCASTQ, name="x"), lambda x: encode(x, 0x59, pp=1, sel=2)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VBROADCASTSS, name="x"), lambda x: encode(x, 0x18, pp=1, sel=2)),
|
||||
(UPat.ins(X86Ops.VPINSRB, name="x"), lambda x: encode(x, 0x20, pp=1, sel=3)), (UPat.ins(X86Ops.VPINSRW, name="x"), lambda x: encode(x, 0xC4, pp=1, sel=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPINSRD, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3)), (UPat.ins(X86Ops.VPINSRQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3, we=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VSHUFPS, name="x"), lambda x: encode(x, 0xC6, pp=0, sel=1)), (UPat.ins(X86Ops.VINSERTPS, name="x"), lambda x: encode(x, 0x21, pp=1, sel=3)), # noqa: E501
|
||||
# extract
|
||||
(UPat(X86Ops.VPEXTRB, name="x"), lambda x: encode(x, 0x14, pp=1, sel=3)), (UPat(X86Ops.VPEXTRW, name="x"), lambda x: encode(x, 0x15, pp=1, sel=3)),
|
||||
(UPat(X86Ops.VPEXTRD, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3)), (UPat(X86Ops.VPEXTRQ, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3, we=1)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPEXTRB, name="x"), lambda x: encode(x, 0x14, pp=1, sel=3)), (UPat.ins(X86Ops.VPEXTRW, name="x"), lambda x: encode(x, 0x15, pp=1, sel=3)), # noqa: E501
|
||||
(UPat.ins(X86Ops.VPEXTRD, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3)), (UPat.ins(X86Ops.VPEXTRQ, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3, we=1)), # noqa: E501
|
||||
# jumps are encoded with a placeholder which gets patched later once the real offset is known
|
||||
(UPat(X86Ops.JE), lambda: bytes([0x0F, 0x84]) + int(0).to_bytes(4, 'little', signed=True)),
|
||||
(UPat(X86Ops.JNE), lambda: bytes([0x0F, 0x85]) + int(0).to_bytes(4, 'little', signed=True)),
|
||||
(UPat(X86Ops.JL), lambda: bytes([0x0F, 0x8C]) + int(0).to_bytes(4, 'little', signed=True)),
|
||||
(UPat(X86Ops.JB), lambda: bytes([0x0F, 0x82]) + int(0).to_bytes(4, 'little', signed=True)),
|
||||
(UPat.ins(X86Ops.JE), lambda: bytes([0x0F, 0x84]) + int(0).to_bytes(4, 'little', signed=True)),
|
||||
(UPat.ins(X86Ops.JNE), lambda: bytes([0x0F, 0x85]) + int(0).to_bytes(4, 'little', signed=True)),
|
||||
(UPat.ins(X86Ops.JL), lambda: bytes([0x0F, 0x8C]) + int(0).to_bytes(4, 'little', signed=True)),
|
||||
(UPat.ins(X86Ops.JB), lambda: bytes([0x0F, 0x82]) + int(0).to_bytes(4, 'little', signed=True)),
|
||||
# return
|
||||
(UPat(X86Ops.RET), lambda: bytes([0xC3])),
|
||||
(UPat.ins(X86Ops.RET), lambda: bytes([0xC3])),
|
||||
])
|
||||
|
||||
class X86Renderer(ISARenderer):
|
||||
|
|
@ -652,21 +651,21 @@ class X86Renderer(ISARenderer):
|
|||
def __init__(self):
|
||||
from tinygrad.runtime.support.compiler_cpu import X86Compiler
|
||||
self.compiler = X86Compiler()
|
||||
def stack_pointer(self) -> UOp: return UOp(X86Ops.DEFINE_REG, dtypes.uint64, arg=RSP)
|
||||
def stack_pointer(self) -> UOp: return UOp(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dtypes.uint64, tag=RSP)
|
||||
def render(self, uops:list[UOp], lower:bool=True) -> str:
|
||||
if lower: uops = self.lower(uops[-1])
|
||||
targets: set[UOp] = set()
|
||||
target_loc: list[int] = []
|
||||
binary = bytearray()
|
||||
for u in uops:
|
||||
if u.op in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE): targets.add(u.src[0])
|
||||
if u.arg in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE): targets.add(u.src[0])
|
||||
for u in uops:
|
||||
if u.op in (Ops.GROUP, Ops.NOOP, Ops.AFTER, Ops.BARRIER): continue
|
||||
if u.op in (X86Ops.IMM, X86Ops.DEFINE_REG): continue
|
||||
if u.arg in (Ops.GROUP, Ops.NOOP, Ops.AFTER, Ops.BARRIER): continue
|
||||
if u.arg in (X86Ops.IMM, X86Ops.DEFINE_REG): continue
|
||||
if (l:=cast(bytes|None, encodings.rewrite(u))) is None:
|
||||
raise RuntimeError(f"failed to encode {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
||||
raise RuntimeError(f"failed to encode {u.arg} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
||||
binary.extend(l)
|
||||
if u in targets: target_loc.append(len(binary))
|
||||
elif u.op in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE):
|
||||
elif u.arg in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE):
|
||||
binary[-4:] = (target_loc.pop() - len(binary)).to_bytes(4, 'little', signed=True)
|
||||
return binary.hex()
|
||||
|
|
@ -1,14 +1,9 @@
|
|||
# flake8: noqa: E702
|
||||
# allow semicolons to put multiple ops on one line
|
||||
from enum import auto, IntEnum, Enum, EnumType
|
||||
|
||||
# wrapper around EnumType to allow extending enums with members
|
||||
class ExtensibleEnumType(EnumType):
|
||||
@classmethod
|
||||
def _check_for_existing_members_(mcls, class_name, bases): return
|
||||
from enum import auto, IntEnum, Enum
|
||||
|
||||
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
||||
class FastEnum(IntEnum, metaclass=ExtensibleEnumType):
|
||||
class FastEnum(IntEnum):
|
||||
def __str__(self): return Enum.__str__(self)
|
||||
def __repr__(x): return str(x)
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -421,6 +421,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
def ins(self, arg, **kwargs): return UOp(Ops.INS, kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), arg, kwargs.pop("tag", self.tag))
|
||||
def contract(self, *rngs:UOp):
|
||||
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
|
||||
return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs))
|
||||
|
|
@ -969,6 +970,8 @@ class UPat(OpMixin):
|
|||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.match_dtype, src=(self,)+args, **kwargs)
|
||||
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.match_dtype, (self,)+src, **kwargs)
|
||||
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.match_dtype, (self,)+src, **kwargs)
|
||||
@staticmethod
|
||||
def ins(arg, **kwargs): return UPat(Ops.INS, arg=arg, **kwargs)
|
||||
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.match_dtype, cast(ConstType, b))
|
||||
def alu(self, op:Ops, *src:UPat):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue