remove CUSTOM_KERNEL / directly construct it (#14604)

* remove CUSTOM_KERNEL / directly construct it

* clean that up

* simpler multi

* custom kernel spec

* remove Kernel

* fix multi

* use sharded shape

* explicit regression test
This commit is contained in:
George Hotz 2026-02-08 18:43:33 +08:00 committed by GitHub
commit 183d38b128
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 60 additions and 80 deletions

View file

@ -28,7 +28,7 @@ def custom_matmul(output: UOp, inp: UOp, weight: UOp) -> UOp:
return store_op.sink(arg=KernelInfo(name=f"fp8_matmul_{inp.shape}x{weight.shape}"))
def custom_matmul_backward(gradient: UOp, kernel: UOp) -> tuple[UOp, UOp]:
_, input_uop, weight_uop = kernel.src
_, input_uop, weight_uop = kernel.src[1:]
input_tensor = Tensor(input_uop, device=input_uop.device)
grad_tensor = Tensor(gradient, device=gradient.device)
weight_tensor = Tensor(weight_uop, device=weight_uop.device)

View file

@ -62,7 +62,7 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
# ** backward gemm, might use the asm gemm
def custom_gemm_bw(gradient:UOp, kernel:UOp):
out, a, b = kernel.src
out, a, b = kernel.src[1:]
assert all_same([gradient.device, a.device, b.device, out.device])
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
grad_a = (g_t @ b_t.T).uop

View file

@ -103,9 +103,9 @@ class Attention:
def fa_custom_backward(out_q:UOp, out_k:UOp, out_v:UOp, grad:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
return UOp.sink(arg=KernelInfo(name="fa_custom_backward"))
def fa_backward(grad:UOp, kernel:UOp) -> tuple[None, UOp, UOp, UOp]:
grad_q = Tensor.empty_like(q:=Tensor(kernel.src[1]))
grad_k = Tensor.empty_like(k:=Tensor(kernel.src[2]))
grad_v = Tensor.empty_like(v:=Tensor(kernel.src[3]))
grad_q = Tensor.empty_like(q:=Tensor(kernel.src[2]))
grad_k = Tensor.empty_like(k:=Tensor(kernel.src[3]))
grad_v = Tensor.empty_like(v:=Tensor(kernel.src[4]))
ck = Tensor.custom_kernel(grad_q, grad_k, grad_v, Tensor(grad), q, k, v, fxn=fa_custom_backward)[:3]
return (None, ck[0].uop, ck[1].uop, ck[2].uop)
attn = Tensor.empty_like(attn).custom_kernel(xq, keys, values, fxn=fa_custom_forward, grad_fxn=fa_backward)[0]

View file

@ -88,13 +88,13 @@ def simple_qkv_kernel(O:UOp, Q:UOp, K:UOp, V:UOp) -> UOp:
# **** backward callbacks ****
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
out, a, b = kernel.src
out, a, b = kernel.src[1:]
grad_a = (Tensor(gradient) @ Tensor(b).T).uop
grad_b = (Tensor(a).T @ Tensor(gradient)).uop
return (None, grad_a, grad_b)
def backward_gemm_custom(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
out, a, b = kernel.src
out, a, b = kernel.src[1:]
grad_a = Tensor.empty_like(Tensor(a)).custom_kernel(Tensor(gradient), Tensor(b).T, fxn=custom_gemm)[0].uop
grad_b = Tensor.empty_like(Tensor(b)).custom_kernel(Tensor(a).T, Tensor(gradient), fxn=custom_gemm)[0].uop
return (None, grad_a, grad_b)
@ -128,6 +128,14 @@ class TestCustomKernel(unittest.TestCase):
out = c.flatten().tolist()
assert all(x == 2 for x in out), "all 2"
def test_sharded_add_one(self):
# PYTHON backend explicitly checks for OOB access for wrong multi shape regression
devs = ("PYTHON:0", "PYTHON:1")
a = Tensor.ones(4, 4).contiguous().shard(devs, axis=0)
c = Tensor(Tensor.empty(2, 4, device=devs).uop.multi(0), device=devs)
c = Tensor.custom_kernel(c, a, fxn=custom_add_one_kernel)[0]
assert (c == 2).all().item()
def test_multioutput(self):
a = Tensor.full((16, 16), 3.).contiguous()
b = Tensor.full((16, 16), 3.).contiguous()

View file

@ -10,7 +10,7 @@ from hypothesis import assume, given, settings, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.realize import CompiledRunner, run_schedule
@ -677,17 +677,6 @@ class TestSchedule(unittest.TestCase):
c = (a.sum(2).contiguous() + b).contiguous()
check_schedule(c, 2)
# TODO: this requires supporting multiple stores in the AST
@unittest.expectedFailure
def test_multioutput_ast(self):
a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
c = Tensor.arange(4).realize().uop
kernel = UOp(Ops.CALL, src=(a.base, b.base, c.base), arg=Kernel(UOp.sink(c.r(Ops.ADD, (0,))+1, c.r(Ops.ADD, (0,))*2)))
run_schedule(check_schedule(UOp.sink(a.assign(kernel), b.assign(kernel)), 1))
self.assertEqual(a.buffer.numpy(), [7])
self.assertEqual(b.buffer.numpy(), [12])
@unittest.skip("no longer supported")
def test_double_from(self):
x = Tensor([1,2,3,4])

View file

@ -1,7 +1,7 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, CI
from tinygrad.helpers import getenv
from extra.gemm.asm.cdna.gemm import asm_gemm
from test.helpers import needs_second_gpu
@ -31,7 +31,8 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:i
assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch"
assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch"
SCALE = 128 if CI else 1
# 128x smaller than usual
SCALE = 128
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
class TestGemm(unittest.TestCase):

View file

@ -52,7 +52,6 @@ pm_gradient = PatternMatcher([
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
# NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
# gradient on CALL: use provided grad_fxn or auto-differentiate
(UPat(Ops.CALL, name="k"), call_gradient),
# there's no gradient for bitcast

View file

@ -2,6 +2,7 @@ from typing import cast
import functools, itertools
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
from tinygrad.dtype import dtypes
from tinygrad.device import Device
# *** allreduce implementation ***
@ -214,13 +215,14 @@ multi_pm = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
# we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
(UPat(Ops.CUSTOM_KERNEL, src=UPat((Ops.MULTI, Ops.CONTIGUOUS)), name="ck"),
lambda ck: ck.replace(src=tuple(m.src[0] if m.op is Ops.MULTI else m for m in ck.src))),
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
# after CALL
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
])+replace_allreduce
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:

View file

@ -74,10 +74,6 @@ mop_cleanup = PatternMatcher([
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
])
def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return ck.arg.fxn(*placeholders).call(*ck.src)
def resolve_call(c:UOp) -> UOp|None:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
@ -99,9 +95,6 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
# resolve custom kernels
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
# remove CONTIGUOUS if the BUFFER is already contiguous
(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER), UPat()), name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
@ -538,7 +531,7 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
if x.dtype.scalar() == dtypes.index: return None
ctx[0].append(x)
return x.replace(tag=(len(ctx[0])-1,))
add_tags = PatternMatcher([
add_tags = pm_gate_kernel_sink+PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),

View file

@ -234,8 +234,7 @@ class Tensor(OpMixin):
def as_param(self, slot:int):
if self.uop.axis is not None:
multi_shape = tuple([s//len(self.device) if i==self.uop.axis else s for i,s in enumerate(self.shape)])
param = UOp.param(slot, self.dtype, multi_shape, self.device).multi(self.uop.axis)
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device).multi(self.uop.axis)
else:
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param, device=self.device)
@ -756,8 +755,7 @@ class Tensor(OpMixin):
dtype = kwargs.pop("dtype", self.dtype)
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
def full_like(self, fill_value:PyConst, **kwargs) -> Tensor:

View file

@ -80,7 +80,6 @@ class Ops(FastEnum):
# tensor graph ops
UNIQUE = auto(); DEVICE = auto(); ASSIGN = auto()
CUSTOM_KERNEL = auto()
# local unique
LUNIQUE = auto()

View file

@ -207,7 +207,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | Ops.SINK | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
return None
@ -798,6 +798,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# *** uop high level syntactic sugar ***
@property
def shard_shape(self):
if self.axis is None: return self.shape
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.shape))
@staticmethod
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
@ -806,7 +811,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return ret
def placeholder_like(self, slot:int):
assert all_int(self.shape), "no placeholder-like on symbolic shape"
return UOp.placeholder(self.shape, self.dtype, slot)
return UOp.placeholder(self.shard_shape, self.dtype, slot)
# set is store+end+after
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
@ -824,7 +829,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
kernel = fxn(*placeholders).call(*contig_srcs, grad_fxn=grad_fxn)
return [s.after(kernel) for s in contig_srcs]
@dataclass(frozen=True)
@ -838,14 +844,6 @@ class KernelInfo:
@property
def function_name(self): return to_function_name(self.name)
@dataclass(frozen=True)
class CustomKernel:
fxn: Callable
grad_fxn: Callable|None = None
# sadly CustomKernel can't be pickled or reconstructed as a str
def __reduce__(self): return (CustomKernel, (panic,))
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
@dataclass(frozen=True)
class CallInfo:
grad_fxn: Callable|None = None
@ -854,12 +852,6 @@ class CallInfo:
def __reduce__(self): return (CallInfo, (None, self.metadata))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
@dataclass(frozen=True)
class Kernel:
ast: UOp
metadata: tuple[Metadata, ...] = ()
grad_fxn: Callable|None = None
# ******** ops in python ********
def safe_exp2(x):
@ -1436,7 +1428,7 @@ def pyrender(ast:UOp) -> str:
for s in u.src: to_render.add(s)
if u.op is Ops.STORE: to_render.add(u.src[1])
if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0])
if u.op in {Ops.CUSTOM_KERNEL, Ops.CALL}: raise NotImplementedError("custom_kernel / call can't be pyrendered")
if u.op is Ops.CALL: raise NotImplementedError("call can't be pyrendered")
if u.op in not_rendered: continue
# checking the consumers is not enough, you have to make sure it's not used twice by the one consumer
if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue

View file

@ -1,6 +1,6 @@
import math
from typing import cast, Any
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel, CustomKernel
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
@ -73,9 +73,6 @@ movement_ops = PatternMatcher([
# AFTER on Movement Op
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS})),), allow_any_len=True), lambda: True),
# custom kernels allowed here
(UPat(Ops.CUSTOM_KERNEL), lambda: True),
])
_tensor_spec = PatternMatcher([
@ -139,12 +136,19 @@ _tensor_spec = PatternMatcher([
# allow CALL/PARAM
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
(UPat(Ops.PARAM), lambda: True),
])+movement_ops+shared_spec
tensor_spec = PatternMatcher([
# no tags allowed in tensor graph
(UPat(GroupOp.All, name="x"), lambda x: None if x.tag is None else False),
])+_tensor_spec
# ** for custom kernels **
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
])+movement_ops+shared_spec
# ***** UOp spec in codegen shared between kernel and program *****
@ -204,6 +208,11 @@ kernel_spec = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.index, dtypes.int) for y in x.src[1:])),
])+movement_ops+shared_codegen_spec+shared_spec
tensor_spec = PatternMatcher([
# no tags allowed in tensor graph
(UPat(GroupOp.All, name="x"), lambda x: None if x.tag is None else False),
])+_tensor_spec+kernel_spec
# ***** UOp spec in linearized programs *****
program_spec = PatternMatcher([
@ -264,16 +273,6 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),
(UPat(Ops.INDEX), lambda: True),
@ -301,8 +300,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher):
# late imports to avoid circular import
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.schedule.rangeify import BufferizeOpts
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Kernel": Kernel, "Metadata": Metadata,
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid, "CustomKernel": CustomKernel,
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Metadata": Metadata,
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid,
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace, "panic": panic,
"ConstFloat": ConstFloat}
def eval_pyrender(code:str) -> UOp:

View file

@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.PARAM:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.CUSTOM_KERNEL: "#3ebf55",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",