mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
param_call
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e336f3cf8c |
10 changed files with 43 additions and 31 deletions
|
|
@ -94,13 +94,13 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||||
if (view := _make_buffer_view(src)) is None: return None
|
if (view := _make_buffer_view(src)) is None: return None
|
||||||
return view.contiguous(tag=c.tag)
|
return view.contiguous(tag=c.tag)
|
||||||
|
|
||||||
def transform_precompiled_call(c:UOp) -> UOp|None:
|
def transform_precompiled_function(c:UOp) -> UOp|None:
|
||||||
if not c.arg.precompile: return None
|
if not c.arg.precompile: return None
|
||||||
if c.src[0].op is Ops.SINK: return None
|
if c.src[0].op is Ops.SINK: return None
|
||||||
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled call, got {c.src[0].op}"
|
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled function, got {c.src[0].op}"
|
||||||
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
|
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
|
||||||
|
|
||||||
# add the outputs to the call
|
# add the outputs to the function
|
||||||
srcs = c.src[0].src
|
srcs = c.src[0].src
|
||||||
resolved = [c.gettuple(i) for i in range(len(srcs))]
|
resolved = [c.gettuple(i) for i in range(len(srcs))]
|
||||||
outs = tuple(_buffer_like(r) for r in resolved)
|
outs = tuple(_buffer_like(r) for r in resolved)
|
||||||
|
|
@ -108,7 +108,7 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||||
fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)])
|
fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)])
|
||||||
|
|
||||||
# create the new thing for the big graph
|
# create the new thing for the big graph
|
||||||
new_call = c.replace(src=(fxn, *input_buffers, *outs), tag=None)
|
new_call = c.replace(op=Ops.CALL, src=(fxn, *input_buffers, *outs), tag=None)
|
||||||
rets = tuple(o.after(new_call) for o in outs)
|
rets = tuple(o.after(new_call) for o in outs)
|
||||||
|
|
||||||
# if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape
|
# if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape
|
||||||
|
|
@ -119,8 +119,8 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||||
|
|
||||||
# NOTE: adding rules to here is bad. these all need to run before the schedule cache
|
# NOTE: adding rules to here is bad. these all need to run before the schedule cache
|
||||||
pm_early_transform_tensor_graph = PatternMatcher([
|
pm_early_transform_tensor_graph = PatternMatcher([
|
||||||
# transform precompiled CALLs
|
# transform precompiled FUNCTIONs -> CALLs
|
||||||
(UPat(Ops.CALL, name="c"), transform_precompiled_call),
|
(UPat(Ops.FUNCTION, name="c"), transform_precompiled_function),
|
||||||
|
|
||||||
# resolve TUPLE+GETTUPLE (for precompiled calls)
|
# resolve TUPLE+GETTUPLE (for precompiled calls)
|
||||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
||||||
|
|
|
||||||
|
|
@ -95,15 +95,15 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
||||||
if t0 not in grads or grads[t0].op is Ops.NOOP: continue
|
if t0 not in grads or grads[t0].op is Ops.NOOP: continue
|
||||||
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
|
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
|
||||||
if t0.op is Ops.GETTUPLE:
|
if t0.op is Ops.GETTUPLE:
|
||||||
k = t0.src[0] # the CALL
|
k = t0.src[0] # the FUNCTION
|
||||||
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE
|
assert k.op is Ops.FUNCTION and k.src[0].op is Ops.TUPLE
|
||||||
n_outputs = len(k.src[0].src)
|
n_outputs = len(k.src[0].src)
|
||||||
prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
|
prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
|
||||||
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
|
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
|
||||||
grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
|
grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
|
||||||
continue
|
continue
|
||||||
# CALL: pass needed param set so backward only computes required gradients
|
# FUNCTION: pass needed param set so backward only computes required gradients
|
||||||
if t0.op is Ops.CALL:
|
if t0.op is Ops.FUNCTION:
|
||||||
needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
|
needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
|
||||||
lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
|
lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from tinygrad.helpers import all_same, prod, getenv, ALLREDUCE_CAST
|
from tinygrad.helpers import all_same, prod, getenv, ALLREDUCE_CAST
|
||||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call
|
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, KernelInfo
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.schedule.allreduce import handle_allreduce
|
from tinygrad.schedule.allreduce import handle_allreduce
|
||||||
|
|
||||||
|
|
@ -116,6 +116,15 @@ def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0
|
||||||
def passthrough_multi(root:UOp, multi:UOp):
|
def passthrough_multi(root:UOp, multi:UOp):
|
||||||
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
|
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
|
||||||
|
|
||||||
|
# TODO: this is all junk
|
||||||
|
|
||||||
|
def should_resolve_call(c:UOp) -> bool:
|
||||||
|
# don't resolve real kernel calls, sink or program
|
||||||
|
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
|
||||||
|
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.CUSTOM_FUNCTION}: return False
|
||||||
|
if c.arg.precompile: return False
|
||||||
|
return True
|
||||||
|
|
||||||
def rewrite_into_call(call:UOp):
|
def rewrite_into_call(call:UOp):
|
||||||
if not should_resolve_call(call): return None
|
if not should_resolve_call(call): return None
|
||||||
new_body = graph_rewrite(call.src[0], multi_pm, name="subcall")
|
new_body = graph_rewrite(call.src[0], multi_pm, name="subcall")
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
|
||||||
import itertools
|
import itertools
|
||||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid
|
from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid
|
||||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element
|
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, identity_element
|
||||||
from tinygrad.uop.symbolic import symbolic
|
from tinygrad.uop.symbolic import symbolic
|
||||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||||
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
|
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
|
||||||
|
|
@ -126,8 +126,7 @@ mop_cleanup = PatternMatcher([
|
||||||
])
|
])
|
||||||
|
|
||||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
|
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
|
||||||
def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||||
if not should_resolve_call(c): return None
|
|
||||||
params: list[UOp] = []
|
params: list[UOp] = []
|
||||||
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params")
|
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params")
|
||||||
params = sorted(params, key=lambda x: x.arg)
|
params = sorted(params, key=lambda x: x.arg)
|
||||||
|
|
@ -150,8 +149,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
|
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
|
||||||
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
|
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
|
||||||
|
|
||||||
# resolve calls
|
# resolve functions
|
||||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
(UPat(Ops.FUNCTION, name="c"), resolve_function),
|
||||||
|
|
||||||
# resolve TUPLE+GETTUPLE
|
# resolve TUPLE+GETTUPLE
|
||||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
||||||
|
|
|
||||||
|
|
@ -222,7 +222,7 @@ class Tensor(OpMixin):
|
||||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||||
return Tensor(param)
|
return Tensor(param)
|
||||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||||
fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
|
fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).function(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
|
||||||
return Tensor(fret.gettuple(0))
|
return Tensor(fret.gettuple(0))
|
||||||
|
|
||||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class Ops(FastEnum):
|
||||||
|
|
||||||
# uops that aren't rendered
|
# uops that aren't rendered
|
||||||
NOOP = auto(); REWRITE_ERROR = auto()
|
NOOP = auto(); REWRITE_ERROR = auto()
|
||||||
PARAM = auto(); CALL = auto()
|
PARAM = auto(); CALL = auto(); FUNCTION = auto()
|
||||||
|
|
||||||
# renderer
|
# renderer
|
||||||
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
|
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
|
||||||
|
|
|
||||||
|
|
@ -215,17 +215,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
# late ops don't have shape
|
# late ops don't have shape
|
||||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||||
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE:
|
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
case Ops.GETTUPLE:
|
case Ops.GETTUPLE:
|
||||||
# GETTUPLE extracts from a TUPLE (possibly through a CALL)
|
# GETTUPLE extracts from a TUPLE (possibly through a FUNCTION)
|
||||||
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0]
|
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0]
|
||||||
assert in_tuple.op is Ops.TUPLE
|
assert in_tuple.op is Ops.TUPLE
|
||||||
inner_shape = in_tuple.src[self.arg]._shape
|
inner_shape = in_tuple.src[self.arg]._shape
|
||||||
if inner_shape is None: return None
|
if inner_shape is None: return None
|
||||||
# if through a CALL, substitute internal PARAMs in the shape with corresponding args
|
# if through a FUNCTION, substitute internal PARAMs in the shape with corresponding args
|
||||||
if self.src[0].op is Ops.CALL:
|
if self.src[0].op is Ops.FUNCTION:
|
||||||
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[0].src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
|
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[0].src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
|
||||||
return inner_shape
|
return inner_shape
|
||||||
|
|
||||||
|
|
@ -262,8 +262,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||||
return self.src[0]._shape
|
return self.src[0]._shape
|
||||||
|
|
||||||
case Ops.CALL: return None
|
|
||||||
|
|
||||||
# TODO: disallow shape changing bitcast
|
# TODO: disallow shape changing bitcast
|
||||||
case Ops.BITCAST:
|
case Ops.BITCAST:
|
||||||
ps = self.src[0]._shape
|
ps = self.src[0]._shape
|
||||||
|
|
@ -421,8 +419,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
|
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
|
||||||
return UOp(Ops.TUPLE, dtypes.void, srcs)
|
return UOp(Ops.TUPLE, dtypes.void, srcs)
|
||||||
def gettuple(self, idx:int) -> UOp:
|
def gettuple(self, idx:int) -> UOp:
|
||||||
in_tuple = self.src[0] if self.op is Ops.CALL else self
|
in_tuple = self.src[0] if self.op is Ops.FUNCTION else self
|
||||||
assert in_tuple.op is Ops.TUPLE, f"gettuple requires CALL or TUPLE source, got {self.op}"
|
assert in_tuple.op is Ops.TUPLE, f"gettuple requires FUNCTION or TUPLE source, got {self.op}"
|
||||||
return UOp(Ops.GETTUPLE, in_tuple.src[idx].dtype, (self,), idx)
|
return UOp(Ops.GETTUPLE, in_tuple.src[idx].dtype, (self,), idx)
|
||||||
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
|
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
|
||||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||||
|
|
@ -941,6 +939,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
# value-producing bodies are always wrapped in TUPLE so CALL dtype is always void
|
# value-producing bodies are always wrapped in TUPLE so CALL dtype is always void
|
||||||
body = self if self.op in UOp._NO_TUPLE_WRAP else UOp.maketuple(self)
|
body = self if self.op in UOp._NO_TUPLE_WRAP else UOp.maketuple(self)
|
||||||
return UOp(Ops.CALL, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
return UOp(Ops.CALL, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||||
|
def function(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
|
||||||
|
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
|
||||||
|
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||||
|
return UOp(Ops.FUNCTION, dtypes.void, (UOp.maketuple(self),)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||||
|
|
|
||||||
|
|
@ -134,12 +134,13 @@ _tensor_spec = PatternMatcher([
|
||||||
|
|
||||||
# allow CALL/PARAM/CUSTOM_FUNCTION — CALL dtype is always void
|
# allow CALL/PARAM/CUSTOM_FUNCTION — CALL dtype is always void
|
||||||
(UPat(Ops.CALL, dtypes.void), lambda: True),
|
(UPat(Ops.CALL, dtypes.void), lambda: True),
|
||||||
|
(UPat(Ops.FUNCTION, dtypes.void), lambda: True),
|
||||||
(UPat(Ops.PARAM), lambda: True),
|
(UPat(Ops.PARAM), lambda: True),
|
||||||
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
|
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
|
||||||
|
|
||||||
# TUPLE must have void dtype, GETTUPLE can only appear on CALL or TUPLE
|
# TUPLE must have void dtype, GETTUPLE can only appear on CALL or TUPLE
|
||||||
(UPat(Ops.TUPLE, dtypes.void), lambda: True),
|
(UPat(Ops.TUPLE, dtypes.void), lambda: True),
|
||||||
(UPat(Ops.GETTUPLE, src=(UPat((Ops.CALL, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
|
(UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
|
||||||
|
|
||||||
# ** for custom kernels **
|
# ** for custom kernels **
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ const layoutUOp = (g, { graph, change }, opts) => {
|
||||||
width = Math.max(width, ctx.measureText(line).width);
|
width = Math.max(width, ctx.measureText(line).width);
|
||||||
height += lineHeight;
|
height += lineHeight;
|
||||||
}
|
}
|
||||||
const callNode = label.startsWith("CALL\n");
|
const callNode = label.startsWith("CALL\n") || label.startsWith("FUNCTION\n");
|
||||||
if (callNode) callCount++;
|
if (callNode) callCount++;
|
||||||
g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag, callNode});
|
g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag, callNode});
|
||||||
// add edges
|
// add edges
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
|
Ops.FUNCTION: "#C07788", Ops.CALL: "#00B7C8",
|
||||||
|
Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
|
||||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
||||||
|
|
||||||
|
|
@ -136,7 +137,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
||||||
label += f"\n({multirange_str(rngs, color=True)})"
|
label += f"\n({multirange_str(rngs, color=True)})"
|
||||||
if u._shape is not None:
|
if u._shape is not None:
|
||||||
label += f"\n{shape_to_str(u.shape)}"
|
label += f"\n{shape_to_str(u.shape)}"
|
||||||
if u.op is Ops.CALL:
|
if u.op in {Ops.CALL, Ops.FUNCTION}:
|
||||||
label += f"\n{u.src[0].key.hex()[:8]}"
|
label += f"\n{u.src[0].key.hex()[:8]}"
|
||||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||||
if len(u.toposort()) < 30: label += f"\n{u.render()}"
|
if len(u.toposort()) < 30: label += f"\n{u.render()}"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue