allow for extending enums and move X86Ops out of uop

This commit is contained in:
ttomsa 2026-02-08 19:08:49 +00:00
commit 5c2b0b2363
11 changed files with 160 additions and 152 deletions

View file

@ -1,9 +1,8 @@
import unittest
from tinygrad.renderer.x86 import X86Renderer, RBP, RDI, RSP, RSI, RAX, RDX, XMM, GPR, imm, def_reg
from tinygrad.uop import X86Ops, Ops
from tinygrad.uop.ops import UOp
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes
from tinygrad.helpers import SPEC
from tinygrad.renderer.isa.x86 import X86Ops, X86Renderer, RBP, RDI, RSP, RSI, RAX, RDX, XMM, GPR, imm, def_reg
@unittest.skipIf(SPEC > 1, "x86 spec not supported in full_spec")
class TestEncodingsX86(unittest.TestCase):

View file

@ -1,8 +1,9 @@
import unittest
from tinygrad.uop import X86Ops, Ops
from tinygrad.uop import Ops
from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
from tinygrad.renderer.x86 import X86Renderer
from tinygrad.renderer.isa import IselContext, Register
from tinygrad.renderer.isa import X86Ops
from tinygrad.renderer.isa.x86 import X86Renderer
from tinygrad.renderer.isa.isa import IselContext, Register
from tinygrad.helpers import SPEC
@unittest.skipIf(SPEC > 1, "x86 spec not supported in full_spec")

View file

@ -1,7 +1,7 @@
from __future__ import annotations
import itertools
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
from tinygrad.uop import X86GroupOp
from tinygrad.renderer.isa import X86GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType
from dataclasses import dataclass, field

View file

@ -3,7 +3,7 @@ from typing import Callable, cast, TYPE_CHECKING
import functools
from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops, KernelInfo, OpType
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops, KernelInfo
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.codegen.opt import Opt
@ -23,7 +23,7 @@ class Estimates:
def from_uops(uops:list[UOp], ignore_indexing=False) -> Estimates:
flops: sint = 0
lds: sint = 0
mem: dict[tuple[UOp, OpType], sint] = {}
mem: dict[tuple[UOp, Ops], sint] = {}
mults: sint = 1
mult_stack: list[sint] = []
dont_count: set[UOp] = set()

View file

@ -0,0 +1,126 @@
# flake8: noqa: E702
# allow semicolons to put multiple ops on one line
from tinygrad.uop.ops import Ops, auto
# ***** X86 *****
# NOTE: mypy doesn't allow extending enums even though our meta class does, so we ignore it here
class X86Ops(Ops): # type: ignore
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
# register, not an instruction. FRAME_INDEX is used when the function arg is on the stack and is rewritten to IMM when stack size is known
DEFINE_REG = auto(); FRAME_INDEX = auto()
# const
IMM = auto()
# index
LEA = auto()
# register / memory / immediate moves
MOV = auto(); MOVm = auto(); MOVi = auto(); MOVABS = auto()
VMOVSS = auto(); VMOVSD = auto(); VMOVUPS = auto()
VMOVSSm = auto(); VMOVSDm = auto(); VMOVUPSm = auto()
# casts
MOVZX = auto(); MOVSX = auto(); MOVSXD = auto()
VPMOVZXBW = auto(); VPMOVZXBD = auto(); VPMOVZXBQ = auto()
VPMOVZXWD = auto(); VPMOVZXWQ = auto(); VPMOVZXDQ = auto()
VPMOVSXBW = auto(); VPMOVSXBD = auto(); VPMOVSXBQ = auto()
VPMOVSXWD = auto(); VPMOVSXWQ = auto(); VPMOVSXDQ = auto()
VCVTDQ2PS = auto(); VCVTDQ2PD = auto(); VCVTTPS2DQ = auto(); VCVTTPD2DQ = auto()
VCVTPH2PS = auto(); VCVTPS2PH = auto(); VCVTPS2PD = auto(); VCVTPD2PS = auto()
VCVTSS2SD = auto(); VCVTSD2SS = auto(); VCVTSI2SS = auto(); VCVTSI2SD = auto()
VCVTTSS2SI = auto(); VCVTTSD2SI = auto()
# bitcasts
VMOVD = auto(); VMOVQ = auto(); VMOVDm = auto(); VMOVQm = auto()
# comparisons
VUCOMISS = auto(); VUCOMISD = auto()
VCMPSS = auto(); VCMPSD = auto(); VCMPPS = auto(); VCMPPD = auto()
VPCMPGTB = auto(); VPCMPGTW = auto(); VPCMPGTD = auto(); VPCMPGTQ = auto()
VPCMPEQB = auto(); VPCMPEQW = auto(); VPCMPEQD = auto(); VPCMPEQQ = auto()
SETNE = auto(); SETE = auto(); SETL = auto(); SETB = auto()
# where
CMOVNE = auto(); CMOVE = auto(); CMOVL = auto(); CMOVB = auto()
VPBLENDVB = auto(); VBLENDVPS = auto(); VBLENDVPD = auto()
# jumps
JNE = auto(); JE = auto(); JL = auto(); JB = auto()
# vectorize / gep
VSHUFPS = auto(); VINSERTPS = auto()
VPEXTRB = auto(); VPEXTRW = auto(); VPEXTRD = auto(); VPEXTRQ = auto()
VPINSRB = auto(); VPINSRW = auto(); VPINSRD = auto(); VPINSRQ = auto()
VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto()
VBROADCASTSS = auto() # TODO: VBROADCASTSD is ymm only, add once they are supported
# int division
IDIV = auto(); DIV = auto()
CBW = auto(); CWD = auto(); CDQ = auto(); CQO = auto()
# int binary
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto()
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto()
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto()
# float unary (sometimes not unary)
VROUNDSS = auto(); VROUNDSD = auto(); VROUNDPS = auto(); VROUNDPD = auto()
VSQRTSS = auto(); VSQRTSD = auto(); VSQRTPS = auto(); VSQRTPD = auto()
# float scalar / vector binary
VADDSS = auto(); VADDSD = auto(); VADDPS = auto(); VADDPD = auto()
VSUBSS = auto(); VSUBSD = auto(); VSUBPS = auto(); VSUBPD = auto()
VMULSS = auto(); VMULSD = auto(); VMULPS = auto(); VMULPD = auto()
VDIVSS = auto(); VDIVSD = auto(); VDIVPS = auto(); VDIVPD = auto()
VMAXSS = auto(); VMAXSD = auto(); VMAXPS = auto(); VMAXPD = auto()
VMINSS = auto(); VMINSD = auto(); VMINPS = auto(); VMINPD = auto()
# int vector binary
VPADDB = auto(); VPADDW = auto(); VPADDD = auto(); VPADDQ = auto()
VPSUBB = auto(); VPSUBW = auto(); VPSUBD = auto(); VPSUBQ = auto()
VPMULLW = auto(); VPMULLD = auto()
# packed bitwise TODO: might also want vandp cause of different execution ports
VPAND = auto(); VPOR = auto(); VPXOR = auto()
# packed variable shifts
VPSLLVD = auto(); VPSLLVQ = auto(); VPSRLVD = auto(); VPSRLVQ = auto(); VPSRAVD = auto()
# fused multiply add TODO: add other variants to fuse more loads
VFMADD213SS = auto(); VFMADD213SD = auto(); VFMADD213PS = auto(); VFMADD213PD = auto()
# return
RET = auto()
# TODO: add commutative groupop to fuse more loads
class X86GroupOp:
# X86Ops whose first src is also the destination
TwoAddress1st = {X86Ops.ADD, X86Ops.ADDi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi, X86Ops.OR, X86Ops.ORi, X86Ops.IMUL,
X86Ops.SUB, X86Ops.SUBi, X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi,
X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD,
X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB}
# X86Ops whose first src can read from memory
ReadMem1st = {X86Ops.MOV, X86Ops.VMOVSS, X86Ops.VMOVSD, X86Ops.VMOVUPS, X86Ops.MOVZX, X86Ops.MOVSX, X86Ops.MOVSXD, X86Ops.VMOVD, X86Ops.VMOVQ,
X86Ops.VPMOVZXBW, X86Ops.VPMOVZXBD, X86Ops.VPMOVZXBQ, X86Ops.VPMOVZXWD, X86Ops.VPMOVZXWQ, X86Ops.VPMOVZXDQ,
X86Ops.VPMOVSXBW, X86Ops.VPMOVSXBD, X86Ops.VPMOVSXBQ, X86Ops.VPMOVSXWD, X86Ops.VPMOVSXWQ, X86Ops.VPMOVSXDQ,
X86Ops.VCVTDQ2PS, X86Ops.VCVTDQ2PD, X86Ops.VCVTTPS2DQ, X86Ops.VCVTTPD2DQ, X86Ops.VCVTTSS2SI, X86Ops.VCVTTSD2SI,
X86Ops.VCVTPH2PS, X86Ops.VCVTPS2PD, X86Ops.VCVTPD2PS, X86Ops.VROUNDPS, X86Ops.VROUNDPD, X86Ops.VSQRTPS, X86Ops.VSQRTPD,
X86Ops.VPBROADCASTB, X86Ops.VPBROADCASTW, X86Ops.VPBROADCASTD, X86Ops.VPBROADCASTQ, X86Ops.VBROADCASTSS,
X86Ops.CMPi, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV, X86Ops.LEA}
# X86Ops whose second src can read from memory NOTE: some of these are TwoAddress1st so the second src is actually the first
ReadMem2nd = {X86Ops.ADD, X86Ops.SUB, X86Ops.AND, X86Ops.OR, X86Ops.XOR, X86Ops.SHL, X86Ops.SHR, X86Ops.SAR, X86Ops.IMUL, X86Ops.CMP,
X86Ops.VADDSS, X86Ops.VADDSD, X86Ops.VADDPS, X86Ops.VADDPD, X86Ops.VSUBSS, X86Ops.VSUBSD, X86Ops.VSUBPS, X86Ops.VSUBPD,
X86Ops.VMULSS, X86Ops.VMULSD, X86Ops.VMULPS, X86Ops.VMULPD, X86Ops.VDIVSS, X86Ops.VDIVSD, X86Ops.VDIVPS, X86Ops.VDIVPD,
X86Ops.VPADDB, X86Ops.VPADDW, X86Ops.VPADDD, X86Ops.VPADDQ, X86Ops.VPSUBB, X86Ops.VPSUBW, X86Ops.VPSUBD, X86Ops.VPSUBQ,
X86Ops.VPCMPEQB, X86Ops.VPCMPEQW, X86Ops.VPCMPEQD, X86Ops.VPCMPEQQ, X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD,
X86Ops.VPCMPGTB, X86Ops.VPCMPGTW, X86Ops.VPCMPGTD, X86Ops.VPCMPGTQ, X86Ops.VCMPSS, X86Ops.VCMPSD, X86Ops.VCMPPS, X86Ops.VCMPPD,
X86Ops.VPMULLW, X86Ops.VPMULLD, X86Ops.VROUNDSS, X86Ops.VROUNDSD, X86Ops.VSQRTSS, X86Ops.VSQRTSD, X86Ops.VSHUFPS, X86Ops.VINSERTPS,
X86Ops.VPINSRB, X86Ops.VPINSRW, X86Ops.VPINSRD, X86Ops.VPINSRQ, X86Ops.VPAND, X86Ops.VPOR, X86Ops.VPXOR, X86Ops.VPSLLVD,
X86Ops.VPSLLVQ, X86Ops.VPSRLVD, X86Ops.VPSRLVQ, X86Ops.VPSRAVD, X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB,
X86Ops.VMAXSS, X86Ops.VMAXSD, X86Ops.VMAXPS, X86Ops.VMAXPD, X86Ops.VMINSS, X86Ops.VMINSD, X86Ops.VMINPS, X86Ops.VMINPD,
X86Ops.VCVTSI2SS, X86Ops.VCVTSI2SD, X86Ops.VCVTSS2SD, X86Ops.VCVTSD2SS, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
# X86Ops whose third src can read from memory NOTE: these are TwoAddress1st so the third src is actually the second
ReadMem3rd = {X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD}
# X86Ops that can write to memory
WriteMem = {X86Ops.MOVm, X86Ops.MOVi, X86Ops.VMOVSSm, X86Ops.VMOVSDm, X86Ops.VMOVUPSm, X86Ops.VMOVDm, X86Ops.VMOVQm,
X86Ops.ADDi, X86Ops.SUBi, X86Ops.ANDi, X86Ops.ORi, X86Ops.XORi, X86Ops.SHLi, X86Ops.SHRi, X86Ops.SARi, X86Ops.SETNE,
X86Ops.SETE, X86Ops.SETL, X86Ops.SETB, X86Ops.VCVTPS2PH, X86Ops.VPEXTRB, X86Ops.VPEXTRW, X86Ops.VPEXTRD, X86Ops.VPEXTRQ}
# X86Ops that read flags
ReadFlags = {X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE, X86Ops.SETB, X86Ops.SETL, X86Ops.SETE, X86Ops.SETNE, X86Ops.JB, X86Ops.JL,
X86Ops.JE, X86Ops.JNE}
# X86Ops that write flags or can modify flags to undefined values
WriteFlags = {X86Ops.CMP, X86Ops.CMPi, X86Ops.ADD, X86Ops.ADDi, X86Ops.SUB, X86Ops.SUBi, X86Ops.IMUL, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV,
X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi,
X86Ops.OR, X86Ops.ORi, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
All = set(X86Ops)

View file

@ -1,7 +1,6 @@
import itertools, heapq
from typing import Any
from collections import defaultdict
from tinygrad.uop import X86Ops, X86GroupOp
from tinygrad.renderer import Renderer
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, UPat, Ops
from tinygrad.codegen import line_rewrite
@ -37,7 +36,7 @@ isel_fixup = PatternMatcher([
# TODO: this will eventually be a proper scheduler
def isa_linearize(sink:UOp) -> list[UOp]:
from tinygrad.renderer.x86 import RSP
from tinygrad.renderer.isa.x86 import RSP, X86Ops, X86GroupOp
# this is a toposort with priority
lst = list(sink.toposort())
out_degree:defaultdict[UOp, int] = defaultdict(int)

View file

@ -1,9 +1,9 @@
import sys, struct, functools
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, DType, truncate
from tinygrad.uop import Ops, X86Ops, GroupOp, X86GroupOp
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
from tinygrad.renderer.isa import ISARenderer, IselContext
from tinygrad.uop.ops import Ops, GroupOp, UOp, UPat, PatternMatcher
from tinygrad.renderer.isa import X86Ops, X86GroupOp
from tinygrad.renderer.isa.isa import ISARenderer, IselContext
from tinygrad.codegen.late.regalloc import Register, assign
from tinygrad.helpers import getenv, CPU_COUNT

View file

@ -8,7 +8,7 @@ from tinygrad.runtime.support.hcq import CLikeArgsState
from tinygrad.renderer.cstyle import ClangJITRenderer
from tinygrad.renderer.llvmir import CPULLVMRenderer
from tinygrad.renderer.nir import LVPRenderer
from tinygrad.renderer.x86 import X86Renderer
from tinygrad.renderer.isa.x86 import X86Renderer
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, X86Compiler
from tinygrad.runtime.support.elf import jit_loader
from tinygrad.uop.ops import sint

View file

@ -1,7 +1,7 @@
from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, OpType
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
@ -442,7 +442,7 @@ def renumber_range(ctx:LocalAddBufferContext, r:UOp):
def find_bufs(x:UOp):
idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.AFTER) if s.op is Ops.INDEX]
read_from: dict[UOp, OpType] = {}
read_from: dict[UOp, Ops] = {}
if any((buf:=idx.buf_uop).op is Ops.BUFFER and read_from.setdefault(buf, op:=idx.src[0].op) is not op for idx in idxs):
raise RuntimeError(f"cycle detected while indexing {buf}")

View file

@ -1,18 +1,21 @@
# flake8: noqa: E702
# allow semicolons to put multiple ops on one line
from enum import auto, IntEnum, Enum
from enum import auto, IntEnum, Enum, EnumType
# wrapper around EnumType to allow extending enums with members
class ExtensibleEnumType(EnumType):
@classmethod
def _check_for_existing_members_(mcls, class_name, bases): return
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
class FastEnum(IntEnum):
class FastEnum(IntEnum, metaclass=ExtensibleEnumType):
def __str__(self): return Enum.__str__(self)
def __repr__(x): return str(x)
@staticmethod
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
OpType = FastEnum
# the order of these Ops controls the order of the toposort
class Ops(OpType):
class Ops(FastEnum):
# ** 1 -- defines/special **
# define GLOBAL/VAR are ptrs to outside the Kernel
@ -137,124 +140,3 @@ class GroupOp:
All = set(Ops)
# **** backend specific ops ****
# NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from
class X86Ops(OpType):
# register, not an instruction. FRAME_INDEX is used when the function arg is on the stack and is rewritten to IMM when stack size is known
DEFINE_REG = auto(); FRAME_INDEX = auto()
# const
IMM = auto()
# index
LEA = auto()
# register / memory / immediate moves
MOV = auto(); MOVm = auto(); MOVi = auto(); MOVABS = auto()
VMOVSS = auto(); VMOVSD = auto(); VMOVUPS = auto()
VMOVSSm = auto(); VMOVSDm = auto(); VMOVUPSm = auto()
# casts
MOVZX = auto(); MOVSX = auto(); MOVSXD = auto()
VPMOVZXBW = auto(); VPMOVZXBD = auto(); VPMOVZXBQ = auto()
VPMOVZXWD = auto(); VPMOVZXWQ = auto(); VPMOVZXDQ = auto()
VPMOVSXBW = auto(); VPMOVSXBD = auto(); VPMOVSXBQ = auto()
VPMOVSXWD = auto(); VPMOVSXWQ = auto(); VPMOVSXDQ = auto()
VCVTDQ2PS = auto(); VCVTDQ2PD = auto(); VCVTTPS2DQ = auto(); VCVTTPD2DQ = auto()
VCVTPH2PS = auto(); VCVTPS2PH = auto(); VCVTPS2PD = auto(); VCVTPD2PS = auto()
VCVTSS2SD = auto(); VCVTSD2SS = auto(); VCVTSI2SS = auto(); VCVTSI2SD = auto()
VCVTTSS2SI = auto(); VCVTTSD2SI = auto()
# bitcasts
VMOVD = auto(); VMOVQ = auto(); VMOVDm = auto(); VMOVQm = auto()
# comparisons
VUCOMISS = auto(); VUCOMISD = auto()
VCMPSS = auto(); VCMPSD = auto(); VCMPPS = auto(); VCMPPD = auto()
VPCMPGTB = auto(); VPCMPGTW = auto(); VPCMPGTD = auto(); VPCMPGTQ = auto()
VPCMPEQB = auto(); VPCMPEQW = auto(); VPCMPEQD = auto(); VPCMPEQQ = auto()
SETNE = auto(); SETE = auto(); SETL = auto(); SETB = auto()
# where
CMOVNE = auto(); CMOVE = auto(); CMOVL = auto(); CMOVB = auto()
VPBLENDVB = auto(); VBLENDVPS = auto(); VBLENDVPD = auto()
# jumps
JNE = auto(); JE = auto(); JL = auto(); JB = auto()
# vectorize / gep
VSHUFPS = auto(); VINSERTPS = auto()
VPEXTRB = auto(); VPEXTRW = auto(); VPEXTRD = auto(); VPEXTRQ = auto()
VPINSRB = auto(); VPINSRW = auto(); VPINSRD = auto(); VPINSRQ = auto()
VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto()
VBROADCASTSS = auto() # TODO: VBROADCASTSD is ymm only, add once they are supported
# int division
IDIV = auto(); DIV = auto()
CBW = auto(); CWD = auto(); CDQ = auto(); CQO = auto()
# int binary
ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto()
AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto()
SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto()
# float unary (sometimes not unary)
VROUNDSS = auto(); VROUNDSD = auto(); VROUNDPS = auto(); VROUNDPD = auto()
VSQRTSS = auto(); VSQRTSD = auto(); VSQRTPS = auto(); VSQRTPD = auto()
# float scalar / vector binary
VADDSS = auto(); VADDSD = auto(); VADDPS = auto(); VADDPD = auto()
VSUBSS = auto(); VSUBSD = auto(); VSUBPS = auto(); VSUBPD = auto()
VMULSS = auto(); VMULSD = auto(); VMULPS = auto(); VMULPD = auto()
VDIVSS = auto(); VDIVSD = auto(); VDIVPS = auto(); VDIVPD = auto()
VMAXSS = auto(); VMAXSD = auto(); VMAXPS = auto(); VMAXPD = auto()
VMINSS = auto(); VMINSD = auto(); VMINPS = auto(); VMINPD = auto()
# int vector binary
VPADDB = auto(); VPADDW = auto(); VPADDD = auto(); VPADDQ = auto()
VPSUBB = auto(); VPSUBW = auto(); VPSUBD = auto(); VPSUBQ = auto()
VPMULLW = auto(); VPMULLD = auto()
# packed bitwise TODO: might also want vandp cause of different execution ports
VPAND = auto(); VPOR = auto(); VPXOR = auto()
# packed variable shifts
VPSLLVD = auto(); VPSLLVQ = auto(); VPSRLVD = auto(); VPSRLVQ = auto(); VPSRAVD = auto()
# fused multiply add TODO: add other variants to fuse more loads
VFMADD213SS = auto(); VFMADD213SD = auto(); VFMADD213PS = auto(); VFMADD213PD = auto()
# return
RET = auto()
# TODO: add commutative groupop to fuse more loads
class X86GroupOp:
# X86Ops whose first src is also the destination
TwoAddress1st = {X86Ops.ADD, X86Ops.ADDi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi, X86Ops.OR, X86Ops.ORi, X86Ops.IMUL,
X86Ops.SUB, X86Ops.SUBi, X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi,
X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD,
X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB}
# X86Ops whose first src can read from memory
ReadMem1st = {X86Ops.MOV, X86Ops.VMOVSS, X86Ops.VMOVSD, X86Ops.VMOVUPS, X86Ops.MOVZX, X86Ops.MOVSX, X86Ops.MOVSXD, X86Ops.VMOVD, X86Ops.VMOVQ,
X86Ops.VPMOVZXBW, X86Ops.VPMOVZXBD, X86Ops.VPMOVZXBQ, X86Ops.VPMOVZXWD, X86Ops.VPMOVZXWQ, X86Ops.VPMOVZXDQ,
X86Ops.VPMOVSXBW, X86Ops.VPMOVSXBD, X86Ops.VPMOVSXBQ, X86Ops.VPMOVSXWD, X86Ops.VPMOVSXWQ, X86Ops.VPMOVSXDQ,
X86Ops.VCVTDQ2PS, X86Ops.VCVTDQ2PD, X86Ops.VCVTTPS2DQ, X86Ops.VCVTTPD2DQ, X86Ops.VCVTTSS2SI, X86Ops.VCVTTSD2SI,
X86Ops.VCVTPH2PS, X86Ops.VCVTPS2PD, X86Ops.VCVTPD2PS, X86Ops.VROUNDPS, X86Ops.VROUNDPD, X86Ops.VSQRTPS, X86Ops.VSQRTPD,
X86Ops.VPBROADCASTB, X86Ops.VPBROADCASTW, X86Ops.VPBROADCASTD, X86Ops.VPBROADCASTQ, X86Ops.VBROADCASTSS,
X86Ops.CMPi, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV, X86Ops.LEA}
# X86Ops whose second src can read from memory NOTE: some of these are TwoAddress1st so the second src is actually the first
ReadMem2nd = {X86Ops.ADD, X86Ops.SUB, X86Ops.AND, X86Ops.OR, X86Ops.XOR, X86Ops.SHL, X86Ops.SHR, X86Ops.SAR, X86Ops.IMUL, X86Ops.CMP,
X86Ops.VADDSS, X86Ops.VADDSD, X86Ops.VADDPS, X86Ops.VADDPD, X86Ops.VSUBSS, X86Ops.VSUBSD, X86Ops.VSUBPS, X86Ops.VSUBPD,
X86Ops.VMULSS, X86Ops.VMULSD, X86Ops.VMULPS, X86Ops.VMULPD, X86Ops.VDIVSS, X86Ops.VDIVSD, X86Ops.VDIVPS, X86Ops.VDIVPD,
X86Ops.VPADDB, X86Ops.VPADDW, X86Ops.VPADDD, X86Ops.VPADDQ, X86Ops.VPSUBB, X86Ops.VPSUBW, X86Ops.VPSUBD, X86Ops.VPSUBQ,
X86Ops.VPCMPEQB, X86Ops.VPCMPEQW, X86Ops.VPCMPEQD, X86Ops.VPCMPEQQ, X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD,
X86Ops.VPCMPGTB, X86Ops.VPCMPGTW, X86Ops.VPCMPGTD, X86Ops.VPCMPGTQ, X86Ops.VCMPSS, X86Ops.VCMPSD, X86Ops.VCMPPS, X86Ops.VCMPPD,
X86Ops.VPMULLW, X86Ops.VPMULLD, X86Ops.VROUNDSS, X86Ops.VROUNDSD, X86Ops.VSQRTSS, X86Ops.VSQRTSD, X86Ops.VSHUFPS, X86Ops.VINSERTPS,
X86Ops.VPINSRB, X86Ops.VPINSRW, X86Ops.VPINSRD, X86Ops.VPINSRQ, X86Ops.VPAND, X86Ops.VPOR, X86Ops.VPXOR, X86Ops.VPSLLVD,
X86Ops.VPSLLVQ, X86Ops.VPSRLVD, X86Ops.VPSRLVQ, X86Ops.VPSRAVD, X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB,
X86Ops.VMAXSS, X86Ops.VMAXSD, X86Ops.VMAXPS, X86Ops.VMAXPD, X86Ops.VMINSS, X86Ops.VMINSD, X86Ops.VMINPS, X86Ops.VMINPD,
X86Ops.VCVTSI2SS, X86Ops.VCVTSI2SD, X86Ops.VCVTSS2SD, X86Ops.VCVTSD2SS, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
# X86Ops whose third src can read from memory NOTE: these are TwoAddress1st so the third src is actually the second
ReadMem3rd = {X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD}
# X86Ops that can write to memory
WriteMem = {X86Ops.MOVm, X86Ops.MOVi, X86Ops.VMOVSSm, X86Ops.VMOVSDm, X86Ops.VMOVUPSm, X86Ops.VMOVDm, X86Ops.VMOVQm,
X86Ops.ADDi, X86Ops.SUBi, X86Ops.ANDi, X86Ops.ORi, X86Ops.XORi, X86Ops.SHLi, X86Ops.SHRi, X86Ops.SARi, X86Ops.SETNE,
X86Ops.SETE, X86Ops.SETL, X86Ops.SETB, X86Ops.VCVTPS2PH, X86Ops.VPEXTRB, X86Ops.VPEXTRW, X86Ops.VPEXTRD, X86Ops.VPEXTRQ}
# X86Ops that read flags
ReadFlags = {X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE, X86Ops.SETB, X86Ops.SETL, X86Ops.SETE, X86Ops.SETNE, X86Ops.JB, X86Ops.JL,
X86Ops.JE, X86Ops.JNE}
# X86Ops that write flags or can modify flags to undefined values
WriteFlags = {X86Ops.CMP, X86Ops.CMPi, X86Ops.ADD, X86Ops.ADDi, X86Ops.SUB, X86Ops.SUBi, X86Ops.IMUL, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV,
X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi,
X86Ops.OR, X86Ops.ORi, X86Ops.VUCOMISS, X86Ops.VUCOMISD}
All = set(X86Ops)

View file

@ -3,7 +3,7 @@ from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence, Iterable,
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref, collections
from dataclasses import dataclass
from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp, X86GroupOp, OpType
from tinygrad.uop import Ops, GroupOp
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, AddrSpace, ConstFloat, PyConst
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC
@ -26,7 +26,7 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2}
range_start:dict[OpType, int] = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
@ -83,7 +83,7 @@ def pretty_print(x:UOp, cache=None, d=0)->str:
class UOpMetaClass(type):
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
def __call__(cls, op:OpType, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
metadata:tuple[Metadata,...]|None=None, _buffer:Buffer|None=None):
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
@ -120,7 +120,7 @@ from tinygrad.mixin import OpMixin
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
class UOp(OpMixin, metaclass=UOpMetaClass):
op:OpType
op:Ops
dtype:DType = dtypes.void
src:tuple[UOp, ...] = tuple()
arg:Any = None
@ -297,6 +297,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return input_shapes[0]
# backend ops don't have a shape
from tinygrad.renderer.isa.x86 import X86GroupOp
if self.op in X86GroupOp.All: return None
# all Ops must be explicitly handled
@ -913,11 +914,11 @@ def get_location() -> tuple[str, int]:
class UPat(OpMixin):
__slots__ = ("op", "dtype", "arg", "name", "src", "is_any")
def __init__(self, op:OpType|tuple[OpType, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None,
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False):
assert op is None or isinstance(op, (OpType, tuple, set)), "op must be Ops or tuple of Ops"
self.op: tuple[OpType, ...]|None = (op,) if isinstance(op, OpType) else (tuple(op) if isinstance(op, set) else op)
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype)
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
self.src: Any = None
@ -1041,7 +1042,7 @@ class PatternMatcher:
# if this comes from a pickle, we reconstruct the lambda functions here
self.patterns:list[tuple[UPat, Callable]] = [(p,types.FunctionType(*fxn) if isinstance(fxn, tuple) else fxn) for p,fxn in patterns]
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
self.pdict: dict[OpType, list[list]] = {}
self.pdict: dict[Ops, list[list]] = {}
# uop is required, arg is optional
for p,fxn in self.patterns:
assert p.op is not None
@ -1335,7 +1336,7 @@ pm_unbind = PatternMatcher([(UPat(Ops.BIND, name="x"), do_unbind)])
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
# comparison operators are not in here because they are chained in python, not left-associative
precedence:dict[OpType, int] = {Ops.MUL:1, Ops.IDIV:1, Ops.MOD:1, Ops.ADD:2, Ops.SUB:2, Ops.SHL:3, Ops.SHR:3, Ops.AND:4, Ops.XOR:5, Ops.OR:6}
precedence = {Ops.MUL:1, Ops.IDIV:1, Ops.MOD:1, Ops.ADD:2, Ops.SUB:2, Ops.SHL:3, Ops.SHR:3, Ops.AND:4, Ops.XOR:5, Ops.OR:6}
def strip_binary_parens(x:UOp, left:str, right:str, code_for_op) -> str:
if x.op not in precedence: return code_for_op(left, right)
return code_for_op(strip_parens(left) if precedence.get(x.src[0].op,99)<=precedence[x.op] else left, strip_parens(right) if