Compare commits

...

6 commits

Author SHA1 Message Date
George Hotz
771a395240
Merge branch 'master' into kernel_is_call 2026-02-06 09:15:03 +08:00
George Hotz
d83ddc05c8 resolve_call 2026-02-05 12:57:31 +08:00
George Hotz
8e8cac4b0f don't use tag, use KernelInfo 2026-02-05 12:31:14 +08:00
George Hotz
57199fd9de keep the all buffers on same device check 2026-02-05 12:17:32 +08:00
George Hotz
2193d0edfa fix arg order 2026-02-05 12:03:45 +08:00
George Hotz
77adccb925 use call for kernel 2026-02-05 11:48:50 +08:00
10 changed files with 92 additions and 47 deletions

View file

@ -1,7 +1,7 @@
import time import time
from typing import cast from typing import cast
from collections import deque from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, CallInfo
from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
@ -28,8 +28,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}" assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
in_degree.setdefault(k, 0) in_degree.setdefault(k, 0)
if k.op is Ops.END: assert k.src[0].op is Ops.KERNEL, f"END src[0] should be KERNEL, not {k.src[0].op}" for s in k.src[0].src if k.op is Ops.END else k.src[1:]:
for s in k.src[0].src if k.op is Ops.END else k.src:
match (s := _unwrap_src(s)).op: match (s := _unwrap_src(s)).op:
case Ops.AFTER: case Ops.AFTER:
children.setdefault(s.src[1], []).append(k) children.setdefault(s.src[1], []).append(k)
@ -54,14 +53,13 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
while len(queue): while len(queue):
k = rk = queue.popleft() k = rk = queue.popleft()
if k.op is Ops.END: k = k.src[0] if k.op is Ops.END: k = k.src[0]
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}" assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
if k.op is Ops.RANGE: schedule.append(k) if k.op is Ops.RANGE: schedule.append(k)
elif k.op is Ops.KERNEL: elif k.op is Ops.CALL:
ast = (kernel:=cast(Kernel, k.arg)).ast ast = k.src[0]
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND) buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE) bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
sched_item[k] = (ast, buf_uops, kernel.metadata, bound_ranges) schedule.append((ast, buf_uops, cast(CallInfo, k.arg).metadata, {}, bound_ranges))
schedule.append(k)
if rk.op is Ops.END: schedule.append(rk) if rk.op is Ops.END: schedule.append(rk)
for x in children.get(rk, []): for x in children.get(rk, []):
in_degree[x] -= 1 in_degree[x] -= 1

View file

@ -14,7 +14,7 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],) if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
def call_gradient(ctx:UOp, k:UOp): def call_gradient(ctx:UOp, k:UOp):
if k.arg is not None: return (None,) + k.arg(ctx, k) if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
# auto-differentiate the function # auto-differentiate the function
fxn, args = k.src[0], k.src[1:] fxn, args = k.src[0], k.src[1:]
params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg) params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)

View file

@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL, Ops.ENCDEC} Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
@ -20,7 +20,7 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None: def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
# if it's a kernel, we don't realize it # if it's a kernel, we don't realize it
if a.src[1].op is not Ops.KERNEL: ctx[a] = None if a.src[1].op is not Ops.CALL: ctx[a] = None
pm_generate_realize_map = PatternMatcher([ pm_generate_realize_map = PatternMatcher([
# always realize SINK src # always realize SINK src
@ -99,7 +99,7 @@ def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0] if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp): def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
if assign.src[1].op is Ops.KERNEL: return None if assign.src[1].op is Ops.CALL: return None
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))])) to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
ret = assign.replace(src=assign.src+(to_mop,)) ret = assign.replace(src=assign.src+(to_mop,))
ctx.range_map[ret] = ctx.range_map[assign] ctx.range_map[ret] = ctx.range_map[assign]
@ -174,7 +174,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
# no ranges on kernels, they are internal # no ranges on kernels, they are internal
if x.op is Ops.KERNEL: continue if x.op is Ops.CALL: continue
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this? if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], []) ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])

View file

@ -2,10 +2,10 @@ from dataclasses import dataclass, field, replace
import itertools import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element from tinygrad.helpers import PCONTIG, partition, get_single_element, panic
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
@ -66,9 +66,11 @@ mop_cleanup = PatternMatcher([
def resolve_custom_kernel(ck:UOp) -> UOp: def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)] placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders))) return ck.arg.fxn(*placeholders).call(*ck.src)
def resolve_call(c:UOp) -> UOp: def resolve_call(c:UOp) -> UOp|None:
# don't resolve CALLs with SINK - those are kernel calls from split_store/custom_kernel
if c.src[0].op is Ops.SINK: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg) params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:] args = c.src[1:]
# TODO: this check belongs in spec, not here # TODO: this check belongs in spec, not here
@ -83,7 +85,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works... # just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# resolve calls # resolve calls (but not CALLs with SINK - those are kernel calls from split_store)
(UPat(Ops.CALL, name="c"), resolve_call), (UPat(Ops.CALL, name="c"), resolve_call),
# resolve custom kernels # resolve custom kernels
@ -384,8 +386,8 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"), (UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)), lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
# remove any RESHAPEs on KERNEL # remove any RESHAPEs on CALL
(UPat(Ops.KERNEL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))), (UPat(Ops.CALL, name="k"), lambda k: k.replace(src=(k.src[0],)+tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src[1:]))),
]) ])
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([ pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
@ -515,16 +517,17 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
else: raise RuntimeError(f"unknown kernel type {ret.op}") else: raise RuntimeError(f"unknown kernel type {ret.op}")
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
else: else:
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else KernelInfo())
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]) kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg) if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]): raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
return kernel return kernel
split_kernels = PatternMatcher([ split_kernels = PatternMatcher([
(UPat((Ops.STORE, Ops.END), name="x"), split_store), (UPat((Ops.STORE, Ops.END), name="x"), split_store),
# if it's a Kernel, stop
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate()) if isinstance(sink.arg, KernelInfo) else None),
]) ])
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp): def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
@ -537,7 +540,7 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
return x.replace(tag=(len(ctx[0])-1,)) return x.replace(tag=(len(ctx[0])-1,))
add_tags = PatternMatcher([ add_tags = PatternMatcher([
# don't tag BUFFERs, they are global # don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END, (UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop), Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)), (UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
]) ])
@ -581,7 +584,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# bufferize -> store # bufferize -> store
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1 lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store") tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels") tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels", bottom_up=True)
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {} kernel_assign: dict[UOp, UOp] = {}

View file

@ -240,7 +240,7 @@ class Tensor(OpMixin):
param = UOp.param(slot, self.dtype, self.shape, self.device) param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param, device=self.device) return Tensor(param, device=self.device)
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor: def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device) return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]: def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
""" """

View file

@ -364,7 +364,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@recursive_property @recursive_property
def trace_num(self): def trace_num(self):
num = next(ucount) num = next(ucount)
# KERNEL also has a UOp in the arg # KERNEL has a UOp in the arg, CALL has it in src[0] so no special handling needed
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ()) uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
return num return num
@ -818,7 +818,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),)) src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
return UOp(Ops.PARAM, dtype, src, arg=slot) return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg) def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn)) kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
@ -843,6 +844,14 @@ class CustomKernel:
def __reduce__(self): return (CustomKernel, (panic,)) def __reduce__(self): return (CustomKernel, (panic,))
def __repr__(self): return f"CustomKernel({id(self.fxn)})" def __repr__(self): return f"CustomKernel({id(self.fxn)})"
@dataclass(frozen=True)
class CallInfo:
grad_fxn: Callable|None = None
metadata: tuple[Metadata, ...] = ()
# grad_fxn can't be pickled, but metadata can
def __reduce__(self): return (CallInfo, (None, self.metadata))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
@dataclass(frozen=True) @dataclass(frozen=True)
class Kernel: class Kernel:
ast: UOp ast: UOp
@ -1405,7 +1414,7 @@ pm_pyrender_extra = PatternMatcher([
# NOTE: you can remove pm_pyrender_extra and it'll still be correct # NOTE: you can remove pm_pyrender_extra and it'll still be correct
pm_pyrender = pm_pyrender_extra+PatternMatcher([ pm_pyrender = pm_pyrender_extra+PatternMatcher([
(UPat(Ops.KERNEL, name="u"), lambda ctx,u: f"UOp(Ops.KERNEL, src={srcs(ctx,u.src)}, arg=Kernel({ctx[u.arg.ast]}(), {u.arg.metadata}))"), (UPat(Ops.CALL, name="u"), lambda ctx,u: f"{ctx[u.src[0]]}.call({', '.join(ctx[s] for s in u.src[1:])}, metadata={u.arg.metadata})"),
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")), (UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
]) ])
@ -1415,7 +1424,7 @@ def pyrender(ast:UOp) -> str:
cmap = consumer_map_from_toposort(lst) cmap = consumer_map_from_toposort(lst)
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE} not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE, always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
Ops.BUFFER, Ops.COPY, Ops.KERNEL, Ops.WHERE, Ops.END, Ops.ASSIGN} Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
to_render: set[UOp] = {ast} to_render: set[UOp] = {ast}
for u in lst: for u in lst:

View file

@ -87,8 +87,9 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"), (UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))), lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER # CALL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), # src[0] is the function (SINK), src[1:] are buffers/bindings
(UPat(Ops.CALL, src=(UPat(Ops.SINK), UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), allow_any_len=True), lambda: True),
# ASSIGN has a target and a value. It can also optionally depend on other assigns # ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])), (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
@ -249,8 +250,8 @@ full_spec = PatternMatcher([
# vectorized index # vectorized index
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True), (UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
# linearizer: outputs + intermediate KERNELs # linearizer: outputs + intermediate CALLs
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True), (UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
# Invalid must have type Index # Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index), (UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),

View file

@ -251,7 +251,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects # only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+ (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))), tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0] # after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s), (UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
# VECTORIZE/CONST # VECTORIZE/CONST

View file

@ -738,12 +738,13 @@ window.addEventListener("popstate", (e) => {
}); });
const createToggle = (id, text) => { const createToggle = (id, text) => {
const label = d3.create("label").text(text).node(); const label = d3.create("label").style("display", "block").text(text).node();
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node(); const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
label.prepend(toggle); label.prepend(toggle);
return { toggle, label }; return { toggle, label };
} }
const { toggle, label:toggleLabel } = createToggle("show-indexing", "Show indexing (r)"); const showIndexing = createToggle("show-indexing", "Show indexing (r)");
const showCallSrc = createToggle("show-call-src", "Show CALL src (c)");
const showGraph = createToggle("show-graph", "Show graph (g)"); const showGraph = createToggle("show-graph", "Show graph (g)");
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph"); showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
@ -893,11 +894,13 @@ async function main() {
// ** center graph // ** center graph
const data = ret[currentRewrite]; const data = ret[currentRewrite];
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 }); const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
render({ showIndexing:toggle.checked }); const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked });
toggle.onchange = (e) => render({ showIndexing:e.target.checked }); render(getOpts());
showIndexing.toggle.onchange = () => render(getOpts());
showCallSrc.toggle.onchange = () => render(getOpts());
// ** right sidebar metadata // ** right sidebar metadata
metadata.innerHTML = ""; metadata.innerHTML = "";
if (ckey.includes("rewrites")) metadata.appendChild(toggleLabel); if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label);
if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true })); if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }));
if (step.trace) { if (step.trace) {
const trace = d3.create("pre").append("code").classed("hljs", true); const trace = d3.create("pre").append("code").classed("hljs", true);
@ -1025,7 +1028,9 @@ document.addEventListener("keydown", (event) => {
document.getElementById("zoom-to-fit-btn").click(); document.getElementById("zoom-to-fit-btn").click();
} }
// r key toggles indexing // r key toggles indexing
if (event.key === "r") toggle.click(); if (event.key === "r") showIndexing.toggle.click();
// c key toggles CALL src
if (event.key === "c") showCallSrc.toggle.click();
// g key toggles graph // g key toggles graph
if (event.key === "g") showGraph.toggle.click(); if (event.key === "g") showGraph.toggle.click();
}); });

View file

@ -55,13 +55,42 @@ const layoutUOp = (g, { graph, change }, opts) => {
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}}); for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
if (change?.includes(parseInt(k))) g.setParent(k, "overlay"); if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
} }
// optionally hide nodes from the layuot // optionally hide nodes from the layout
if (!opts.showIndexing) { if (!opts.showIndexing) {
for (const n of g.nodes()) { for (const n of g.nodes()) {
const node = g.node(n); const node = g.node(n);
if (node.label.includes("dtypes.index")) g.removeNode(n); if (node.label.includes("dtypes.index")) g.removeNode(n);
} }
} }
if (!opts.showCallSrc) {
// remove edges from src[0] to CALL nodes, track affected nodes
const disconnected = new Set();
for (const n of g.nodes()) {
const node = g.node(n);
if (node?.label?.startsWith("CALL\n") || node?.label === "CALL") {
for (const pred of (g.predecessors(n) || [])) {
const edge = g.edge(pred, n);
if (edge?.label?.text === 0) {
g.removeEdge(pred, n);
disconnected.add(pred);
}
}
}
}
// remove nodes that are now disconnected (no successors), only from affected subtree
let changed = true;
while (changed) {
changed = false;
for (const n of disconnected) {
if (!g.hasNode(n)) continue;
if ((g.successors(n) || []).length === 0) {
for (const pred of (g.predecessors(n) || [])) disconnected.add(pred);
g.removeNode(n);
changed = true;
}
}
}
}
dagre.layout(g); dagre.layout(g);
// remove overlay node if it's empty // remove overlay node if it's empty
if (!g.node("overlay")?.width) g.removeNode("overlay"); if (!g.node("overlay")?.width) g.removeNode("overlay");