mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add float max
This commit is contained in:
parent
f9b2f51554
commit
b4f8d64d2b
2 changed files with 14 additions and 2 deletions
|
|
@ -44,6 +44,8 @@ extra_matcher = PatternMatcher([
|
|||
(UPat.var('x')+(UPat.var('y')*-1), lambda x,y: x.alu(Ops.SUB, y)),
|
||||
# mulacc only available for floats
|
||||
(UPat.var('a', dtypes.floats)*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c)),
|
||||
# no max for scalar ints
|
||||
(UPat(Ops.MAX, dtypes.ints, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]) if m.dtype.count == 1 else None),
|
||||
# no int8 mul or cmove, cast to int16
|
||||
(UPat.var("a", dtypes.int8s) * UPat.var("b"), lambda a,b: (a.cast(dtypes.int16) * b.cast(dtypes.int16)).cast(a.dtype)),
|
||||
(UPat.var("m").where(UPat.var("a", (dtypes.bool,)+dtypes.int8s), UPat.var("b")),
|
||||
|
|
@ -328,6 +330,8 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)),
|
||||
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)),
|
||||
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)),
|
||||
(UPat(Ops.MAX, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VMAXSS if x.dtype.count == 1 else X86Ops.VMAXPS)),
|
||||
(UPat(Ops.MAX, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VMAXSD if x.dtype.count == 1 else X86Ops.VMAXPD)),
|
||||
# casts
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None),
|
||||
|
|
@ -385,6 +389,7 @@ isel_matcher = PatternMatcher([
|
|||
# fuse loads into X86Ops that allow it, if beneficial
|
||||
(UPat(X86GroupOp.ReadMem1st, src=(UPat(Ops.LOAD),), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 0)),
|
||||
(UPat(X86GroupOp.ReadMem2nd, src=(UPat(), UPat(Ops.LOAD)), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 1)),
|
||||
#(UPat(X86GroupOp.Associative, src=(UPat(Ops.LOAD), UPat()), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x.replace(src=(x.src[1], x.src[0])), 1)),
|
||||
(UPat(X86GroupOp.ReadMem3rd, src=(UPat(), UPat(), UPat(Ops.LOAD)), name="x"), lambda ctx,x: fuse_load(ctx, x, 2)),
|
||||
# allocate virtual register to X86Op, ones with specific constraints have already been allocated
|
||||
(UPat(X86GroupOp.All, name="x"), lambda ctx,x: x.replace(arg=ctx.vreg(XMM if x.dtype in dtypes.floats or x.dtype.count > 1 else WGPR)) if x.arg is None and x.dtype != dtypes.void else None), # noqa: E501
|
||||
|
|
@ -612,6 +617,8 @@ encodings = PatternMatcher([
|
|||
(UPat(X86Ops.VDIVSD, name="x"), lambda x: encode(x, 0x5E, pp=3, sel=1)), (UPat(X86Ops.VDIVPD, name="x"), lambda x: encode(x, 0x5E, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VCMPSS, name="x"), lambda x: encode(x, 0xC2, pp=2, sel=1)), (UPat(X86Ops.VCMPPS, name="x"), lambda x: encode(x, 0xC2, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VCMPSD, name="x"), lambda x: encode(x, 0xC2, pp=3, sel=1)), (UPat(X86Ops.VCMPPD, name="x"), lambda x: encode(x, 0xC2, pp=1, sel=1)),
|
||||
(UPat(X86Ops.VMAXSS, name="x"), lambda x: encode(x, 0x5F, pp=2, sel=1)), (UPat(X86Ops.VMAXPS, name="x"), lambda x: encode(x, 0x5F, pp=0, sel=1)),
|
||||
(UPat(X86Ops.VMAXSD, name="x"), lambda x: encode(x, 0x5F, pp=3, sel=1)), (UPat(X86Ops.VMAXPD, name="x"), lambda x: encode(x, 0x5F, pp=1, sel=1)),
|
||||
# ternary
|
||||
(UPat(X86Ops.CMOVB, name="x"), lambda x: encode(x, 0x0F42)), (UPat(X86Ops.CMOVL, name="x"), lambda x: encode(x, 0x0F4C)),
|
||||
(UPat(X86Ops.CMOVE, name="x"), lambda x: encode(x, 0x0F44)), (UPat(X86Ops.CMOVNE, name="x"), lambda x: encode(x, 0x0F45)),
|
||||
|
|
@ -648,7 +655,7 @@ class X86Renderer(ISARenderer):
|
|||
isel_matcher = isel_matcher
|
||||
post_regalloc_matcher = post_regalloc_matcher
|
||||
isa_spec = isa_spec
|
||||
code_for_op = {x: lambda: None for x in (Ops.SQRT, Ops.AND, Ops.OR, Ops.SHL, Ops.SHR, Ops.FDIV, Ops.CMPLT, Ops.CMPEQ)}
|
||||
code_for_op = {x: lambda: None for x in (Ops.SQRT, Ops.AND, Ops.OR, Ops.SHL, Ops.SHR, Ops.FDIV, Ops.CMPLT, Ops.CMPEQ, Ops.MAX)}
|
||||
|
||||
def two_address(self, x:UOp) -> int|None: return 0 if x.op in X86GroupOp.TwoAddress1st else None
|
||||
def stack_pointer(self) -> UOp: return UOp(X86Ops.DEFINE_REG, dtypes.uint64, arg=RSP)
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@ class X86Ops(FastEnum):
|
|||
VSUBSS = auto(); VSUBSD = auto(); VSUBPS = auto(); VSUBPD = auto() # noqa: E702
|
||||
VMULSS = auto(); VMULSD = auto(); VMULPS = auto(); VMULPD = auto() # noqa: E702
|
||||
VDIVSS = auto(); VDIVSD = auto(); VDIVPS = auto(); VDIVPD = auto() # noqa: E702
|
||||
VMAXSS = auto(); VMAXSD = auto(); VMAXPS = auto(); VMAXPD = auto() # noqa: E702
|
||||
# int vector binary
|
||||
VPADDB = auto(); VPADDW = auto(); VPADDD = auto(); VPADDQ = auto() # noqa: E702
|
||||
VPSUBB = auto(); VPSUBW = auto(); VPSUBD = auto(); VPSUBQ = auto() # noqa: E702
|
||||
|
|
@ -207,6 +208,9 @@ class X86Ops(FastEnum):
|
|||
|
||||
# TODO: add associative groupop to fuse more loads
|
||||
class X86GroupOp:
|
||||
# variants with immediates are not associative
|
||||
Associative = {X86Ops.VADDSS, X86Ops.VADDSD}
|
||||
|
||||
# X86Ops whose first src is also the destination
|
||||
TwoAddress1st = {X86Ops.ADD, X86Ops.ADDi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi, X86Ops.OR, X86Ops.ORi, X86Ops.IMUL,
|
||||
X86Ops.SUB, X86Ops.SUBi, X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi,
|
||||
|
|
@ -232,7 +236,8 @@ class X86GroupOp:
|
|||
X86Ops.VPMULLW, X86Ops.VPMULLD, X86Ops.VROUNDSS, X86Ops.VROUNDSD, X86Ops.VSQRTSS, X86Ops.VSQRTSD, X86Ops.VSHUFPS, X86Ops.VINSERTPS,
|
||||
X86Ops.VPINSRB, X86Ops.VPINSRW, X86Ops.VPINSRD, X86Ops.VPINSRQ, X86Ops.VPAND, X86Ops.VPOR, X86Ops.VPXOR, X86Ops.VPSLLVD,
|
||||
X86Ops.VPSLLVQ, X86Ops.VPSRLVD, X86Ops.VPSRLVQ, X86Ops.VPSRAVD, X86Ops.VCVTSI2SS, X86Ops.VCVTSI2SD, X86Ops.VCVTSS2SD, X86Ops.VCVTSD2SS,
|
||||
X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
|
||||
X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB, X86Ops.VMAXSS, X86Ops.VMAXSD, X86Ops.VMAXPS, X86Ops.VMAXPD,
|
||||
X86Ops.VUCOMISS, X86Ops.VUCOMISD}
|
||||
|
||||
# X86Ops whose third src can read from memory NOTE: these are TwoAddress1st so the third src is actually the second
|
||||
ReadMem3rd = {X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue