mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
6 commits
master
...
kernel_is_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
771a395240 |
||
|
|
d83ddc05c8 | ||
|
|
8e8cac4b0f | ||
|
|
57199fd9de | ||
|
|
2193d0edfa | ||
|
|
77adccb925 |
10 changed files with 92 additions and 47 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
import time
|
import time
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
|
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, CallInfo
|
||||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||||
from tinygrad.device import Buffer, MultiBuffer
|
from tinygrad.device import Buffer, MultiBuffer
|
||||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||||
|
|
@ -28,8 +28,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||||
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
|
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
|
||||||
assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
|
assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
|
||||||
in_degree.setdefault(k, 0)
|
in_degree.setdefault(k, 0)
|
||||||
if k.op is Ops.END: assert k.src[0].op is Ops.KERNEL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
for s in k.src[0].src if k.op is Ops.END else k.src[1:]:
|
||||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
|
||||||
match (s := _unwrap_src(s)).op:
|
match (s := _unwrap_src(s)).op:
|
||||||
case Ops.AFTER:
|
case Ops.AFTER:
|
||||||
children.setdefault(s.src[1], []).append(k)
|
children.setdefault(s.src[1], []).append(k)
|
||||||
|
|
@ -54,14 +53,13 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||||
while len(queue):
|
while len(queue):
|
||||||
k = rk = queue.popleft()
|
k = rk = queue.popleft()
|
||||||
if k.op is Ops.END: k = k.src[0]
|
if k.op is Ops.END: k = k.src[0]
|
||||||
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
|
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
|
||||||
if k.op is Ops.RANGE: schedule.append(k)
|
if k.op is Ops.RANGE: schedule.append(k)
|
||||||
elif k.op is Ops.KERNEL:
|
elif k.op is Ops.CALL:
|
||||||
ast = (kernel:=cast(Kernel, k.arg)).ast
|
ast = k.src[0]
|
||||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
|
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||||
sched_item[k] = (ast, buf_uops, kernel.metadata, bound_ranges)
|
schedule.append((ast, buf_uops, cast(CallInfo, k.arg).metadata, {}, bound_ranges))
|
||||||
schedule.append(k)
|
|
||||||
if rk.op is Ops.END: schedule.append(rk)
|
if rk.op is Ops.END: schedule.append(rk)
|
||||||
for x in children.get(rk, []):
|
for x in children.get(rk, []):
|
||||||
in_degree[x] -= 1
|
in_degree[x] -= 1
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
||||||
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
||||||
|
|
||||||
def call_gradient(ctx:UOp, k:UOp):
|
def call_gradient(ctx:UOp, k:UOp):
|
||||||
if k.arg is not None: return (None,) + k.arg(ctx, k)
|
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
|
||||||
# auto-differentiate the function
|
# auto-differentiate the function
|
||||||
fxn, args = k.src[0], k.src[1:]
|
fxn, args = k.src[0], k.src[1:]
|
||||||
params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||||
|
|
||||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL, Ops.ENCDEC}
|
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
|
||||||
|
|
||||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||||
|
|
||||||
|
|
@ -20,7 +20,7 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||||
# if it's a kernel, we don't realize it
|
# if it's a kernel, we don't realize it
|
||||||
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
if a.src[1].op is not Ops.CALL: ctx[a] = None
|
||||||
|
|
||||||
pm_generate_realize_map = PatternMatcher([
|
pm_generate_realize_map = PatternMatcher([
|
||||||
# always realize SINK src
|
# always realize SINK src
|
||||||
|
|
@ -99,7 +99,7 @@ def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
||||||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||||||
|
|
||||||
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
||||||
if assign.src[1].op is Ops.KERNEL: return None
|
if assign.src[1].op is Ops.CALL: return None
|
||||||
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
|
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
|
||||||
ret = assign.replace(src=assign.src+(to_mop,))
|
ret = assign.replace(src=assign.src+(to_mop,))
|
||||||
ctx.range_map[ret] = ctx.range_map[assign]
|
ctx.range_map[ret] = ctx.range_map[assign]
|
||||||
|
|
@ -174,7 +174,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||||
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||||||
|
|
||||||
# no ranges on kernels, they are internal
|
# no ranges on kernels, they are internal
|
||||||
if x.op is Ops.KERNEL: continue
|
if x.op is Ops.CALL: continue
|
||||||
|
|
||||||
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||||||
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])
|
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,10 @@ from dataclasses import dataclass, field, replace
|
||||||
import itertools
|
import itertools
|
||||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
|
||||||
from tinygrad.uop.symbolic import symbolic
|
from tinygrad.uop.symbolic import symbolic
|
||||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
from tinygrad.helpers import PCONTIG, partition, get_single_element, panic
|
||||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||||
from tinygrad.codegen.opt import Opt
|
from tinygrad.codegen.opt import Opt
|
||||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||||
|
|
@ -66,9 +66,11 @@ mop_cleanup = PatternMatcher([
|
||||||
|
|
||||||
def resolve_custom_kernel(ck:UOp) -> UOp:
|
def resolve_custom_kernel(ck:UOp) -> UOp:
|
||||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
return ck.arg.fxn(*placeholders).call(*ck.src)
|
||||||
|
|
||||||
def resolve_call(c:UOp) -> UOp:
|
def resolve_call(c:UOp) -> UOp|None:
|
||||||
|
# don't resolve CALLs with SINK - those are kernel calls from split_store/custom_kernel
|
||||||
|
if c.src[0].op is Ops.SINK: return None
|
||||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||||
args = c.src[1:]
|
args = c.src[1:]
|
||||||
# TODO: this check belongs in spec, not here
|
# TODO: this check belongs in spec, not here
|
||||||
|
|
@ -83,7 +85,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||||
# just removing it works...
|
# just removing it works...
|
||||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||||
|
|
||||||
# resolve calls
|
# resolve calls (but not CALLs with SINK - those are kernel calls from split_store)
|
||||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||||
|
|
||||||
# resolve custom kernels
|
# resolve custom kernels
|
||||||
|
|
@ -384,8 +386,8 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||||
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
|
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
|
||||||
|
|
||||||
# remove any RESHAPEs on KERNEL
|
# remove any RESHAPEs on CALL
|
||||||
(UPat(Ops.KERNEL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
|
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=(k.src[0],)+tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src[1:]))),
|
||||||
])
|
])
|
||||||
|
|
||||||
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||||
|
|
@ -515,16 +517,17 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||||
else: raise RuntimeError(f"unknown kernel type {ret.op}")
|
else: raise RuntimeError(f"unknown kernel type {ret.op}")
|
||||||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
|
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
|
||||||
else:
|
else:
|
||||||
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None)
|
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else KernelInfo())
|
||||||
|
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
|
||||||
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
|
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
|
||||||
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
|
||||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
|
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
|
||||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
|
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
split_kernels = PatternMatcher([
|
split_kernels = PatternMatcher([
|
||||||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||||
|
# if it's a Kernel, stop
|
||||||
|
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate()) if isinstance(sink.arg, KernelInfo) else None),
|
||||||
])
|
])
|
||||||
|
|
||||||
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
||||||
|
|
@ -537,7 +540,7 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
||||||
return x.replace(tag=(len(ctx[0])-1,))
|
return x.replace(tag=(len(ctx[0])-1,))
|
||||||
add_tags = PatternMatcher([
|
add_tags = PatternMatcher([
|
||||||
# don't tag BUFFERs, they are global
|
# don't tag BUFFERs, they are global
|
||||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
|
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
|
||||||
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
||||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
||||||
])
|
])
|
||||||
|
|
@ -581,7 +584,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||||
# bufferize -> store
|
# bufferize -> store
|
||||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||||
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
||||||
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
|
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels", bottom_up=True)
|
||||||
|
|
||||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||||
kernel_assign: dict[UOp, UOp] = {}
|
kernel_assign: dict[UOp, UOp] = {}
|
||||||
|
|
|
||||||
|
|
@ -240,7 +240,7 @@ class Tensor(OpMixin):
|
||||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||||
return Tensor(param, device=self.device)
|
return Tensor(param, device=self.device)
|
||||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||||
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device)
|
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
|
||||||
|
|
||||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -364,7 +364,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
@recursive_property
|
@recursive_property
|
||||||
def trace_num(self):
|
def trace_num(self):
|
||||||
num = next(ucount)
|
num = next(ucount)
|
||||||
# KERNEL also has a UOp in the arg
|
# KERNEL has a UOp in the arg, CALL has it in src[0] so no special handling needed
|
||||||
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
|
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
|
||||||
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
|
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
|
||||||
return num
|
return num
|
||||||
|
|
@ -818,7 +818,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
|
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
|
||||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||||
|
|
||||||
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)
|
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||||
|
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||||
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
||||||
|
|
@ -843,6 +844,14 @@ class CustomKernel:
|
||||||
def __reduce__(self): return (CustomKernel, (panic,))
|
def __reduce__(self): return (CustomKernel, (panic,))
|
||||||
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
|
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CallInfo:
|
||||||
|
grad_fxn: Callable|None = None
|
||||||
|
metadata: tuple[Metadata, ...] = ()
|
||||||
|
# grad_fxn can't be pickled, but metadata can
|
||||||
|
def __reduce__(self): return (CallInfo, (None, self.metadata))
|
||||||
|
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Kernel:
|
class Kernel:
|
||||||
ast: UOp
|
ast: UOp
|
||||||
|
|
@ -1405,7 +1414,7 @@ pm_pyrender_extra = PatternMatcher([
|
||||||
|
|
||||||
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
|
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
|
||||||
pm_pyrender = pm_pyrender_extra+PatternMatcher([
|
pm_pyrender = pm_pyrender_extra+PatternMatcher([
|
||||||
(UPat(Ops.KERNEL, name="u"), lambda ctx,u: f"UOp(Ops.KERNEL, src={srcs(ctx,u.src)}, arg=Kernel({ctx[u.arg.ast]}(), {u.arg.metadata}))"),
|
(UPat(Ops.CALL, name="u"), lambda ctx,u: f"{ctx[u.src[0]]}.call({', '.join(ctx[s] for s in u.src[1:])}, metadata={u.arg.metadata})"),
|
||||||
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
|
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
@ -1415,7 +1424,7 @@ def pyrender(ast:UOp) -> str:
|
||||||
cmap = consumer_map_from_toposort(lst)
|
cmap = consumer_map_from_toposort(lst)
|
||||||
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
|
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
|
||||||
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
|
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
|
||||||
Ops.BUFFER, Ops.COPY, Ops.KERNEL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||||
|
|
||||||
to_render: set[UOp] = {ast}
|
to_render: set[UOp] = {ast}
|
||||||
for u in lst:
|
for u in lst:
|
||||||
|
|
|
||||||
|
|
@ -87,8 +87,9 @@ _tensor_spec = PatternMatcher([
|
||||||
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||||
|
|
||||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
# CALL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
# src[0] is the function (SINK), src[1:] are buffers/bindings
|
||||||
|
(UPat(Ops.CALL, src=(UPat(Ops.SINK), UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), allow_any_len=True), lambda: True),
|
||||||
|
|
||||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||||
|
|
@ -249,8 +250,8 @@ full_spec = PatternMatcher([
|
||||||
# vectorized index
|
# vectorized index
|
||||||
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
|
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
|
||||||
|
|
||||||
# linearizer: outputs + intermediate KERNELs
|
# linearizer: outputs + intermediate CALLs
|
||||||
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
|
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
|
||||||
|
|
||||||
# Invalid must have type Index
|
# Invalid must have type Index
|
||||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||||
|
|
|
||||||
|
|
@ -251,7 +251,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||||
# only RANGE/IF/STORE/KERNEL have side effects
|
# only RANGE/IF/STORE/KERNEL have side effects
|
||||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
||||||
# after with 1 src is just src[0]
|
# after with 1 src is just src[0]
|
||||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||||
# VECTORIZE/CONST
|
# VECTORIZE/CONST
|
||||||
|
|
|
||||||
|
|
@ -738,12 +738,13 @@ window.addEventListener("popstate", (e) => {
|
||||||
});
|
});
|
||||||
|
|
||||||
const createToggle = (id, text) => {
|
const createToggle = (id, text) => {
|
||||||
const label = d3.create("label").text(text).node();
|
const label = d3.create("label").style("display", "block").text(text).node();
|
||||||
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
|
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
|
||||||
label.prepend(toggle);
|
label.prepend(toggle);
|
||||||
return { toggle, label };
|
return { toggle, label };
|
||||||
}
|
}
|
||||||
const { toggle, label:toggleLabel } = createToggle("show-indexing", "Show indexing (r)");
|
const showIndexing = createToggle("show-indexing", "Show indexing (r)");
|
||||||
|
const showCallSrc = createToggle("show-call-src", "Show CALL src (c)");
|
||||||
const showGraph = createToggle("show-graph", "Show graph (g)");
|
const showGraph = createToggle("show-graph", "Show graph (g)");
|
||||||
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
|
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
|
||||||
|
|
||||||
|
|
@ -893,11 +894,13 @@ async function main() {
|
||||||
// ** center graph
|
// ** center graph
|
||||||
const data = ret[currentRewrite];
|
const data = ret[currentRewrite];
|
||||||
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
|
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
|
||||||
render({ showIndexing:toggle.checked });
|
const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked });
|
||||||
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
|
render(getOpts());
|
||||||
|
showIndexing.toggle.onchange = () => render(getOpts());
|
||||||
|
showCallSrc.toggle.onchange = () => render(getOpts());
|
||||||
// ** right sidebar metadata
|
// ** right sidebar metadata
|
||||||
metadata.innerHTML = "";
|
metadata.innerHTML = "";
|
||||||
if (ckey.includes("rewrites")) metadata.appendChild(toggleLabel);
|
if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label);
|
||||||
if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }));
|
if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }));
|
||||||
if (step.trace) {
|
if (step.trace) {
|
||||||
const trace = d3.create("pre").append("code").classed("hljs", true);
|
const trace = d3.create("pre").append("code").classed("hljs", true);
|
||||||
|
|
@ -1025,7 +1028,9 @@ document.addEventListener("keydown", (event) => {
|
||||||
document.getElementById("zoom-to-fit-btn").click();
|
document.getElementById("zoom-to-fit-btn").click();
|
||||||
}
|
}
|
||||||
// r key toggles indexing
|
// r key toggles indexing
|
||||||
if (event.key === "r") toggle.click();
|
if (event.key === "r") showIndexing.toggle.click();
|
||||||
|
// c key toggles CALL src
|
||||||
|
if (event.key === "c") showCallSrc.toggle.click();
|
||||||
// g key toggles graph
|
// g key toggles graph
|
||||||
if (event.key === "g") showGraph.toggle.click();
|
if (event.key === "g") showGraph.toggle.click();
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -55,13 +55,42 @@ const layoutUOp = (g, { graph, change }, opts) => {
|
||||||
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
|
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
|
||||||
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
|
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
|
||||||
}
|
}
|
||||||
// optionally hide nodes from the layuot
|
// optionally hide nodes from the layout
|
||||||
if (!opts.showIndexing) {
|
if (!opts.showIndexing) {
|
||||||
for (const n of g.nodes()) {
|
for (const n of g.nodes()) {
|
||||||
const node = g.node(n);
|
const node = g.node(n);
|
||||||
if (node.label.includes("dtypes.index")) g.removeNode(n);
|
if (node.label.includes("dtypes.index")) g.removeNode(n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!opts.showCallSrc) {
|
||||||
|
// remove edges from src[0] to CALL nodes, track affected nodes
|
||||||
|
const disconnected = new Set();
|
||||||
|
for (const n of g.nodes()) {
|
||||||
|
const node = g.node(n);
|
||||||
|
if (node?.label?.startsWith("CALL\n") || node?.label === "CALL") {
|
||||||
|
for (const pred of (g.predecessors(n) || [])) {
|
||||||
|
const edge = g.edge(pred, n);
|
||||||
|
if (edge?.label?.text === 0) {
|
||||||
|
g.removeEdge(pred, n);
|
||||||
|
disconnected.add(pred);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// remove nodes that are now disconnected (no successors), only from affected subtree
|
||||||
|
let changed = true;
|
||||||
|
while (changed) {
|
||||||
|
changed = false;
|
||||||
|
for (const n of disconnected) {
|
||||||
|
if (!g.hasNode(n)) continue;
|
||||||
|
if ((g.successors(n) || []).length === 0) {
|
||||||
|
for (const pred of (g.predecessors(n) || [])) disconnected.add(pred);
|
||||||
|
g.removeNode(n);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
dagre.layout(g);
|
dagre.layout(g);
|
||||||
// remove overlay node if it's empty
|
// remove overlay node if it's empty
|
||||||
if (!g.node("overlay")?.width) g.removeNode("overlay");
|
if (!g.node("overlay")?.width) g.removeNode("overlay");
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue