mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove CUSTOM_KERNEL / directly construct it (#14604)
* remove CUSTOM_KERNEL / directly construct it * clean that up * simpler multi * custom kernel spec * remove Kernel * fix multi * use sharded shape * explicit regression test
This commit is contained in:
parent
e29a88ca09
commit
183d38b128
14 changed files with 60 additions and 80 deletions
|
|
@ -28,7 +28,7 @@ def custom_matmul(output: UOp, inp: UOp, weight: UOp) -> UOp:
|
|||
return store_op.sink(arg=KernelInfo(name=f"fp8_matmul_{inp.shape}x{weight.shape}"))
|
||||
|
||||
def custom_matmul_backward(gradient: UOp, kernel: UOp) -> tuple[UOp, UOp]:
|
||||
_, input_uop, weight_uop = kernel.src
|
||||
_, input_uop, weight_uop = kernel.src[1:]
|
||||
input_tensor = Tensor(input_uop, device=input_uop.device)
|
||||
grad_tensor = Tensor(gradient, device=gradient.device)
|
||||
weight_tensor = Tensor(weight_uop, device=weight_uop.device)
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
|||
# ** backward gemm, might use the asm gemm
|
||||
|
||||
def custom_gemm_bw(gradient:UOp, kernel:UOp):
|
||||
out, a, b = kernel.src
|
||||
out, a, b = kernel.src[1:]
|
||||
assert all_same([gradient.device, a.device, b.device, out.device])
|
||||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
grad_a = (g_t @ b_t.T).uop
|
||||
|
|
|
|||
|
|
@ -103,9 +103,9 @@ class Attention:
|
|||
def fa_custom_backward(out_q:UOp, out_k:UOp, out_v:UOp, grad:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
|
||||
return UOp.sink(arg=KernelInfo(name="fa_custom_backward"))
|
||||
def fa_backward(grad:UOp, kernel:UOp) -> tuple[None, UOp, UOp, UOp]:
|
||||
grad_q = Tensor.empty_like(q:=Tensor(kernel.src[1]))
|
||||
grad_k = Tensor.empty_like(k:=Tensor(kernel.src[2]))
|
||||
grad_v = Tensor.empty_like(v:=Tensor(kernel.src[3]))
|
||||
grad_q = Tensor.empty_like(q:=Tensor(kernel.src[2]))
|
||||
grad_k = Tensor.empty_like(k:=Tensor(kernel.src[3]))
|
||||
grad_v = Tensor.empty_like(v:=Tensor(kernel.src[4]))
|
||||
ck = Tensor.custom_kernel(grad_q, grad_k, grad_v, Tensor(grad), q, k, v, fxn=fa_custom_backward)[:3]
|
||||
return (None, ck[0].uop, ck[1].uop, ck[2].uop)
|
||||
attn = Tensor.empty_like(attn).custom_kernel(xq, keys, values, fxn=fa_custom_forward, grad_fxn=fa_backward)[0]
|
||||
|
|
|
|||
|
|
@ -88,13 +88,13 @@ def simple_qkv_kernel(O:UOp, Q:UOp, K:UOp, V:UOp) -> UOp:
|
|||
# **** backward callbacks ****
|
||||
|
||||
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
|
||||
out, a, b = kernel.src
|
||||
out, a, b = kernel.src[1:]
|
||||
grad_a = (Tensor(gradient) @ Tensor(b).T).uop
|
||||
grad_b = (Tensor(a).T @ Tensor(gradient)).uop
|
||||
return (None, grad_a, grad_b)
|
||||
|
||||
def backward_gemm_custom(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
|
||||
out, a, b = kernel.src
|
||||
out, a, b = kernel.src[1:]
|
||||
grad_a = Tensor.empty_like(Tensor(a)).custom_kernel(Tensor(gradient), Tensor(b).T, fxn=custom_gemm)[0].uop
|
||||
grad_b = Tensor.empty_like(Tensor(b)).custom_kernel(Tensor(a).T, Tensor(gradient), fxn=custom_gemm)[0].uop
|
||||
return (None, grad_a, grad_b)
|
||||
|
|
@ -128,6 +128,14 @@ class TestCustomKernel(unittest.TestCase):
|
|||
out = c.flatten().tolist()
|
||||
assert all(x == 2 for x in out), "all 2"
|
||||
|
||||
def test_sharded_add_one(self):
|
||||
# PYTHON backend explicitly checks for OOB access for wrong multi shape regression
|
||||
devs = ("PYTHON:0", "PYTHON:1")
|
||||
a = Tensor.ones(4, 4).contiguous().shard(devs, axis=0)
|
||||
c = Tensor(Tensor.empty(2, 4, device=devs).uop.multi(0), device=devs)
|
||||
c = Tensor.custom_kernel(c, a, fxn=custom_add_one_kernel)[0]
|
||||
assert (c == 2).all().item()
|
||||
|
||||
def test_multioutput(self):
|
||||
a = Tensor.full((16, 16), 3.).contiguous()
|
||||
b = Tensor.full((16, 16), 3.).contiguous()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from hypothesis import assume, given, settings, strategies as strat
|
|||
from tinygrad import nn, dtypes, Device, Tensor, Variable
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
|
||||
|
|
@ -677,17 +677,6 @@ class TestSchedule(unittest.TestCase):
|
|||
c = (a.sum(2).contiguous() + b).contiguous()
|
||||
check_schedule(c, 2)
|
||||
|
||||
# TODO: this requires supporting multiple stores in the AST
|
||||
@unittest.expectedFailure
|
||||
def test_multioutput_ast(self):
|
||||
a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
|
||||
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
|
||||
c = Tensor.arange(4).realize().uop
|
||||
kernel = UOp(Ops.CALL, src=(a.base, b.base, c.base), arg=Kernel(UOp.sink(c.r(Ops.ADD, (0,))+1, c.r(Ops.ADD, (0,))*2)))
|
||||
run_schedule(check_schedule(UOp.sink(a.assign(kernel), b.assign(kernel)), 1))
|
||||
self.assertEqual(a.buffer.numpy(), [7])
|
||||
self.assertEqual(b.buffer.numpy(), [12])
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_double_from(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, CI
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.gemm.asm.cdna.gemm import asm_gemm
|
||||
from test.helpers import needs_second_gpu
|
||||
|
||||
|
|
@ -31,7 +31,8 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:i
|
|||
assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch"
|
||||
assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch"
|
||||
|
||||
SCALE = 128 if CI else 1
|
||||
# 128x smaller than usual
|
||||
SCALE = 128
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
class TestGemm(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ pm_gradient = PatternMatcher([
|
|||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
# NOTE: this is only correct when the KERNEL has a single output
|
||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
||||
# gradient on CALL: use provided grad_fxn or auto-differentiate
|
||||
(UPat(Ops.CALL, name="k"), call_gradient),
|
||||
# there's no gradient for bitcast
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import cast
|
|||
import functools, itertools
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device
|
||||
|
||||
# *** allreduce implementation ***
|
||||
|
|
@ -214,13 +215,14 @@ multi_pm = PatternMatcher([
|
|||
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
|
||||
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
# we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels
|
||||
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
|
||||
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
|
||||
(UPat(Ops.CUSTOM_KERNEL, src=UPat((Ops.MULTI, Ops.CONTIGUOUS)), name="ck"),
|
||||
lambda ck: ck.replace(src=tuple(m.src[0] if m.op is Ops.MULTI else m for m in ck.src))),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
|
||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
|
||||
# after CALL
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
|
||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
|
||||
])+replace_allreduce
|
||||
|
||||
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
|
|
|
|||
|
|
@ -74,10 +74,6 @@ mop_cleanup = PatternMatcher([
|
|||
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
|
||||
])
|
||||
|
||||
def resolve_custom_kernel(ck:UOp) -> UOp:
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return ck.arg.fxn(*placeholders).call(*ck.src)
|
||||
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||
|
|
@ -99,9 +95,6 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
# resolve calls
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
# resolve custom kernels
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
|
||||
|
||||
# remove CONTIGUOUS if the BUFFER is already contiguous
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER), UPat()), name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
|
||||
|
|
@ -538,7 +531,7 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
|||
if x.dtype.scalar() == dtypes.index: return None
|
||||
ctx[0].append(x)
|
||||
return x.replace(tag=(len(ctx[0])-1,))
|
||||
add_tags = PatternMatcher([
|
||||
add_tags = pm_gate_kernel_sink+PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
|
||||
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
||||
|
|
|
|||
|
|
@ -234,8 +234,7 @@ class Tensor(OpMixin):
|
|||
|
||||
def as_param(self, slot:int):
|
||||
if self.uop.axis is not None:
|
||||
multi_shape = tuple([s//len(self.device) if i==self.uop.axis else s for i,s in enumerate(self.shape)])
|
||||
param = UOp.param(slot, self.dtype, multi_shape, self.device).multi(self.uop.axis)
|
||||
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device).multi(self.uop.axis)
|
||||
else:
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param, device=self.device)
|
||||
|
|
@ -756,8 +755,7 @@ class Tensor(OpMixin):
|
|||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
|
||||
sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
|
||||
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
|
||||
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
|
||||
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
|
||||
|
||||
def full_like(self, fill_value:PyConst, **kwargs) -> Tensor:
|
||||
|
|
|
|||
|
|
@ -80,7 +80,6 @@ class Ops(FastEnum):
|
|||
|
||||
# tensor graph ops
|
||||
UNIQUE = auto(); DEVICE = auto(); ASSIGN = auto()
|
||||
CUSTOM_KERNEL = auto()
|
||||
|
||||
# local unique
|
||||
LUNIQUE = auto()
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | Ops.SINK | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
|
||||
return None
|
||||
|
||||
|
|
@ -798,6 +798,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
# *** uop high level syntactic sugar ***
|
||||
|
||||
@property
|
||||
def shard_shape(self):
|
||||
if self.axis is None: return self.shape
|
||||
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.shape))
|
||||
|
||||
@staticmethod
|
||||
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
|
|
@ -806,7 +811,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
assert all_int(self.shape), "no placeholder-like on symbolic shape"
|
||||
return UOp.placeholder(self.shape, self.dtype, slot)
|
||||
return UOp.placeholder(self.shard_shape, self.dtype, slot)
|
||||
|
||||
# set is store+end+after
|
||||
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
|
||||
|
|
@ -824,7 +829,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||
kernel = fxn(*placeholders).call(*contig_srcs, grad_fxn=grad_fxn)
|
||||
return [s.after(kernel) for s in contig_srcs]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -838,14 +844,6 @@ class KernelInfo:
|
|||
@property
|
||||
def function_name(self): return to_function_name(self.name)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CustomKernel:
|
||||
fxn: Callable
|
||||
grad_fxn: Callable|None = None
|
||||
# sadly CustomKernel can't be pickled or reconstructed as a str
|
||||
def __reduce__(self): return (CustomKernel, (panic,))
|
||||
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallInfo:
|
||||
grad_fxn: Callable|None = None
|
||||
|
|
@ -854,12 +852,6 @@ class CallInfo:
|
|||
def __reduce__(self): return (CallInfo, (None, self.metadata))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
ast: UOp
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
grad_fxn: Callable|None = None
|
||||
|
||||
# ******** ops in python ********
|
||||
|
||||
def safe_exp2(x):
|
||||
|
|
@ -1436,7 +1428,7 @@ def pyrender(ast:UOp) -> str:
|
|||
for s in u.src: to_render.add(s)
|
||||
if u.op is Ops.STORE: to_render.add(u.src[1])
|
||||
if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0])
|
||||
if u.op in {Ops.CUSTOM_KERNEL, Ops.CALL}: raise NotImplementedError("custom_kernel / call can't be pyrendered")
|
||||
if u.op is Ops.CALL: raise NotImplementedError("call can't be pyrendered")
|
||||
if u.op in not_rendered: continue
|
||||
# checking the consumers is not enough, you have to make sure it's not used twice by the one consumer
|
||||
if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
from typing import cast, Any
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel, CustomKernel
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
|
||||
|
||||
|
|
@ -73,9 +73,6 @@ movement_ops = PatternMatcher([
|
|||
|
||||
# AFTER on Movement Op
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS})),), allow_any_len=True), lambda: True),
|
||||
|
||||
# custom kernels allowed here
|
||||
(UPat(Ops.CUSTOM_KERNEL), lambda: True),
|
||||
])
|
||||
|
||||
_tensor_spec = PatternMatcher([
|
||||
|
|
@ -139,12 +136,19 @@ _tensor_spec = PatternMatcher([
|
|||
# allow CALL/PARAM
|
||||
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
|
||||
(UPat(Ops.PARAM), lambda: True),
|
||||
])+movement_ops+shared_spec
|
||||
|
||||
tensor_spec = PatternMatcher([
|
||||
# no tags allowed in tensor graph
|
||||
(UPat(GroupOp.All, name="x"), lambda x: None if x.tag is None else False),
|
||||
])+_tensor_spec
|
||||
# ** for custom kernels **
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
])+movement_ops+shared_spec
|
||||
|
||||
# ***** UOp spec in codegen shared between kernel and program *****
|
||||
|
||||
|
|
@ -204,6 +208,11 @@ kernel_spec = PatternMatcher([
|
|||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.index, dtypes.int) for y in x.src[1:])),
|
||||
])+movement_ops+shared_codegen_spec+shared_spec
|
||||
|
||||
tensor_spec = PatternMatcher([
|
||||
# no tags allowed in tensor graph
|
||||
(UPat(GroupOp.All, name="x"), lambda x: None if x.tag is None else False),
|
||||
])+_tensor_spec+kernel_spec
|
||||
|
||||
# ***** UOp spec in linearized programs *****
|
||||
|
||||
program_spec = PatternMatcher([
|
||||
|
|
@ -264,16 +273,6 @@ full_spec = PatternMatcher([
|
|||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
|
||||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.VECTORIZE), lambda: True),
|
||||
(UPat(Ops.INDEX), lambda: True),
|
||||
|
|
@ -301,8 +300,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher):
|
|||
# late imports to avoid circular import
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.schedule.rangeify import BufferizeOpts
|
||||
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Kernel": Kernel, "Metadata": Metadata,
|
||||
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid, "CustomKernel": CustomKernel,
|
||||
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Metadata": Metadata,
|
||||
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid,
|
||||
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace, "panic": panic,
|
||||
"ConstFloat": ConstFloat}
|
||||
def eval_pyrender(code:str) -> UOp:
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
|
|||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.PARAM:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.CUSTOM_KERNEL: "#3ebf55",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue