fixup op order (#13128)

* fixup op order

* more order

* move a few more

* more

* DEBUG_LINEARIZE
This commit is contained in:
George Hotz 2025-11-06 08:50:04 -08:00 committed by GitHub
commit 07b415e831
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 80 additions and 52 deletions

View file

@ -1,7 +1,7 @@
import heapq
from collections import defaultdict
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat
from tinygrad.helpers import prod
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
from tinygrad.helpers import prod, getenv
def linearize(u:UOp) -> list[UOp]:
# this is a toposort with priority
@ -43,6 +43,10 @@ def linearize(u:UOp) -> list[UOp]:
in_degree[v] -= 1
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
if getenv("DEBUG_LINEARIZE"):
for i,u in enumerate(newlst):
print(f"{i:4d} {str(u.op):20s} {multirange_str(u.ranges, color=True, pad=10)} {priorities[u]}")
return newlst
class CFGContext:

View file

@ -1,3 +1,5 @@
# flake8: noqa: E702
# allow semicolons to put multiple ops on one line
from enum import auto, IntEnum, Enum
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
@ -9,9 +11,21 @@ class FastEnum(IntEnum):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# ** 1 -- defines/special **
# TODO: unify these ops into the levels of the memory hierarchy
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto()
# this is for symbolic shapes
DEFINE_VAR = auto(); BIND = auto()
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
SPECIAL = auto()
# ** 2 -- non op uops **
# uops that aren't rendered
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
SENTINEL = auto()
NOOP = auto(); SINK = auto(); PRECAST = auto()
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
AFTER = auto()
@ -19,64 +33,70 @@ class Ops(FastEnum):
# GROUP is a NOOP that just merges things together
GROUP = auto()
# buffer ops
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
# vector creation / item selection
GEP = auto(); VECTORIZE = auto()
# create buffer
BUFFERIZE = auto()
# ops that adjust the behavior of the scheduler
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto() # noqa: E702
# movement ops! these only exist in the tensor graph
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
MULTI = auto() # MULTI is really a movement op
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
# this is for symbolic shapes
DEFINE_VAR = auto(); BIND = auto() # noqa: E702
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
SPECIAL = auto()
# reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
# optimization helper ops
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
# load/store before math
LOAD = auto(); STORE = auto() # noqa: E702
ASSIGN = auto() # TODO: ASSIGN is STORE, remove ASSIGN
# tensor core math op, not elementwise
WMMA = auto()
# ** 3 -- load/store **
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
INDEX = auto()
# load/store before math
LOAD = auto(); STORE = auto()
# ** 4 -- math **
# tensor core math op, not elementwise
WMMA = auto()
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto()
SQRT = auto(); RECIPROCAL = auto(); NEG = auto(); TRUNC = auto()
# BinaryOps
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
CMPLT = auto(); CMPNE = auto(); CMPEQ = auto() # noqa: E702
XOR = auto(); OR = auto(); AND = auto() # noqa: E702
THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto()
CMPLT = auto(); CMPNE = auto(); CMPEQ = auto()
XOR = auto(); OR = auto(); AND = auto()
THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto()
# TernaryOps
WHERE = auto(); MULACC = auto() # noqa: E702
WHERE = auto(); MULACC = auto()
# ** 5 -- control flow / consts / custom **
# control flow ops
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto()
# consts. VCONST is a vectorized const
VCONST = auto(); CONST = auto() # noqa: E702
VCONST = auto(); CONST = auto()
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
CUSTOM = auto(); CUSTOMI = auto()
# ** 6 -- ops that don't exist in programs **
# tensor graph ops
UNIQUE = auto(); DEVICE = auto(); KERNEL = auto()
ASSIGN = auto()
# buffer ops
BUFFERIZE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto()
# ops that adjust the behavior of the scheduler
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto()
# movement ops! these only exist in the tensor graph
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto()
MULTI = auto() # MULTI is really a movement op
# reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
# errors/placeholders
REWRITE_ERROR = auto(); SENTINEL = auto()
# expander ops
UNROLL = auto(); CONTRACT = auto(); CAT = auto(); PTRCAT = auto()
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}

View file

@ -48,6 +48,11 @@ def range_str(u:UOp, color=False) -> str:
ret = '_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]])
return colored(ret, axis_colors[u.arg[-1]]) if color else ret
def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
ret = ','.join([range_str(x, color=color) for x in sorted(rngs, key=lambda x: x.arg)])
if pad is not None: ret += " " * (pad-ansilen(ret))
return ret
def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {}
for u in lst:
@ -853,8 +858,7 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
def print_uops(uops:list[UOp]):
for i,u in enumerate(uops):
formatted_srcs = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
formatted_range = ','.join([range_str(r, color=True) for r in sorted(u.ranges, key=lambda x: x.arg)])
print(f"{i:4d} {str(u.op):20s}: {(formatted_range)+' '*(10-ansilen(formatted_range))} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
# ***** pattern matcher *****

View file

@ -134,10 +134,6 @@ shared_codegen_spec = PatternMatcher([
# WMMA has a <a, b, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
# UNROLL/CONTRACT is used here for WMMA
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# VECTORIZE/GEP
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
@ -166,6 +162,10 @@ kernel_spec = PatternMatcher([
# index is allowed here
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
# UNROLL/CONTRACT is used here for WMMA
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# END can end multiple axes here
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),

View file

@ -8,7 +8,7 @@ from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, TypeVar, Generator, Callable
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import print_uops, range_start
from tinygrad.uop.ops import print_uops, range_start, multirange_str
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
@ -78,7 +78,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
try:
if len(rngs:=u.ranges):
label += f"\n({','.join([range_str(x, color=True) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
label += f"\n({multirange_str(rngs, color=True)})"
if u.op not in {Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u._shape is not None:
label += f"\n{shape_to_str(u.shape)}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}: