mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add vbroadcastss instruction
This commit is contained in:
parent
54396f5cb3
commit
8365bc84ee
2 changed files with 5 additions and 5 deletions
|
|
@ -213,7 +213,7 @@ def cmp(x:UOp): return UOp(X86Ops.CMP, src=x.src) if (i:=to_imm(x.src[1])) is No
|
|||
def def_reg(dt:DType): return UOp(X86Ops.DEFINE_REG, dt)
|
||||
|
||||
# vshufps takes 2 registers, it gets its lower 64 bits from the first register and its upper 64 bits from the second
|
||||
# very useful, used for a lot of shuffles including broadcasts and cats
|
||||
# used for all shuffles with 1 or 2 src registers that are not broadcasts
|
||||
def vshufps(x:UOp) -> UOp:
|
||||
def _imm(src:tuple[UOp, ...]) -> UOp: return imm(dtypes.uint8, sum((s.arg[0] if s.op is Ops.GEP else 0) << (2*i) for i,s in enumerate(src)))
|
||||
rsrc = tuple(s.src[0] if s.op is Ops.GEP else s for s in x.src)
|
||||
|
|
@ -248,9 +248,6 @@ def fuse_index(ctx:IselContext, x:UOp) -> tuple[UOp, ...]:
|
|||
return (base, idx.cast(dtypes.int64) if idx.op is not Ops.NOOP and idx.vmin < 0 else idx, disp(x.src[1]))
|
||||
|
||||
def fuse_load(ctx:IselContext, x:UOp, i:int) -> UOp|None:
|
||||
# TODO: the rule is if size of load doesn't match size of x can't fuse, but there's some details to figure out
|
||||
# like how vinsertps dtype is scalar
|
||||
if x.op is X86Ops.VSHUFPS: return None
|
||||
# if the load is used multiple times we don't fuse
|
||||
return x.replace(src=x.src[:i] + fuse_index(ctx, x.src[i]) + x.src[i+1:]) if len(ctx.uses[x.src[i]]) == 1 else None
|
||||
|
||||
|
|
@ -316,6 +313,7 @@ isel_matcher = PatternMatcher([
|
|||
(UPat.var("y", dtypes.int16s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTW, x.dtype, (y.bitcast(dtypes.float32),))),
|
||||
(UPat.var("y", dtypes.int32s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTD, x.dtype, (y.bitcast(dtypes.float32),))),
|
||||
(UPat.var("y", dtypes.int64s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTQ, x.dtype, (y.bitcast(dtypes.float64),))),
|
||||
(UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: UOp(X86Ops.VBROADCASTSS, x.dtype, (y,))),
|
||||
# shufles
|
||||
(UPat.var("y", dtypes.int8s).bitcast(dtypes.mask8).named("x"), lambda y,x: UOp(X86Ops.VPINSRB, x.dtype, (def_reg(x.dtype), y, imm(dtypes.uint8, 0)))),
|
||||
(UPat.var("y", dtypes.int16s).bitcast((dtypes.float16, dtypes.mask16)).named("x"), lambda y,x: UOp(X86Ops.VPINSRW, x.dtype, (def_reg(x.dtype), y, imm(dtypes.uint8, 0)))), # noqa: E501
|
||||
|
|
@ -662,6 +660,7 @@ encodings = PatternMatcher([
|
|||
# shuffles
|
||||
(UPat(X86Ops.VPBROADCASTB, name="x"), lambda x: encode(x, 0x78, pp=1, sel=2)), (UPat(X86Ops.VPBROADCASTW, name="x"), lambda x: encode(x, 0x79, pp=1, sel=2)),
|
||||
(UPat(X86Ops.VPBROADCASTD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=2)), (UPat(X86Ops.VPBROADCASTQ, name="x"), lambda x: encode(x, 0x59, pp=1, sel=2)),
|
||||
(UPat(X86Ops.VBROADCASTSS, name="x"), lambda x: encode(x, 0x18, pp=1, sel=2)),
|
||||
(UPat(X86Ops.VPINSRB, name="x"), lambda x: encode(x, 0x20, pp=1, sel=3)), (UPat(X86Ops.VPINSRW, name="x"), lambda x: encode(x, 0xC4, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VPINSRD, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3)), (UPat(X86Ops.VPINSRQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3, we=1)),
|
||||
(UPat(X86Ops.VSHUFPS, name="x"), lambda x: encode(x, 0xC6, pp=0, sel=1)), (UPat(X86Ops.VINSERTPS, name="x"), lambda x: encode(x, 0x21, pp=1, sel=3)),
|
||||
|
|
|
|||
|
|
@ -171,6 +171,7 @@ class X86Ops(FastEnum):
|
|||
VPEXTRB = auto(); VPEXTRW = auto(); VPEXTRD = auto(); VPEXTRQ = auto() # noqa: E702
|
||||
VPINSRB = auto(); VPINSRW = auto(); VPINSRD = auto(); VPINSRQ = auto() # noqa: E702
|
||||
VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto() # noqa: E702
|
||||
VBROADCASTSS = auto() # TODO: VBROADCASTSD is ymm only, add once they are supported
|
||||
# int division
|
||||
IDIV = auto()
|
||||
CBW = auto(); CWD = auto(); CDQ = auto(); CQO = auto() # noqa: E702
|
||||
|
|
@ -216,7 +217,7 @@ class X86GroupOp:
|
|||
X86Ops.VCVTDQ2PS, X86Ops.VCVTDQ2PD, X86Ops.VCVTTPS2DQ, X86Ops.VCVTTPD2DQ, X86Ops.VCVTTSS2SI, X86Ops.VCVTTSD2SI,
|
||||
X86Ops.VCVTPH2PS, X86Ops.VCVTPS2PD, X86Ops.VCVTPD2PS, X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB,
|
||||
X86Ops.VROUNDPS, X86Ops.VROUNDPD, X86Ops.VSQRTPS, X86Ops.VSQRTPD, X86Ops.CMPi, X86Ops.IMULi, X86Ops.IDIV, X86Ops.LEA,
|
||||
X86Ops.VPBROADCASTB, X86Ops.VPBROADCASTW, X86Ops.VPBROADCASTD, X86Ops.VPBROADCASTQ}
|
||||
X86Ops.VPBROADCASTB, X86Ops.VPBROADCASTW, X86Ops.VPBROADCASTD, X86Ops.VPBROADCASTQ, X86Ops.VBROADCASTSS}
|
||||
|
||||
# X86Ops whose second src can read from memory NOTE: some of these are TwoAddress1st so the second src is actually the first
|
||||
ReadMem2nd = {X86Ops.ADD, X86Ops.SUB, X86Ops.AND, X86Ops.OR, X86Ops.XOR, X86Ops.SHL, X86Ops.SHR, X86Ops.SAR, X86Ops.IMUL, X86Ops.CMP,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue