tinygrad/tinygrad/engine/schedule.py
Louis Novy 2ac5aec66b
Fix exponential complexity in _is_padding_okay [pr] (#7008)
* preliminary test

* missed Optional

* don't check for cache during recursion

* match style from st_fixup... may be marginally faster?

* pathological test case: strongly connected DAG

* move to test_schedule as this isn't really a fusion

* oops this shouldn't be edited

* Revert "oops this shouldn't be edited"

This reverts commit 487cb027dc.

* Revert "move to test_schedule as this isn't really a fusion"

This reverts commit 48d8c550ce.

* move to test_schedule as this isn't really a fusion

* ok no more merge error funny business
2024-10-14 02:34:47 +03:00

450 lines
25 KiB
Python

import sys, pickle, atexit
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, resolve, \
graph_rewrite, track_rewrites, Variable, sint
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
from tinygrad.shape.view import View, strides_for_shape
# creation can recurse a lot
sys.setrecursionlimit(10000)
BUF_LIMIT = {"METAL":32}
METAOPS = {MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW}
# *** ScheduleItem return type ***
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: Tuple[Buffer, ...]
metadata: Optional[Tuple[Metadata, ...]]
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1]
@property
def inputs(self) -> Tuple[Buffer, ...]:
"""Read only buffers in the schedule."""
return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
@dataclass(frozen=True)
class LBScheduleItem:
ast: UOp
bufs: Tuple[LazyBuffer, ...]
metadata: Optional[Tuple[Metadata, ...]]
@property
def outputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1]
@property
def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
# *** UOp with SWIZZLE (movementops) rewriting to UOp we can index ***
# ** helpers for doing movementops on uops
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
if (n:=cache.get(u)) is not None: return n
if u.op is UOps.VIEW:
new_st = apply_to_st(u.arg)
return u if u.arg == new_st else UOp(UOps.VIEW, dtypes.void, (), new_st)
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
new_srcs = tuple(st_fixup(x, apply_to_st, cache) for x in u.src)
cache[u] = ret = u if new_srcs == u.src else UOp(u.op, u.dtype, new_srcs, u.arg)
return ret
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
tmp = input_st.permute(permute_axis)
return tmp, tmp.shape[-len(axis):]
# ** reduceop fusor
def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
if (swizzle_st:=unwrap(swizzle.st)).contiguous: return None
rsrc = reduceop.src[0]
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), reduceop.axis_arg)
prshape = prod(rshape)
strides = strides_for_shape(rshape)
nv: List[View] = []
for v in swizzle_st.views:
nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
# update input_st and axis
new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = permute_reduce(new_input_st, reduceop.axis_arg)
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
(reduceop.arg[0], new_axis)).view(ShapeTracker.from_shape(swizzle_st.shape))
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
assert swizzle_st.contiguous, "can't push a non contiguous SWIZZLE down to STORE"
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
output_shape = swizzle_st.reduce(root.axis_arg)
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, (root.arg[0], new_axis)).view(ShapeTracker.from_shape(output_shape))
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzles = [x for x in root.src if x.op is UOps.VIEW and len(x.src) != 0]
if len(swizzles) == 0: return None
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
new_shape, new_input_shape = swizzle_shapes[0]
fixup_cache: Dict[UOp, UOp] = {}
new_srcs = [x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src]
ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg)
return ret if ret.op is UOps.STORE else ret.view(ShapeTracker.from_shape(new_shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
reduceop_fusor = PatternMatcher([
# SWIZZLE on VALID merges the views
(UPat(UOps.VIEW, src=(UPat(UOps.ALU, src=(UPat(UOps.VALID), UPat.var(), UPat.var()), name="alu", arg=TernaryOps.WHERE),), name="root"),
lambda root,alu: UOp(UOps.VALID, dtypes.bool, (root.st.to_uop(),)).where(*alu.src[1:]) if root.st != alu.st else alu),
# push a SWIZZLE up to LOAD, through a reduce (eg. expands)
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.index(x.arg[0])))])
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, Tuple[int, ...], UOp]] = []
if getenv("RUN_PROCESS_REPLAY"):
@atexit.register
def save_process_replay():
for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret))
@track_rewrites
def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp:
sink = graph_rewrite(base_sink, reduceop_fusor)
ret = graph_rewrite(sink, enumerate_bufs, bufs)
PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret))
return ret
# *** List[LazyBuffer] lowering to ScheduleItem ***
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer],
buf_uops:Dict[Buffer, UOp], assign_targets:Dict[LazyBuffer, LazyBuffer],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
"""recursively create a UOp"""
if buf is not buf.base: st, buf = buf.st+st, buf.base
if (buf, st) in cache: return cache[(buf, st)]
assert buf.op is not None, "base must be a base itself"
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
# buffer ops define ShapeTracker
# if it's realized, it's a load and we add it to the inputs
if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if buf.op is MetaOps.CONST:
if isinstance(val:=buf.arg, UOp): var_vals.update([val.unbind()])
return ubuf.view(unbound_st)
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
# only reduceop changes shape
src_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st
src: List[UOp] = [_recursive_uop(x, src_st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs]
if buf.op in ReduceOps: ret = UOp(UOps.REDUCE_AXIS, dtype, tuple(src), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).view(st)
elif buf.op is MetaOps.CONTIGUOUS:
assert buf in outputs, f"{buf.op} must be writable"
ret = src[0]
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (src[1].src[0], src[0]))
elif buf.op is UnaryOps.CAST: ret = src[0].cast(dtype)
elif buf.op is UnaryOps.BITCAST: ret = src[0].bitcast(dtype)
else: ret = UOp(UOps.ALU, dtype, tuple(src), buf.op)
cache[(buf, st)] = ret
return ret
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op in METAOPS:
return LBScheduleItem(UOp(METAOPS[cast(MetaOps, out.op)], out.dtype, (), out.arg), (out,)+tuple(x.base for x in out.srcs),
(out.metadata,) if out.metadata is not None else None), {}
# create the stores
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
ast: List[UOp] = []
inputs: List[LazyBuffer] = []
for out in outs:
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, buf_uops, assign_targets, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0]
output_st, vv = output_st.simplify().unbind()
var_vals.update(vv)
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs))
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
# *** DAG creation: decide which LazyBuffers should realize ***
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],
double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
if buf in allbufs: return None
if buf.base.realized is not None: return realizes.setdefault(buf.base)
if GRAPH:
from tinygrad.engine.graph import log_lazybuffer
log_lazybuffer(buf, scheduled)
# check if we need to realize views
if buf is not buf.base:
# fuse some pads
if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
resolve(prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask])):
simple_pads[buf.base] = None
# realize all expands
elif resolve(prod(buf.base.st.shape) < prod(buf.st.shape)):
# this was causing "test_lil_model" to fail
if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
simple_pads[buf.base] = None # don't realize image to image casts. this is part of a larger problem
else: realizes[buf.base] = None
# check all other pads for safe fusion
elif any(v.mask is not None for v in buf.st.views): simple_pads[buf.base] = None
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces)
if buf.op in ReduceOps and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None
allbufs[buf] = None
if buf.forced_realize or buf.op in MetaOps: realizes[buf] = None
if buf.op is MetaOps.ASSIGN:
assert buf.srcs[1].base is buf.srcs[1], f"assign must be to base {buf.srcs[1]}"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
assign_targets[buf.srcs[1]] = buf
if buf.op is MetaOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes[buf.srcs[0].base] = None
if buf.op is MetaOps.VIEW: realizes[buf.srcs[0].base] = None
for x in buf.srcs:
if x.base.realized is None: children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces)
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], cache:Dict[LazyBuffer, bool]) -> bool:
if (n:=cache.get(buf)) is not None: return n
if buf in realizes: return True
# NOTE: this broke to_image_idx and coder with JIT
if buf.op in UNSAFE_PAD_OPS: return False
cache[buf] = ret = all(_is_padding_okay(x.base, realizes, cache) for x in buf.srcs)
return ret
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], None]) -> None:
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
if (tr, st) in cache: return
cache.setdefault((tr, st))
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
if tr_next.op in ReduceOps: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r)
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],\
realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]:
rc_parents, cache = deque(group), set()
while rc_parents:
if (p:=rc_parents.pop()) in cache: continue
cache.add(p)
# max one reduceop per kernel
if p.op in ReduceOps: return {}
rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
# search descendants of the reduceop that can cleanly group
descendants: Dict[LazyBuffer, None] = {}
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
def _get_output_groups(outs:List[LazyBuffer]) -> \
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups
Dict[Buffer, UOp], # this is a map of realized Buffers to UOps.BUFFER
Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule
"""find all the realizes in the graph, group the output LazyBuffers into kernels."""
# start by just realizing the buffers passed in
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Dict[LazyBuffer, None] = {}
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
double_reduces: Dict[LazyBuffer, None] = {}
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, assign_targets, double_reduces, scheduled=True)
# check if we have to realize pads
for p in simple_pads:
if not _is_padding_okay(p, realizes, {}):
realizes[p] = None
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
reduce_of_const: List[LazyBuffer] = []
for r in allbufs:
if r.op not in ReduceOps or r in realizes: continue
group: Dict[LazyBuffer, None] = {}
_recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
if not forced_realize and len(group) > 1:
group = _get_isolated_children(r, reduce_for_op, children, realizes, group)
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group):
parents = deque((r, *group))
while parents and not forced_realize:
if (p:=parents.pop().base).realized or p in realizes:
if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
continue
parents.extend(p.srcs)
if forced_realize or not group:
tr = r
if can_chase:
# can chase this down to contiguous children
st = tr.st
while len(children[tr]) == 1:
tr_next = next(iter(children[tr]))
st_childs = dedup(s for s in tr_next.srcs if s.base is tr)
if len(st_childs) > 1: break
if st.size != st_childs[0].st.size: break
st = st + st_childs[0].st
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
tr = tr.srcs[0].base
reduce_for_op[tr] = r
realizes[tr] = None
else: reduce_for_op.update((tr, r) for tr in group)
if FUSE_ARANGE and r.op is ReduceOps.SUM and r.srcs[0].base.op is MetaOps.CONST: reduce_of_const.append(r)
# fuse double reduces with no other child
if FUSE_CONV_BW:
for reduceop in double_reduces:
top_reduce = reduceop.base.srcs[0].base
if len(children[top_reduce]) == 1: del realizes[top_reduce]
for r in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
if DEBUG_ARANGE:=(getenv("DEBUG_ARANGE")): print(f"checking {r} {group=}")
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.VIEW}}
if len(kernel_children) == 0: continue
if DEBUG_ARANGE: print(colored(f"folding {r}", "green"))
for tr in group: del realizes[tr]
output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
buf_uops: Dict[Buffer, UOp] = {}
for buf in realizes:
if buf.realized is None and buf.op is not MetaOps.CONST:
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
# make things that can't be images not images
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
buf.dtype = dtypes.float32
# hack the underlying buffer too
if buf.base is buf:
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
buf.buffer.dtype = dtypes.float32
buf.buffer.options = None
if buf.op is MetaOps.CONST:
uop = UOp(UOps.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(buf.dtype.scalar(), buf.arg), v.const_like(0))
# NOTE: UOps.BUFFER creation must come after the ImageDType fixup
else: uop = UOp(UOps.BUFFER, buf.buffer.dtype.ptr(), (), (len(buf_uops), (buf.buffer.device, buf.buffer.size, buf.buffer.dtype)))
buf_uops.setdefault(buf.buffer, uop)
return output_groups, buf_uops, assign_targets
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
def _graph_schedule(outs:List[LazyBuffer]) -> \
Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph
DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph
Dict[Variable, int]]: # this has all the var values of the schedule
"""create a graph for realizing the outputs"""
output_groups, buf_uops, assign_targets = _get_output_groups(outs)
# preschedule all buffers in realizes
prescheduled: List[LBScheduleItem] = []
var_vals: Dict[Variable, int] = {}
for group in output_groups.values():
prescheduled.append((ret:=_lower_lazybuffer(group, buf_uops))[0])
var_vals = merge_dicts([var_vals, ret[1]])
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
in_degree: DefaultDict[LBScheduleItem, int] = defaultdict(int)
for lsi in prescheduled:
if lsi not in in_degree: in_degree[lsi] = 0
# realize outputs after all parents are realized
scheduled_parents = dedup(schedule_targets[x] for x in lsi.inputs if x in schedule_targets)
for x in scheduled_parents:
graph[x].append(lsi)
in_degree[lsi] += 1
# realize outputs before a parent is assigned to
parents_assigns = dedup(schedule_targets[assign_targets[x]] for x in lsi.inputs if x in assign_targets)
for assign in parents_assigns:
graph[lsi].append(assign)
in_degree[assign] += 1
if SAVE_SCHEDULE:
def _save():
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
if len(SCHEDULES) == 0: atexit.register(_save)
SCHEDULES.append((graph, in_degree))
return graph, in_degree, var_vals
# *** DAG ordering: breadth first search ***
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
graph, in_degree, var_vals = _graph_schedule(outs)
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
schedule: List[ScheduleItem] = []
kernel_number = GlobalCounters.kernel_count
while queue:
lsi = queue.popleft()
if GRAPH:
kernel_number += 1
from tinygrad.engine.graph import realized_lazybuffer
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
for out in lsi.outputs: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.bufs if x.size != 0), lsi.metadata))
if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:
if DEBUG >= 3: print(si)
raise RuntimeError(f"Kernel for {si.metadata} exceeded the {m} buffer count limit for {device} with {len(si.bufs)} buffers.")
for x in graph[lsi]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# confirm everything was scheduled correctly
if any(degree != 0 for degree in in_degree.values()) or len(in_degree) != len(schedule):
raise RuntimeError(f"cycle detected in graph, prescheduled {len(in_degree)} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
return schedule, var_vals
def create_schedule(outs:List[LazyBuffer]) -> List[ScheduleItem]:
schedule, var_vals = create_schedule_with_vars(outs)
assert len(var_vals) == 0
return schedule