mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
allow for extending enums and move X86Ops out of uop
This commit is contained in:
parent
733789e294
commit
5c2b0b2363
11 changed files with 160 additions and 152 deletions
|
|
@ -1,9 +1,8 @@
|
|||
import unittest
|
||||
from tinygrad.renderer.x86 import X86Renderer, RBP, RDI, RSP, RSI, RAX, RDX, XMM, GPR, imm, def_reg
|
||||
from tinygrad.uop import X86Ops, Ops
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import SPEC
|
||||
from tinygrad.renderer.isa.x86 import X86Ops, X86Renderer, RBP, RDI, RSP, RSI, RAX, RDX, XMM, GPR, imm, def_reg
|
||||
|
||||
@unittest.skipIf(SPEC > 1, "x86 spec not supported in full_spec")
|
||||
class TestEncodingsX86(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import unittest
|
||||
from tinygrad.uop import X86Ops, Ops
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
|
||||
from tinygrad.renderer.x86 import X86Renderer
|
||||
from tinygrad.renderer.isa import IselContext, Register
|
||||
from tinygrad.renderer.isa import X86Ops
|
||||
from tinygrad.renderer.isa.x86 import X86Renderer
|
||||
from tinygrad.renderer.isa.isa import IselContext, Register
|
||||
from tinygrad.helpers import SPEC
|
||||
|
||||
@unittest.skipIf(SPEC > 1, "x86 spec not supported in full_spec")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
import itertools
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
||||
from tinygrad.uop import X86GroupOp
|
||||
from tinygrad.renderer.isa import X86GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Callable, cast, TYPE_CHECKING
|
|||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
|
||||
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops, KernelInfo, OpType
|
||||
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops, KernelInfo
|
||||
from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
|
@ -23,7 +23,7 @@ class Estimates:
|
|||
def from_uops(uops:list[UOp], ignore_indexing=False) -> Estimates:
|
||||
flops: sint = 0
|
||||
lds: sint = 0
|
||||
mem: dict[tuple[UOp, OpType], sint] = {}
|
||||
mem: dict[tuple[UOp, Ops], sint] = {}
|
||||
mults: sint = 1
|
||||
mult_stack: list[sint] = []
|
||||
dont_count: set[UOp] = set()
|
||||
|
|
|
|||
126
tinygrad/renderer/isa/__init__.py
Normal file
126
tinygrad/renderer/isa/__init__.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
# flake8: noqa: E702
|
||||
# allow semicolons to put multiple ops on one line
|
||||
from tinygrad.uop.ops import Ops, auto
|
||||
|
||||
# ***** X86 *****
|
||||
|
||||
# NOTE: mypy doesn't allow extending enums even though our meta class does, so we ignore it here
|
||||
class X86Ops(Ops): # type: ignore
|
||||
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
|
||||
# register, not an instruction. FRAME_INDEX is used when the function arg is on the stack and is rewritten to IMM when stack size is known
|
||||
DEFINE_REG = auto(); FRAME_INDEX = auto()
|
||||
# const
|
||||
IMM = auto()
|
||||
# index
|
||||
LEA = auto()
|
||||
# register / memory / immediate moves
|
||||
MOV = auto(); MOVm = auto(); MOVi = auto(); MOVABS = auto()
|
||||
VMOVSS = auto(); VMOVSD = auto(); VMOVUPS = auto()
|
||||
VMOVSSm = auto(); VMOVSDm = auto(); VMOVUPSm = auto()
|
||||
# casts
|
||||
MOVZX = auto(); MOVSX = auto(); MOVSXD = auto()
|
||||
VPMOVZXBW = auto(); VPMOVZXBD = auto(); VPMOVZXBQ = auto()
|
||||
VPMOVZXWD = auto(); VPMOVZXWQ = auto(); VPMOVZXDQ = auto()
|
||||
VPMOVSXBW = auto(); VPMOVSXBD = auto(); VPMOVSXBQ = auto()
|
||||
VPMOVSXWD = auto(); VPMOVSXWQ = auto(); VPMOVSXDQ = auto()
|
||||
VCVTDQ2PS = auto(); VCVTDQ2PD = auto(); VCVTTPS2DQ = auto(); VCVTTPD2DQ = auto()
|
||||
VCVTPH2PS = auto(); VCVTPS2PH = auto(); VCVTPS2PD = auto(); VCVTPD2PS = auto()
|
||||
VCVTSS2SD = auto(); VCVTSD2SS = auto(); VCVTSI2SS = auto(); VCVTSI2SD = auto()
|
||||
VCVTTSS2SI = auto(); VCVTTSD2SI = auto()
|
||||
# bitcasts
|
||||
VMOVD = auto(); VMOVQ = auto(); VMOVDm = auto(); VMOVQm = auto()
|
||||
# comparisons
|
||||
VUCOMISS = auto(); VUCOMISD = auto()
|
||||
VCMPSS = auto(); VCMPSD = auto(); VCMPPS = auto(); VCMPPD = auto()
|
||||
VPCMPGTB = auto(); VPCMPGTW = auto(); VPCMPGTD = auto(); VPCMPGTQ = auto()
|
||||
VPCMPEQB = auto(); VPCMPEQW = auto(); VPCMPEQD = auto(); VPCMPEQQ = auto()
|
||||
SETNE = auto(); SETE = auto(); SETL = auto(); SETB = auto()
|
||||
# where
|
||||
CMOVNE = auto(); CMOVE = auto(); CMOVL = auto(); CMOVB = auto()
|
||||
VPBLENDVB = auto(); VBLENDVPS = auto(); VBLENDVPD = auto()
|
||||
# jumps
|
||||
JNE = auto(); JE = auto(); JL = auto(); JB = auto()
|
||||
# vectorize / gep
|
||||
VSHUFPS = auto(); VINSERTPS = auto()
|
||||
VPEXTRB = auto(); VPEXTRW = auto(); VPEXTRD = auto(); VPEXTRQ = auto()
|
||||
VPINSRB = auto(); VPINSRW = auto(); VPINSRD = auto(); VPINSRQ = auto()
|
||||
VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto()
|
||||
VBROADCASTSS = auto() # TODO: VBROADCASTSD is ymm only, add once they are supported
|
||||
# int division
|
||||
IDIV = auto(); DIV = auto()
|
||||
CBW = auto(); CWD = auto(); CDQ = auto(); CQO = auto()
|
||||
# int binary
|
||||
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto()
|
||||
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto()
|
||||
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto()
|
||||
# float unary (sometimes not unary)
|
||||
VROUNDSS = auto(); VROUNDSD = auto(); VROUNDPS = auto(); VROUNDPD = auto()
|
||||
VSQRTSS = auto(); VSQRTSD = auto(); VSQRTPS = auto(); VSQRTPD = auto()
|
||||
# float scalar / vector binary
|
||||
VADDSS = auto(); VADDSD = auto(); VADDPS = auto(); VADDPD = auto()
|
||||
VSUBSS = auto(); VSUBSD = auto(); VSUBPS = auto(); VSUBPD = auto()
|
||||
VMULSS = auto(); VMULSD = auto(); VMULPS = auto(); VMULPD = auto()
|
||||
VDIVSS = auto(); VDIVSD = auto(); VDIVPS = auto(); VDIVPD = auto()
|
||||
VMAXSS = auto(); VMAXSD = auto(); VMAXPS = auto(); VMAXPD = auto()
|
||||
VMINSS = auto(); VMINSD = auto(); VMINPS = auto(); VMINPD = auto()
|
||||
# int vector binary
|
||||
VPADDB = auto(); VPADDW = auto(); VPADDD = auto(); VPADDQ = auto()
|
||||
VPSUBB = auto(); VPSUBW = auto(); VPSUBD = auto(); VPSUBQ = auto()
|
||||
VPMULLW = auto(); VPMULLD = auto()
|
||||
# packed bitwise TODO: might also want vandp cause of different execution ports
|
||||
VPAND = auto(); VPOR = auto(); VPXOR = auto()
|
||||
# packed variable shifts
|
||||
VPSLLVD = auto(); VPSLLVQ = auto(); VPSRLVD = auto(); VPSRLVQ = auto(); VPSRAVD = auto()
|
||||
# fused multiply add TODO: add other variants to fuse more loads
|
||||
VFMADD213SS = auto(); VFMADD213SD = auto(); VFMADD213PS = auto(); VFMADD213PD = auto()
|
||||
# return
|
||||
RET = auto()
|
||||
|
||||
# TODO: add commutative groupop to fuse more loads
|
||||
class X86GroupOp:
|
||||
# X86Ops whose first src is also the destination
|
||||
TwoAddress1st = {X86Ops.ADD, X86Ops.ADDi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi, X86Ops.OR, X86Ops.ORi, X86Ops.IMUL,
|
||||
X86Ops.SUB, X86Ops.SUBi, X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi,
|
||||
X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD,
|
||||
X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB}
|
||||
|
||||
# X86Ops whose first src can read from memory
|
||||
ReadMem1st = {X86Ops.MOV, X86Ops.VMOVSS, X86Ops.VMOVSD, X86Ops.VMOVUPS, X86Ops.MOVZX, X86Ops.MOVSX, X86Ops.MOVSXD, X86Ops.VMOVD, X86Ops.VMOVQ,
|
||||
X86Ops.VPMOVZXBW, X86Ops.VPMOVZXBD, X86Ops.VPMOVZXBQ, X86Ops.VPMOVZXWD, X86Ops.VPMOVZXWQ, X86Ops.VPMOVZXDQ,
|
||||
X86Ops.VPMOVSXBW, X86Ops.VPMOVSXBD, X86Ops.VPMOVSXBQ, X86Ops.VPMOVSXWD, X86Ops.VPMOVSXWQ, X86Ops.VPMOVSXDQ,
|
||||
X86Ops.VCVTDQ2PS, X86Ops.VCVTDQ2PD, X86Ops.VCVTTPS2DQ, X86Ops.VCVTTPD2DQ, X86Ops.VCVTTSS2SI, X86Ops.VCVTTSD2SI,
|
||||
X86Ops.VCVTPH2PS, X86Ops.VCVTPS2PD, X86Ops.VCVTPD2PS, X86Ops.VROUNDPS, X86Ops.VROUNDPD, X86Ops.VSQRTPS, X86Ops.VSQRTPD,
|
||||
X86Ops.VPBROADCASTB, X86Ops.VPBROADCASTW, X86Ops.VPBROADCASTD, X86Ops.VPBROADCASTQ, X86Ops.VBROADCASTSS,
|
||||
X86Ops.CMPi, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV, X86Ops.LEA}
|
||||
|
||||
# X86Ops whose second src can read from memory NOTE: some of these are TwoAddress1st so the second src is actually the first
|
||||
ReadMem2nd = {X86Ops.ADD, X86Ops.SUB, X86Ops.AND, X86Ops.OR, X86Ops.XOR, X86Ops.SHL, X86Ops.SHR, X86Ops.SAR, X86Ops.IMUL, X86Ops.CMP,
|
||||
X86Ops.VADDSS, X86Ops.VADDSD, X86Ops.VADDPS, X86Ops.VADDPD, X86Ops.VSUBSS, X86Ops.VSUBSD, X86Ops.VSUBPS, X86Ops.VSUBPD,
|
||||
X86Ops.VMULSS, X86Ops.VMULSD, X86Ops.VMULPS, X86Ops.VMULPD, X86Ops.VDIVSS, X86Ops.VDIVSD, X86Ops.VDIVPS, X86Ops.VDIVPD,
|
||||
X86Ops.VPADDB, X86Ops.VPADDW, X86Ops.VPADDD, X86Ops.VPADDQ, X86Ops.VPSUBB, X86Ops.VPSUBW, X86Ops.VPSUBD, X86Ops.VPSUBQ,
|
||||
X86Ops.VPCMPEQB, X86Ops.VPCMPEQW, X86Ops.VPCMPEQD, X86Ops.VPCMPEQQ, X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD,
|
||||
X86Ops.VPCMPGTB, X86Ops.VPCMPGTW, X86Ops.VPCMPGTD, X86Ops.VPCMPGTQ, X86Ops.VCMPSS, X86Ops.VCMPSD, X86Ops.VCMPPS, X86Ops.VCMPPD,
|
||||
X86Ops.VPMULLW, X86Ops.VPMULLD, X86Ops.VROUNDSS, X86Ops.VROUNDSD, X86Ops.VSQRTSS, X86Ops.VSQRTSD, X86Ops.VSHUFPS, X86Ops.VINSERTPS,
|
||||
X86Ops.VPINSRB, X86Ops.VPINSRW, X86Ops.VPINSRD, X86Ops.VPINSRQ, X86Ops.VPAND, X86Ops.VPOR, X86Ops.VPXOR, X86Ops.VPSLLVD,
|
||||
X86Ops.VPSLLVQ, X86Ops.VPSRLVD, X86Ops.VPSRLVQ, X86Ops.VPSRAVD, X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB,
|
||||
X86Ops.VMAXSS, X86Ops.VMAXSD, X86Ops.VMAXPS, X86Ops.VMAXPD, X86Ops.VMINSS, X86Ops.VMINSD, X86Ops.VMINPS, X86Ops.VMINPD,
|
||||
X86Ops.VCVTSI2SS, X86Ops.VCVTSI2SD, X86Ops.VCVTSS2SD, X86Ops.VCVTSD2SS, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
|
||||
|
||||
# X86Ops whose third src can read from memory NOTE: these are TwoAddress1st so the third src is actually the second
|
||||
ReadMem3rd = {X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD}
|
||||
|
||||
# X86Ops that can write to memory
|
||||
WriteMem = {X86Ops.MOVm, X86Ops.MOVi, X86Ops.VMOVSSm, X86Ops.VMOVSDm, X86Ops.VMOVUPSm, X86Ops.VMOVDm, X86Ops.VMOVQm,
|
||||
X86Ops.ADDi, X86Ops.SUBi, X86Ops.ANDi, X86Ops.ORi, X86Ops.XORi, X86Ops.SHLi, X86Ops.SHRi, X86Ops.SARi, X86Ops.SETNE,
|
||||
X86Ops.SETE, X86Ops.SETL, X86Ops.SETB, X86Ops.VCVTPS2PH, X86Ops.VPEXTRB, X86Ops.VPEXTRW, X86Ops.VPEXTRD, X86Ops.VPEXTRQ}
|
||||
|
||||
# X86Ops that read flags
|
||||
ReadFlags = {X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE, X86Ops.SETB, X86Ops.SETL, X86Ops.SETE, X86Ops.SETNE, X86Ops.JB, X86Ops.JL,
|
||||
X86Ops.JE, X86Ops.JNE}
|
||||
|
||||
# X86Ops that write flags or can modify flags to undefined values
|
||||
WriteFlags = {X86Ops.CMP, X86Ops.CMPi, X86Ops.ADD, X86Ops.ADDi, X86Ops.SUB, X86Ops.SUBi, X86Ops.IMUL, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV,
|
||||
X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi,
|
||||
X86Ops.OR, X86Ops.ORi, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
|
||||
|
||||
All = set(X86Ops)
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
import itertools, heapq
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop import X86Ops, X86GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, UPat, Ops
|
||||
from tinygrad.codegen import line_rewrite
|
||||
|
|
@ -37,7 +36,7 @@ isel_fixup = PatternMatcher([
|
|||
|
||||
# TODO: this will eventually be a proper scheduler
|
||||
def isa_linearize(sink:UOp) -> list[UOp]:
|
||||
from tinygrad.renderer.x86 import RSP
|
||||
from tinygrad.renderer.isa.x86 import RSP, X86Ops, X86GroupOp
|
||||
# this is a toposort with priority
|
||||
lst = list(sink.toposort())
|
||||
out_degree:defaultdict[UOp, int] = defaultdict(int)
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
import sys, struct, functools
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, DType, truncate
|
||||
from tinygrad.uop import Ops, X86Ops, GroupOp, X86GroupOp
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext
|
||||
from tinygrad.uop.ops import Ops, GroupOp, UOp, UPat, PatternMatcher
|
||||
from tinygrad.renderer.isa import X86Ops, X86GroupOp
|
||||
from tinygrad.renderer.isa.isa import ISARenderer, IselContext
|
||||
from tinygrad.codegen.late.regalloc import Register, assign
|
||||
from tinygrad.helpers import getenv, CPU_COUNT
|
||||
|
||||
|
|
@ -8,7 +8,7 @@ from tinygrad.runtime.support.hcq import CLikeArgsState
|
|||
from tinygrad.renderer.cstyle import ClangJITRenderer
|
||||
from tinygrad.renderer.llvmir import CPULLVMRenderer
|
||||
from tinygrad.renderer.nir import LVPRenderer
|
||||
from tinygrad.renderer.x86 import X86Renderer
|
||||
from tinygrad.renderer.isa.x86 import X86Renderer
|
||||
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, X86Compiler
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
from tinygrad.uop.ops import sint
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, OpType
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
|
|
@ -442,7 +442,7 @@ def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
|||
|
||||
def find_bufs(x:UOp):
|
||||
idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.AFTER) if s.op is Ops.INDEX]
|
||||
read_from: dict[UOp, OpType] = {}
|
||||
read_from: dict[UOp, Ops] = {}
|
||||
if any((buf:=idx.buf_uop).op is Ops.BUFFER and read_from.setdefault(buf, op:=idx.src[0].op) is not op for idx in idxs):
|
||||
raise RuntimeError(f"cycle detected while indexing {buf}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,21 @@
|
|||
# flake8: noqa: E702
|
||||
# allow semicolons to put multiple ops on one line
|
||||
from enum import auto, IntEnum, Enum
|
||||
from enum import auto, IntEnum, Enum, EnumType
|
||||
|
||||
# wrapper around EnumType to allow extending enums with members
|
||||
class ExtensibleEnumType(EnumType):
|
||||
@classmethod
|
||||
def _check_for_existing_members_(mcls, class_name, bases): return
|
||||
|
||||
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
||||
class FastEnum(IntEnum):
|
||||
class FastEnum(IntEnum, metaclass=ExtensibleEnumType):
|
||||
def __str__(self): return Enum.__str__(self)
|
||||
def __repr__(x): return str(x)
|
||||
@staticmethod
|
||||
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
|
||||
|
||||
OpType = FastEnum
|
||||
|
||||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(OpType):
|
||||
class Ops(FastEnum):
|
||||
# ** 1 -- defines/special **
|
||||
|
||||
# define GLOBAL/VAR are ptrs to outside the Kernel
|
||||
|
|
@ -137,124 +140,3 @@ class GroupOp:
|
|||
|
||||
All = set(Ops)
|
||||
|
||||
# **** backend specific ops ****
|
||||
|
||||
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
|
||||
class X86Ops(OpType):
|
||||
# register, not an instruction. FRAME_INDEX is used when the function arg is on the stack and is rewritten to IMM when stack size is known
|
||||
DEFINE_REG = auto(); FRAME_INDEX = auto()
|
||||
# const
|
||||
IMM = auto()
|
||||
# index
|
||||
LEA = auto()
|
||||
# register / memory / immediate moves
|
||||
MOV = auto(); MOVm = auto(); MOVi = auto(); MOVABS = auto()
|
||||
VMOVSS = auto(); VMOVSD = auto(); VMOVUPS = auto()
|
||||
VMOVSSm = auto(); VMOVSDm = auto(); VMOVUPSm = auto()
|
||||
# casts
|
||||
MOVZX = auto(); MOVSX = auto(); MOVSXD = auto()
|
||||
VPMOVZXBW = auto(); VPMOVZXBD = auto(); VPMOVZXBQ = auto()
|
||||
VPMOVZXWD = auto(); VPMOVZXWQ = auto(); VPMOVZXDQ = auto()
|
||||
VPMOVSXBW = auto(); VPMOVSXBD = auto(); VPMOVSXBQ = auto()
|
||||
VPMOVSXWD = auto(); VPMOVSXWQ = auto(); VPMOVSXDQ = auto()
|
||||
VCVTDQ2PS = auto(); VCVTDQ2PD = auto(); VCVTTPS2DQ = auto(); VCVTTPD2DQ = auto()
|
||||
VCVTPH2PS = auto(); VCVTPS2PH = auto(); VCVTPS2PD = auto(); VCVTPD2PS = auto()
|
||||
VCVTSS2SD = auto(); VCVTSD2SS = auto(); VCVTSI2SS = auto(); VCVTSI2SD = auto()
|
||||
VCVTTSS2SI = auto(); VCVTTSD2SI = auto()
|
||||
# bitcasts
|
||||
VMOVD = auto(); VMOVQ = auto(); VMOVDm = auto(); VMOVQm = auto()
|
||||
# comparisons
|
||||
VUCOMISS = auto(); VUCOMISD = auto()
|
||||
VCMPSS = auto(); VCMPSD = auto(); VCMPPS = auto(); VCMPPD = auto()
|
||||
VPCMPGTB = auto(); VPCMPGTW = auto(); VPCMPGTD = auto(); VPCMPGTQ = auto()
|
||||
VPCMPEQB = auto(); VPCMPEQW = auto(); VPCMPEQD = auto(); VPCMPEQQ = auto()
|
||||
SETNE = auto(); SETE = auto(); SETL = auto(); SETB = auto()
|
||||
# where
|
||||
CMOVNE = auto(); CMOVE = auto(); CMOVL = auto(); CMOVB = auto()
|
||||
VPBLENDVB = auto(); VBLENDVPS = auto(); VBLENDVPD = auto()
|
||||
# jumps
|
||||
JNE = auto(); JE = auto(); JL = auto(); JB = auto()
|
||||
# vectorize / gep
|
||||
VSHUFPS = auto(); VINSERTPS = auto()
|
||||
VPEXTRB = auto(); VPEXTRW = auto(); VPEXTRD = auto(); VPEXTRQ = auto()
|
||||
VPINSRB = auto(); VPINSRW = auto(); VPINSRD = auto(); VPINSRQ = auto()
|
||||
VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto()
|
||||
VBROADCASTSS = auto() # TODO: VBROADCASTSD is ymm only, add once they are supported
|
||||
# int division
|
||||
IDIV = auto(); DIV = auto()
|
||||
CBW = auto(); CWD = auto(); CDQ = auto(); CQO = auto()
|
||||
# int binary
|
||||
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto()
|
||||
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto()
|
||||
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto()
|
||||
# float unary (sometimes not unary)
|
||||
VROUNDSS = auto(); VROUNDSD = auto(); VROUNDPS = auto(); VROUNDPD = auto()
|
||||
VSQRTSS = auto(); VSQRTSD = auto(); VSQRTPS = auto(); VSQRTPD = auto()
|
||||
# float scalar / vector binary
|
||||
VADDSS = auto(); VADDSD = auto(); VADDPS = auto(); VADDPD = auto()
|
||||
VSUBSS = auto(); VSUBSD = auto(); VSUBPS = auto(); VSUBPD = auto()
|
||||
VMULSS = auto(); VMULSD = auto(); VMULPS = auto(); VMULPD = auto()
|
||||
VDIVSS = auto(); VDIVSD = auto(); VDIVPS = auto(); VDIVPD = auto()
|
||||
VMAXSS = auto(); VMAXSD = auto(); VMAXPS = auto(); VMAXPD = auto()
|
||||
VMINSS = auto(); VMINSD = auto(); VMINPS = auto(); VMINPD = auto()
|
||||
# int vector binary
|
||||
VPADDB = auto(); VPADDW = auto(); VPADDD = auto(); VPADDQ = auto()
|
||||
VPSUBB = auto(); VPSUBW = auto(); VPSUBD = auto(); VPSUBQ = auto()
|
||||
VPMULLW = auto(); VPMULLD = auto()
|
||||
# packed bitwise TODO: might also want vandp cause of different execution ports
|
||||
VPAND = auto(); VPOR = auto(); VPXOR = auto()
|
||||
# packed variable shifts
|
||||
VPSLLVD = auto(); VPSLLVQ = auto(); VPSRLVD = auto(); VPSRLVQ = auto(); VPSRAVD = auto()
|
||||
# fused multiply add TODO: add other variants to fuse more loads
|
||||
VFMADD213SS = auto(); VFMADD213SD = auto(); VFMADD213PS = auto(); VFMADD213PD = auto()
|
||||
# return
|
||||
RET = auto()
|
||||
|
||||
# TODO: add commutative groupop to fuse more loads
|
||||
class X86GroupOp:
|
||||
# X86Ops whose first src is also the destination
|
||||
TwoAddress1st = {X86Ops.ADD, X86Ops.ADDi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi, X86Ops.OR, X86Ops.ORi, X86Ops.IMUL,
|
||||
X86Ops.SUB, X86Ops.SUBi, X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi,
|
||||
X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD,
|
||||
X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB}
|
||||
|
||||
# X86Ops whose first src can read from memory
|
||||
ReadMem1st = {X86Ops.MOV, X86Ops.VMOVSS, X86Ops.VMOVSD, X86Ops.VMOVUPS, X86Ops.MOVZX, X86Ops.MOVSX, X86Ops.MOVSXD, X86Ops.VMOVD, X86Ops.VMOVQ,
|
||||
X86Ops.VPMOVZXBW, X86Ops.VPMOVZXBD, X86Ops.VPMOVZXBQ, X86Ops.VPMOVZXWD, X86Ops.VPMOVZXWQ, X86Ops.VPMOVZXDQ,
|
||||
X86Ops.VPMOVSXBW, X86Ops.VPMOVSXBD, X86Ops.VPMOVSXBQ, X86Ops.VPMOVSXWD, X86Ops.VPMOVSXWQ, X86Ops.VPMOVSXDQ,
|
||||
X86Ops.VCVTDQ2PS, X86Ops.VCVTDQ2PD, X86Ops.VCVTTPS2DQ, X86Ops.VCVTTPD2DQ, X86Ops.VCVTTSS2SI, X86Ops.VCVTTSD2SI,
|
||||
X86Ops.VCVTPH2PS, X86Ops.VCVTPS2PD, X86Ops.VCVTPD2PS, X86Ops.VROUNDPS, X86Ops.VROUNDPD, X86Ops.VSQRTPS, X86Ops.VSQRTPD,
|
||||
X86Ops.VPBROADCASTB, X86Ops.VPBROADCASTW, X86Ops.VPBROADCASTD, X86Ops.VPBROADCASTQ, X86Ops.VBROADCASTSS,
|
||||
X86Ops.CMPi, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV, X86Ops.LEA}
|
||||
|
||||
# X86Ops whose second src can read from memory NOTE: some of these are TwoAddress1st so the second src is actually the first
|
||||
ReadMem2nd = {X86Ops.ADD, X86Ops.SUB, X86Ops.AND, X86Ops.OR, X86Ops.XOR, X86Ops.SHL, X86Ops.SHR, X86Ops.SAR, X86Ops.IMUL, X86Ops.CMP,
|
||||
X86Ops.VADDSS, X86Ops.VADDSD, X86Ops.VADDPS, X86Ops.VADDPD, X86Ops.VSUBSS, X86Ops.VSUBSD, X86Ops.VSUBPS, X86Ops.VSUBPD,
|
||||
X86Ops.VMULSS, X86Ops.VMULSD, X86Ops.VMULPS, X86Ops.VMULPD, X86Ops.VDIVSS, X86Ops.VDIVSD, X86Ops.VDIVPS, X86Ops.VDIVPD,
|
||||
X86Ops.VPADDB, X86Ops.VPADDW, X86Ops.VPADDD, X86Ops.VPADDQ, X86Ops.VPSUBB, X86Ops.VPSUBW, X86Ops.VPSUBD, X86Ops.VPSUBQ,
|
||||
X86Ops.VPCMPEQB, X86Ops.VPCMPEQW, X86Ops.VPCMPEQD, X86Ops.VPCMPEQQ, X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD,
|
||||
X86Ops.VPCMPGTB, X86Ops.VPCMPGTW, X86Ops.VPCMPGTD, X86Ops.VPCMPGTQ, X86Ops.VCMPSS, X86Ops.VCMPSD, X86Ops.VCMPPS, X86Ops.VCMPPD,
|
||||
X86Ops.VPMULLW, X86Ops.VPMULLD, X86Ops.VROUNDSS, X86Ops.VROUNDSD, X86Ops.VSQRTSS, X86Ops.VSQRTSD, X86Ops.VSHUFPS, X86Ops.VINSERTPS,
|
||||
X86Ops.VPINSRB, X86Ops.VPINSRW, X86Ops.VPINSRD, X86Ops.VPINSRQ, X86Ops.VPAND, X86Ops.VPOR, X86Ops.VPXOR, X86Ops.VPSLLVD,
|
||||
X86Ops.VPSLLVQ, X86Ops.VPSRLVD, X86Ops.VPSRLVQ, X86Ops.VPSRAVD, X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB,
|
||||
X86Ops.VMAXSS, X86Ops.VMAXSD, X86Ops.VMAXPS, X86Ops.VMAXPD, X86Ops.VMINSS, X86Ops.VMINSD, X86Ops.VMINPS, X86Ops.VMINPD,
|
||||
X86Ops.VCVTSI2SS, X86Ops.VCVTSI2SD, X86Ops.VCVTSS2SD, X86Ops.VCVTSD2SS, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
|
||||
|
||||
# X86Ops whose third src can read from memory NOTE: these are TwoAddress1st so the third src is actually the second
|
||||
ReadMem3rd = {X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD}
|
||||
|
||||
# X86Ops that can write to memory
|
||||
WriteMem = {X86Ops.MOVm, X86Ops.MOVi, X86Ops.VMOVSSm, X86Ops.VMOVSDm, X86Ops.VMOVUPSm, X86Ops.VMOVDm, X86Ops.VMOVQm,
|
||||
X86Ops.ADDi, X86Ops.SUBi, X86Ops.ANDi, X86Ops.ORi, X86Ops.XORi, X86Ops.SHLi, X86Ops.SHRi, X86Ops.SARi, X86Ops.SETNE,
|
||||
X86Ops.SETE, X86Ops.SETL, X86Ops.SETB, X86Ops.VCVTPS2PH, X86Ops.VPEXTRB, X86Ops.VPEXTRW, X86Ops.VPEXTRD, X86Ops.VPEXTRQ}
|
||||
|
||||
# X86Ops that read flags
|
||||
ReadFlags = {X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE, X86Ops.SETB, X86Ops.SETL, X86Ops.SETE, X86Ops.SETNE, X86Ops.JB, X86Ops.JL,
|
||||
X86Ops.JE, X86Ops.JNE}
|
||||
|
||||
# X86Ops that write flags or can modify flags to undefined values
|
||||
WriteFlags = {X86Ops.CMP, X86Ops.CMPi, X86Ops.ADD, X86Ops.ADDi, X86Ops.SUB, X86Ops.SUBi, X86Ops.IMUL, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV,
|
||||
X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi,
|
||||
X86Ops.OR, X86Ops.ORi, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
|
||||
|
||||
All = set(X86Ops)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence, Iterable,
|
|||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref, collections
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp, X86GroupOp, OpType
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, AddrSpace, ConstFloat, PyConst
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC
|
||||
|
|
@ -26,7 +26,7 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
|
|||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2}
|
||||
|
||||
range_start:dict[OpType, int] = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
|
@ -83,7 +83,7 @@ def pretty_print(x:UOp, cache=None, d=0)->str:
|
|||
|
||||
class UOpMetaClass(type):
|
||||
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
||||
def __call__(cls, op:OpType, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
|
||||
metadata:tuple[Metadata,...]|None=None, _buffer:Buffer|None=None):
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
|
||||
|
|
@ -120,7 +120,7 @@ from tinygrad.mixin import OpMixin
|
|||
# NOTE: this should be frozen, but frozen is slower
|
||||
@dataclass(eq=False, slots=True)
|
||||
class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
op:OpType
|
||||
op:Ops
|
||||
dtype:DType = dtypes.void
|
||||
src:tuple[UOp, ...] = tuple()
|
||||
arg:Any = None
|
||||
|
|
@ -297,6 +297,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return input_shapes[0]
|
||||
|
||||
# backend ops don't have a shape
|
||||
from tinygrad.renderer.isa.x86 import X86GroupOp
|
||||
if self.op in X86GroupOp.All: return None
|
||||
|
||||
# all Ops must be explicitly handled
|
||||
|
|
@ -913,11 +914,11 @@ def get_location() -> tuple[str, int]:
|
|||
|
||||
class UPat(OpMixin):
|
||||
__slots__ = ("op", "dtype", "arg", "name", "src", "is_any")
|
||||
def __init__(self, op:OpType|tuple[OpType, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
|
||||
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
|
||||
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
|
||||
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False):
|
||||
assert op is None or isinstance(op, (OpType, tuple, set)), "op must be Ops or tuple of Ops"
|
||||
self.op: tuple[OpType, ...]|None = (op,) if isinstance(op, OpType) else (tuple(op) if isinstance(op, set) else op)
|
||||
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
|
||||
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
|
||||
self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
|
||||
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
||||
self.src: Any = None
|
||||
|
|
@ -1041,7 +1042,7 @@ class PatternMatcher:
|
|||
# if this comes from a pickle, we reconstruct the lambda functions here
|
||||
self.patterns:list[tuple[UPat, Callable]] = [(p,types.FunctionType(*fxn) if isinstance(fxn, tuple) else fxn) for p,fxn in patterns]
|
||||
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
||||
self.pdict: dict[OpType, list[list]] = {}
|
||||
self.pdict: dict[Ops, list[list]] = {}
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
assert p.op is not None
|
||||
|
|
@ -1335,7 +1336,7 @@ pm_unbind = PatternMatcher([(UPat(Ops.BIND, name="x"), do_unbind)])
|
|||
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
|
||||
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
|
||||
# comparison operators are not in here because they are chained in python, not left-associative
|
||||
precedence:dict[OpType, int] = {Ops.MUL:1, Ops.IDIV:1, Ops.MOD:1, Ops.ADD:2, Ops.SUB:2, Ops.SHL:3, Ops.SHR:3, Ops.AND:4, Ops.XOR:5, Ops.OR:6}
|
||||
precedence = {Ops.MUL:1, Ops.IDIV:1, Ops.MOD:1, Ops.ADD:2, Ops.SUB:2, Ops.SHL:3, Ops.SHR:3, Ops.AND:4, Ops.XOR:5, Ops.OR:6}
|
||||
def strip_binary_parens(x:UOp, left:str, right:str, code_for_op) -> str:
|
||||
if x.op not in precedence: return code_for_op(left, right)
|
||||
return code_for_op(strip_parens(left) if precedence.get(x.src[0].op,99)<=precedence[x.op] else left, strip_parens(right) if
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue