mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
resolve_call
This commit is contained in:
parent
8e8cac4b0f
commit
d83ddc05c8
6 changed files with 32 additions and 36 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import time
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, CallInfo
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||
|
|
@ -29,7 +29,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
|
||||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src[1:]:
|
||||
match (s := _unwrap_src(s)).op:
|
||||
case Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
|
|
@ -55,13 +55,13 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
while len(queue):
|
||||
k = rk = queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
|
||||
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.KERNEL:
|
||||
ast = (kernel:=cast(Kernel, k.arg)).ast
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
schedule.append((ast, buf_uops, kernel.metadata, {}, bound_ranges))
|
||||
elif k.op is Ops.CALL:
|
||||
ast = k.src[0]
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
schedule.append((ast, buf_uops, cast(CallInfo, k.arg).metadata, {}, bound_ranges))
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
for x in children.get(rk, []):
|
||||
in_degree[x] -= 1
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
|||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL, Ops.ENCDEC}
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
|
||||
|
||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
|
|
@ -20,7 +20,7 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
|||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||
# if it's a kernel, we don't realize it
|
||||
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
||||
if a.src[1].op is not Ops.CALL: ctx[a] = None
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize SINK src
|
||||
|
|
@ -99,7 +99,7 @@ def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
|||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||||
|
||||
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
||||
if assign.src[1].op is Ops.KERNEL: return None
|
||||
if assign.src[1].op is Ops.CALL: return None
|
||||
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
|
||||
ret = assign.replace(src=assign.src+(to_mop,))
|
||||
ctx.range_map[ret] = ctx.range_map[assign]
|
||||
|
|
@ -174,7 +174,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||||
|
||||
# no ranges on kernels, they are internal
|
||||
if x.op is Ops.KERNEL: continue
|
||||
if x.op is Ops.CALL: continue
|
||||
|
||||
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||||
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
|
|||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element, panic
|
||||
|
|
@ -66,9 +66,11 @@ mop_cleanup = PatternMatcher([
|
|||
|
||||
def resolve_custom_kernel(ck:UOp) -> UOp:
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||
return ck.arg.fxn(*placeholders).call(*ck.src)
|
||||
|
||||
def resolve_call(c:UOp) -> UOp:
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
# don't resolve CALLs with SINK - those are kernel calls from split_store/custom_kernel
|
||||
if c.src[0].op is Ops.SINK: return None
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
|
|
@ -83,7 +85,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# resolve calls
|
||||
# resolve calls (but not CALLs with SINK - those are kernel calls from split_store)
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
# resolve custom kernels
|
||||
|
|
@ -384,8 +386,8 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
|||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
|
||||
|
||||
# remove any RESHAPEs on KERNEL
|
||||
(UPat(Ops.KERNEL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
|
||||
# remove any RESHAPEs on CALL
|
||||
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=(k.src[0],)+tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src[1:]))),
|
||||
])
|
||||
|
||||
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
|
|
@ -519,20 +521,13 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
|||
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
|
||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel[1:].src)}")
|
||||
return kernel
|
||||
|
||||
# old
|
||||
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
|
||||
return kernel
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||
# if it's a Kernel, stop
|
||||
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate) if isinstance(sink.arg, KernelInfo) else None),
|
||||
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate()) if isinstance(sink.arg, KernelInfo) else None),
|
||||
])
|
||||
|
||||
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
||||
|
|
@ -545,7 +540,7 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
|||
return x.replace(tag=(len(ctx[0])-1,))
|
||||
add_tags = PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
|
||||
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -364,7 +364,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
@recursive_property
|
||||
def trace_num(self):
|
||||
num = next(ucount)
|
||||
# KERNEL also has a UOp in the arg
|
||||
# KERNEL has a UOp in the arg, CALL has it in src[0] so no special handling needed
|
||||
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
|
||||
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
|
||||
return num
|
||||
|
|
@ -1414,7 +1414,7 @@ pm_pyrender_extra = PatternMatcher([
|
|||
|
||||
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
|
||||
pm_pyrender = pm_pyrender_extra+PatternMatcher([
|
||||
(UPat(Ops.KERNEL, name="u"), lambda ctx,u: f"UOp(Ops.KERNEL, src={srcs(ctx,u.src)}, arg=Kernel({ctx[u.arg.ast]}(), {u.arg.metadata}))"),
|
||||
(UPat(Ops.CALL, name="u"), lambda ctx,u: f"{ctx[u.src[0]]}.call({', '.join(ctx[s] for s in u.src[1:])}, metadata={u.arg.metadata})"),
|
||||
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
|
||||
])
|
||||
|
||||
|
|
@ -1424,7 +1424,7 @@ def pyrender(ast:UOp) -> str:
|
|||
cmap = consumer_map_from_toposort(lst)
|
||||
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
|
||||
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
|
||||
Ops.BUFFER, Ops.COPY, Ops.KERNEL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||
|
||||
to_render: set[UOp] = {ast}
|
||||
for u in lst:
|
||||
|
|
|
|||
|
|
@ -87,8 +87,9 @@ _tensor_spec = PatternMatcher([
|
|||
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
# CALL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
# src[0] is the function (SINK), src[1:] are buffers/bindings
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.SINK), UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), allow_any_len=True), lambda: True),
|
||||
|
||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
|
|
@ -249,8 +250,8 @@ full_spec = PatternMatcher([
|
|||
# vectorized index
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
|
||||
|
||||
# linearizer: outputs + intermediate KERNELs
|
||||
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
|
||||
# linearizer: outputs + intermediate CALLs
|
||||
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
# VECTORIZE/CONST
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue