Compare commits

...

6 commits

Author SHA1 Message Date
George Hotz
7ab5b5b865 outer 2025-11-14 16:22:18 -08:00
George Hotz
3b123b4d53 outer fold still broken 2025-11-14 16:16:02 -08:00
George Hotz
80f1ff18ad fold works 2025-11-14 15:27:05 -08:00
George Hotz
09c02868a5 works 2025-11-14 15:18:50 -08:00
George Hotz
9d638a0202
Merge branch 'master' into tiny_scan 2025-11-14 14:15:10 -08:00
George Hotz
b815c6a3df start work on SCAN op 2025-11-14 12:03:41 -08:00
10 changed files with 92 additions and 16 deletions

31
test/test_fold.py Normal file
View file

@ -0,0 +1,31 @@
import numpy as np
import unittest
from tinygrad import Tensor, UOp
from tinygrad.uop.ops import AxisType
class TestFold(unittest.TestCase):
def test_reduce_add(self):
a = Tensor.randn(10, 10).realize()
a_red = a.sum(axis=1)
np.testing.assert_allclose(a_red.numpy(), a.numpy().sum(axis=1), atol=1e-6)
def test_fold_add(self):
a = Tensor.randn(10, 10).realize()
init = Tensor.zeros(10, 1).contiguous()
a_red = (init+a).fold(init).reshape(10)
np.testing.assert_allclose(a_red.numpy(), a.numpy().sum(axis=1), atol=1e-6)
#@unittest.skip("no outer fold yet")
def test_fold_matmul(self):
vec = Tensor.randn(1, 10).realize()
mats = Tensor.randn(3, 10, 10).realize()
np_mats = mats.numpy()
np_ref = ((vec.numpy() @ np_mats[0]) @ np_mats[1]) @ np_mats[2]
i = UOp.range(3, -1, AxisType.OUTER)
out = (vec @ mats[i]).contiguous().fold(vec, i)
np.testing.assert_allclose(out.numpy(), np_ref, atol=1e-6)
if __name__ == '__main__':
unittest.main()

View file

@ -308,9 +308,19 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
if len(reduce_range) == 0: return ret if len(reduce_range) == 0: return ret
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0)) return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
def fold_to_store(x:UOp):
_, acc, ranges = x.src[0], x.src[1], x.src[2:]
assert acc.op is Ops.INDEX
buf = acc.src[0]
ret = x.substitute({buf: buf.rtag().after(*ranges)}).substitute({buf.rtag(): buf})
base, acc, ranges = ret.src[0], ret.src[1], ret.src[2:]
return buf.after(acc.store(base).end(*ranges)).index(acc.src[1])
pm_reduce = PatternMatcher([ pm_reduce = PatternMatcher([
# REDUCE -> DEFINE_ACC+ASSIGN # REDUCE -> DEFINE_ACC+STORE
(UPat(Ops.REDUCE, name="red"), reduce_to_acc), (UPat(Ops.REDUCE, name="red"), reduce_to_acc),
# FOLD -> STORE
(UPat(Ops.FOLD, name="x"), fold_to_store),
# tensor core built in accumulate # tensor core built in accumulate
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"), (UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),

View file

@ -2,7 +2,8 @@ from __future__ import annotations
import math, itertools import math, itertools
from collections import defaultdict from collections import defaultdict
from typing import cast, Final from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, axis_letters, axis_colors from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer from tinygrad.device import Buffer
from tinygrad.dtype import dtypes, ImageDType from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
@ -12,10 +13,6 @@ from tinygrad.renderer import Renderer
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
class Scheduler: class Scheduler:
def __init__(self, ast:UOp, ren:Renderer): def __init__(self, ast:UOp, ren:Renderer):
self.ast, self.ren = ast, ren self.ast, self.ren = ast, ren

View file

@ -54,7 +54,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
new_srcs = [] new_srcs = []
for s in x.src: for i,s in enumerate(x.src):
new_src = s new_src = s
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL): if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
@ -65,7 +65,9 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
# None in the device assigns it a number later # None in the device assigns it a number later
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL) opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None) new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges]) if x in ctx.range_map:
# for scan we use the output ranges on the 2nd arg
new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][int(x.op is Ops.FOLD and i == 1)]) if i in realized_ranges])
new_srcs.append(new_src) new_srcs.append(new_src)
# NOTE: do we need this? # NOTE: do we need this?
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
@ -84,6 +86,13 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
ctx.range_map[ret] = ctx.range_map[x] ctx.range_map[ret] = ctx.range_map[x]
return ret return ret
def add_ranges_to_scan(ctx:IndexingContext, x:UOp):
if x not in ctx.range_map: return None
new_ranges = [r for r,ar in zip(*ctx.range_map[x]) if r is not ar and r not in x.src]
ret = x.replace(src=x.src+tuple(new_ranges))
ctx.range_map[ret] = ctx.range_map[x]
return ret
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp): def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0] if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
@ -97,6 +106,8 @@ def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
pm_apply_rangeify = PatternMatcher([ pm_apply_rangeify = PatternMatcher([
# REDUCE_AXIS -> REDUCE # REDUCE_AXIS -> REDUCE
(UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges), (UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges),
# SCAN -> SCAN (with new ranges)
(UPat(Ops.FOLD, name="x"), add_ranges_to_scan),
# PAD -> WHERE # PAD -> WHERE
(UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local), (UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local),
# add third op to assign # add third op to assign
@ -244,6 +255,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# REDUCE_AXIS creates ranges for the axes it is reducing # REDUCE_AXIS creates ranges for the axes it is reducing
if x.op is Ops.REDUCE_AXIS: if x.op is Ops.REDUCE_AXIS:
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape))) rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
if x.op is Ops.FOLD:
rngs = tuple(rctx.new_range(s, axistype=AxisType.FOLD) if resolve(x.src[1].shape[i] == 1) else r \
for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
if debug: if debug:
realized_ranges = rctx.realize_map.get(x, None) realized_ranges = rctx.realize_map.get(x, None)

View file

@ -396,6 +396,7 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
return buf return buf
def renumber_range(ctx:LocalAddBufferContext, r:UOp): def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.arg[-1] == AxisType.OUTER: return None
if r.tag != (): return None if r.tag != (): return None
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None) ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
ctx.range += 1 ctx.range += 1
@ -469,7 +470,10 @@ pm_add_range_tags = PatternMatcher([
]) ])
def split_store(ctx:list[UOp], x:UOp) -> UOp|None: def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
if len(x.ranges): return None if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
# ends of outer range don't go in kernels
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
# local kernel rewrite # local kernel rewrite
lctx = LocalAddBufferContext() lctx = LocalAddBufferContext()

View file

@ -1509,6 +1509,9 @@ class Tensor(OpMixin):
# ***** reduce ops ***** # ***** reduce ops *****
def fold(self, init:Tensor, *ranges:UOp) -> Tensor:
return self._apply_uop(UOp.fold, init, extra_args=ranges)
def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1))) axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
if self.ndim == 0: axis = () if self.ndim == 0: axis = ()

View file

@ -92,6 +92,9 @@ class Ops(FastEnum):
# reduce # reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
# scan
FOLD = auto()
# errors/placeholders # errors/placeholders
REWRITE_ERROR = auto(); SENTINEL = auto() REWRITE_ERROR = auto(); SENTINEL = auto()

View file

@ -13,14 +13,20 @@ if TYPE_CHECKING:
class AxisType(Enum): class AxisType(Enum):
def __repr__(self): return str(self) def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); FOLD = auto() # noqa: E702
THREAD = auto() UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto(); OUTER = auto() # noqa: E702
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u", axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.FOLD: "F", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE", axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.FOLD: "red", AxisType.UNROLL: "magenta",
AxisType.OUTER: "green"}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1} # NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.FOLD: 4, AxisType.REDUCE: 5, AxisType.UNROLL: 6, AxisType.OUTER: -2}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.FOLD: 2}
# https://en.wikipedia.org/wiki/Identity_element # https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
@ -213,6 +219,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]]) case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
# shape of init
case Ops.FOLD:
return self.src[1]._shape
# passthrough ops # passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
return self.src[0]._shape return self.src[0]._shape
@ -437,6 +447,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype" assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid) return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def fold(self, *src:UOp, **kwargs): return UOp(Ops.FOLD, self.dtype, (self,)+src, **kwargs)
def is_contiguous(self): def is_contiguous(self):
# TODO: this is is_realized # TODO: this is is_realized

View file

@ -40,6 +40,9 @@ shared_spec = PatternMatcher([
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None), (UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
# FOLD, 2 ops + ranges
(UPat(Ops.FOLD, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[2:])),
]) ])
# ***** UOp spec in the Tensor graph ***** # ***** UOp spec in the Tensor graph *****
@ -172,7 +175,7 @@ kernel_spec = PatternMatcher([
# bufferize can be on anything # bufferize can be on anything
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True), (UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
# reduce must be on ranges # reduce/fold must be on ranges
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
])+movement_ops+shared_codegen_spec+shared_spec ])+movement_ops+shared_codegen_spec+shared_spec

View file

@ -17,7 +17,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.FOLD: "#FF7B7B",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",