mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fixup op order (#13128)
* fixup op order * more order * move a few more * more * DEBUG_LINEARIZE
This commit is contained in:
parent
b9b68bf437
commit
07b415e831
5 changed files with 80 additions and 52 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import heapq
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
|
||||
from tinygrad.helpers import prod, getenv
|
||||
|
||||
def linearize(u:UOp) -> list[UOp]:
|
||||
# this is a toposort with priority
|
||||
|
|
@ -43,6 +43,10 @@ def linearize(u:UOp) -> list[UOp]:
|
|||
in_degree[v] -= 1
|
||||
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
|
||||
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
|
||||
|
||||
if getenv("DEBUG_LINEARIZE"):
|
||||
for i,u in enumerate(newlst):
|
||||
print(f"{i:4d} {str(u.op):20s} {multirange_str(u.ranges, color=True, pad=10)} {priorities[u]}")
|
||||
return newlst
|
||||
|
||||
class CFGContext:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# flake8: noqa: E702
|
||||
# allow semicolons to put multiple ops on one line
|
||||
from enum import auto, IntEnum, Enum
|
||||
|
||||
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
||||
|
|
@ -9,9 +11,21 @@ class FastEnum(IntEnum):
|
|||
|
||||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
# ** 1 -- defines/special **
|
||||
|
||||
# TODO: unify these ops into the levels of the memory hierarchy
|
||||
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto()
|
||||
|
||||
# this is for symbolic shapes
|
||||
DEFINE_VAR = auto(); BIND = auto()
|
||||
|
||||
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
|
||||
SPECIAL = auto()
|
||||
|
||||
# ** 2 -- non op uops **
|
||||
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
|
||||
SENTINEL = auto()
|
||||
NOOP = auto(); SINK = auto(); PRECAST = auto()
|
||||
|
||||
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
|
||||
AFTER = auto()
|
||||
|
|
@ -19,64 +33,70 @@ class Ops(FastEnum):
|
|||
# GROUP is a NOOP that just merges things together
|
||||
GROUP = auto()
|
||||
|
||||
# buffer ops
|
||||
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
||||
# vector creation / item selection
|
||||
GEP = auto(); VECTORIZE = auto()
|
||||
|
||||
# create buffer
|
||||
BUFFERIZE = auto()
|
||||
|
||||
# ops that adjust the behavior of the scheduler
|
||||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto() # noqa: E702
|
||||
|
||||
# movement ops! these only exist in the tensor graph
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
||||
MULTI = auto() # MULTI is really a movement op
|
||||
|
||||
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
|
||||
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
|
||||
|
||||
# this is for symbolic shapes
|
||||
DEFINE_VAR = auto(); BIND = auto() # noqa: E702
|
||||
|
||||
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
|
||||
SPECIAL = auto()
|
||||
|
||||
# reduce
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
|
||||
|
||||
# optimization helper ops
|
||||
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
|
||||
|
||||
# load/store before math
|
||||
LOAD = auto(); STORE = auto() # noqa: E702
|
||||
ASSIGN = auto() # TODO: ASSIGN is STORE, remove ASSIGN
|
||||
|
||||
# tensor core math op, not elementwise
|
||||
WMMA = auto()
|
||||
# ** 3 -- load/store **
|
||||
|
||||
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
|
||||
INDEX = auto()
|
||||
|
||||
# load/store before math
|
||||
LOAD = auto(); STORE = auto()
|
||||
|
||||
# ** 4 -- math **
|
||||
|
||||
# tensor core math op, not elementwise
|
||||
WMMA = auto()
|
||||
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto()
|
||||
SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto()
|
||||
|
||||
# BinaryOps
|
||||
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
|
||||
CMPLT = auto(); CMPNE = auto(); CMPEQ = auto() # noqa: E702
|
||||
XOR = auto(); OR = auto(); AND = auto() # noqa: E702
|
||||
THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
|
||||
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto()
|
||||
CMPLT = auto(); CMPNE = auto(); CMPEQ = auto()
|
||||
XOR = auto(); OR = auto(); AND = auto()
|
||||
THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto()
|
||||
|
||||
# TernaryOps
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
WHERE = auto(); MULACC = auto()
|
||||
|
||||
# ** 5 -- control flow / consts / custom **
|
||||
|
||||
# control flow ops
|
||||
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702
|
||||
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto()
|
||||
|
||||
# consts. VCONST is a vectorized const
|
||||
VCONST = auto(); CONST = auto() # noqa: E702
|
||||
VCONST = auto(); CONST = auto()
|
||||
|
||||
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
|
||||
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
|
||||
CUSTOM = auto(); CUSTOMI = auto()
|
||||
|
||||
# ** 6 -- ops that don't exist in programs **
|
||||
|
||||
# tensor graph ops
|
||||
UNIQUE = auto(); DEVICE = auto(); KERNEL = auto()
|
||||
ASSIGN = auto()
|
||||
|
||||
# buffer ops
|
||||
BUFFERIZE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto()
|
||||
|
||||
# ops that adjust the behavior of the scheduler
|
||||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto()
|
||||
|
||||
# movement ops! these only exist in the tensor graph
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto()
|
||||
MULTI = auto() # MULTI is really a movement op
|
||||
|
||||
# reduce
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
||||
|
||||
# errors/placeholders
|
||||
REWRITE_ERROR = auto(); SENTINEL = auto()
|
||||
|
||||
# expander ops
|
||||
UNROLL = auto(); CONTRACT = auto(); CAT = auto(); PTRCAT = auto()
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}
|
||||
|
|
|
|||
|
|
@ -48,6 +48,11 @@ def range_str(u:UOp, color=False) -> str:
|
|||
ret = '_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]])
|
||||
return colored(ret, axis_colors[u.arg[-1]]) if color else ret
|
||||
|
||||
def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
||||
ret = ','.join([range_str(x, color=color) for x in sorted(rngs, key=lambda x: x.arg)])
|
||||
if pad is not None: ret += " " * (pad-ansilen(ret))
|
||||
return ret
|
||||
|
||||
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||
ret: dict[UOp, dict[UOp, None]] = {}
|
||||
for u in lst:
|
||||
|
|
@ -853,8 +858,7 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
|||
def print_uops(uops:list[UOp]):
|
||||
for i,u in enumerate(uops):
|
||||
formatted_srcs = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
||||
formatted_range = ','.join([range_str(r, color=True) for r in sorted(u.ranges, key=lambda x: x.arg)])
|
||||
print(f"{i:4d} {str(u.op):20s}: {(formatted_range)+' '*(10-ansilen(formatted_range))} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
|
||||
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
|
||||
|
||||
# ***** pattern matcher *****
|
||||
|
||||
|
|
|
|||
|
|
@ -134,10 +134,6 @@ shared_codegen_spec = PatternMatcher([
|
|||
# WMMA has a <a, b, acc>
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||
|
||||
# UNROLL/CONTRACT is used here for WMMA
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# VECTORIZE/GEP
|
||||
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
|
@ -166,6 +162,10 @@ kernel_spec = PatternMatcher([
|
|||
# index is allowed here
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# UNROLL/CONTRACT is used here for WMMA
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# END can end multiple axes here
|
||||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from urllib.parse import parse_qs, urlparse
|
|||
from typing import Any, TypedDict, TypeVar, Generator, Callable
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender
|
||||
from tinygrad.uop.ops import print_uops, range_start
|
||||
from tinygrad.uop.ops import print_uops, range_start, multirange_str
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
|
|
@ -78,7 +78,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
|
||||
try:
|
||||
if len(rngs:=u.ranges):
|
||||
label += f"\n({','.join([range_str(x, color=True) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
|
||||
label += f"\n({multirange_str(rngs, color=True)})"
|
||||
if u.op not in {Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u._shape is not None:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue