mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8395071f77 | ||
|
|
de3e901b71 | ||
|
|
ae2410e10e |
3 changed files with 206 additions and 40 deletions
120
test/unit/test_callify.py
Normal file
120
test/unit/test_callify.py
Normal file
|
|
@ -0,0 +1,120 @@
|
||||||
|
import unittest
|
||||||
|
from tinygrad import Tensor, dtypes
|
||||||
|
|
||||||
|
class TestCallify(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
a = Tensor([1.,2,3])
|
||||||
|
b = Tensor([4.,5,6])
|
||||||
|
out = a + b
|
||||||
|
out.callify()
|
||||||
|
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
|
||||||
|
|
||||||
|
def test_const(self):
|
||||||
|
out = Tensor(2.0) + Tensor(3.0)
|
||||||
|
out.callify()
|
||||||
|
self.assertEqual(out.item(), 5.0)
|
||||||
|
|
||||||
|
def test_sum(self):
|
||||||
|
out = Tensor.ones(16).contiguous().sum()
|
||||||
|
out.callify()
|
||||||
|
self.assertEqual(out.item(), 16.0)
|
||||||
|
|
||||||
|
def test_multi_output(self):
|
||||||
|
a = Tensor([1.,2,3])
|
||||||
|
b = Tensor([4.,5,6])
|
||||||
|
c = a + b
|
||||||
|
d = a * b
|
||||||
|
c.callify(d)
|
||||||
|
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
|
||||||
|
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
|
||||||
|
|
||||||
|
def test_two_callify_independent(self):
|
||||||
|
a = Tensor([1.,2,3])
|
||||||
|
b = Tensor([4.,5,6])
|
||||||
|
c = a + b
|
||||||
|
c.callify()
|
||||||
|
|
||||||
|
d = Tensor([10.,20,30])
|
||||||
|
e = Tensor([1.,1,1])
|
||||||
|
f = d - e
|
||||||
|
f.callify()
|
||||||
|
|
||||||
|
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
|
||||||
|
self.assertListEqual(f.tolist(), [9.0, 19.0, 29.0])
|
||||||
|
|
||||||
|
def test_two_callify_shared_input(self):
|
||||||
|
a = Tensor([1.,2,3]).contiguous().realize()
|
||||||
|
b = a + 1
|
||||||
|
b.callify()
|
||||||
|
c = a * 2
|
||||||
|
c.callify()
|
||||||
|
self.assertListEqual(b.tolist(), [2.0, 3.0, 4.0])
|
||||||
|
self.assertListEqual(c.tolist(), [2.0, 4.0, 6.0])
|
||||||
|
|
||||||
|
def test_chained_callify(self):
|
||||||
|
a = Tensor([1.,2,3])
|
||||||
|
b = a + 1
|
||||||
|
b.callify()
|
||||||
|
b.realize()
|
||||||
|
c = b + 1
|
||||||
|
c.callify()
|
||||||
|
self.assertListEqual(c.tolist(), [3.0, 4.0, 5.0])
|
||||||
|
|
||||||
|
def test_gemm(self):
|
||||||
|
a = Tensor.ones(8, 8).contiguous()
|
||||||
|
b = Tensor.eye(8).contiguous()
|
||||||
|
out = a @ b
|
||||||
|
out.callify()
|
||||||
|
lst = out.tolist()
|
||||||
|
for y in range(8):
|
||||||
|
for x in range(8):
|
||||||
|
self.assertEqual(lst[y][x], 1.0)
|
||||||
|
|
||||||
|
def test_int_dtype(self):
|
||||||
|
a = Tensor([1,2,3], dtype=dtypes.int)
|
||||||
|
b = Tensor([4,5,6], dtype=dtypes.int)
|
||||||
|
out = a + b
|
||||||
|
out.callify()
|
||||||
|
self.assertListEqual(out.tolist(), [5, 7, 9])
|
||||||
|
|
||||||
|
def test_callify_then_schedule(self):
|
||||||
|
a = Tensor([1.,2,3])
|
||||||
|
b = Tensor([4.,5,6])
|
||||||
|
out = a + b
|
||||||
|
out.callify()
|
||||||
|
schedule = out.schedule()
|
||||||
|
self.assertGreater(len(schedule), 0)
|
||||||
|
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
|
||||||
|
|
||||||
|
def test_reduce(self):
|
||||||
|
out = Tensor([1.,2,3,4]).sum()
|
||||||
|
out.callify()
|
||||||
|
self.assertEqual(out.item(), 10.0)
|
||||||
|
|
||||||
|
def test_multiple_ops(self):
|
||||||
|
a = Tensor([1.,2,3])
|
||||||
|
b = Tensor([4.,5,6])
|
||||||
|
out = (a + b) * (a - b)
|
||||||
|
out.callify()
|
||||||
|
self.assertListEqual(out.tolist(), [-15.0, -21.0, -27.0])
|
||||||
|
|
||||||
|
def test_double_callify(self):
|
||||||
|
a = Tensor([1.,2,3])
|
||||||
|
b = Tensor([4.,5,6])
|
||||||
|
out = a + b
|
||||||
|
out.callify()
|
||||||
|
out.callify()
|
||||||
|
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
|
||||||
|
|
||||||
|
def test_double_callify_multi_output(self):
|
||||||
|
a = Tensor([1.,2,3])
|
||||||
|
b = Tensor([4.,5,6])
|
||||||
|
c = a + b
|
||||||
|
d = a * b
|
||||||
|
c.callify(d)
|
||||||
|
c.callify(d)
|
||||||
|
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
|
||||||
|
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
import time, inspect
|
import time, inspect
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink
|
from tinygrad.uop.ops import UOp, Ops, KernelInfo, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink
|
||||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||||
from tinygrad.device import Buffer, MultiBuffer
|
from tinygrad.device import Buffer, MultiBuffer
|
||||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
|
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
|
||||||
from tinygrad.engine.realize import ExecItem
|
from tinygrad.engine.realize import ExecItem
|
||||||
from tinygrad.engine.allocations import transform_to_call
|
|
||||||
|
|
||||||
# **** schedule linearizer
|
# **** schedule linearizer
|
||||||
|
|
||||||
|
|
@ -68,43 +67,48 @@ def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
pm_post_sched_cache = PatternMatcher([
|
pm_post_sched_cache = PatternMatcher([
|
||||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
|
# tag=True prevents re-matching after replacement (needed when PARAMs replace with PARAMs in nested callify)
|
||||||
|
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg].replace(tag=True) if x.tag is None else None),
|
||||||
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
||||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
||||||
])
|
])
|
||||||
|
|
||||||
schedule_cache: dict[bytes, UOp] = {}
|
schedule_cache: dict[bytes, UOp] = {}
|
||||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
|
||||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
|
||||||
# big_sink srcs are all the Tensors
|
|
||||||
st = time.perf_counter()
|
|
||||||
big_sink, buffer_map = transform_to_call(big_sink)
|
|
||||||
function = big_sink.src[0]
|
|
||||||
|
|
||||||
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
|
def _resolve_params(linear:UOp, params:tuple[UOp, ...]) -> UOp:
|
||||||
if SPEC: type_verify(big_sink, tensor_spec)
|
"""Replace PARAMs in a LINEAR with the given params (BUFFERs or outer PARAMs), also handling LUNIQUE BUFFERs."""
|
||||||
|
from tinygrad.uop.ops import _remove_all_tags
|
||||||
|
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, params), name="params to buffers")
|
||||||
|
return graph_rewrite(linear, _remove_all_tags, name="remove tags")
|
||||||
|
|
||||||
|
def rewrite_call_to_linear(ctx:list, call:UOp) -> UOp|None:
|
||||||
|
"""Rewrite rule: CALL(SINK, *params) -> LINEAR(...) with caching. Only matches top-level CALLs from transform_to_call."""
|
||||||
|
function = call.src[0]
|
||||||
|
if function.op is not Ops.SINK or isinstance(function.arg, KernelInfo): return None
|
||||||
|
# recursively schedule any nested CALLs inside the function (from nested callify)
|
||||||
|
inner_start = len(ctx)
|
||||||
|
function = graph_rewrite(function, pm_schedule, ctx=ctx, name="schedule nested calls")
|
||||||
|
if not SCACHE or (linear:=schedule_cache.get(function.key, None)) is None:
|
||||||
|
if SPEC: type_verify(call.replace(src=(function,)+call.src[1:]), tensor_spec)
|
||||||
linear = create_schedule(get_kernel_graph(function))
|
linear = create_schedule(get_kernel_graph(function))
|
||||||
if SCACHE: schedule_cache[function.key] = linear
|
if SCACHE: schedule_cache[function.key] = linear
|
||||||
else:
|
# late apply params to buffers (tag=True prevents PARAM->PARAM cycles in nested callify)
|
||||||
# schedule cache hit
|
linear = _resolve_params(linear, call.src[1:])
|
||||||
linear = sc_ret
|
# resolve remaining PARAMs in inner LINEARs from nested CALLs using this call's params
|
||||||
|
for i in range(inner_start, len(ctx)):
|
||||||
|
inner_call, inner_linear = ctx[i]
|
||||||
|
ctx[i] = (inner_call, _resolve_params(inner_linear, call.src[1:]))
|
||||||
|
ctx.append((call, linear))
|
||||||
|
return linear
|
||||||
|
|
||||||
# it's a call that we late apply
|
pm_schedule = PatternMatcher([
|
||||||
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
|
(UPat(Ops.CALL, name="call"), rewrite_call_to_linear),
|
||||||
|
# strip AFTER(buf, LINEAR) -> buf after scheduling
|
||||||
|
(UPat(Ops.AFTER, src=(UPat(name="buf"), UPat(Ops.LINEAR))), lambda ctx,buf: buf),
|
||||||
|
])
|
||||||
|
|
||||||
# vars used in the schedule
|
def linear_to_schedule(linear:UOp) -> list[ExecItem]:
|
||||||
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
|
"""Convert a LINEAR UOp to a list of ExecItems."""
|
||||||
# get var_vals
|
|
||||||
var_vals: dict[str, int] = {}
|
|
||||||
for b in big_sink.src[1:]:
|
|
||||||
if b.op is Ops.BIND:
|
|
||||||
nm = b.src[0].expr
|
|
||||||
if nm not in used_vars: continue
|
|
||||||
val = b.src[1].arg
|
|
||||||
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
|
|
||||||
var_vals[nm] = val
|
|
||||||
|
|
||||||
# convert LINEAR to ExecItems
|
|
||||||
schedule: list[ExecItem] = []
|
schedule: list[ExecItem] = []
|
||||||
for si in linear.src:
|
for si in linear.src:
|
||||||
ast, buf_uops = si.src[0], si.src[1:]
|
ast, buf_uops = si.src[0], si.src[1:]
|
||||||
|
|
@ -122,6 +126,34 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||||
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
|
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
|
||||||
else:
|
else:
|
||||||
schedule.append(ExecItem(ast, list(ubufs), metadata))
|
schedule.append(ExecItem(ast, list(ubufs), metadata))
|
||||||
|
return schedule
|
||||||
|
|
||||||
|
# strip AFTER(buf, LINEAR) -> buf, used by _apply_map_to_tensors to clean up scope tensors after scheduling
|
||||||
|
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||||
|
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[UOp], list[ExecItem], dict[str, int]]:
|
||||||
|
st = time.perf_counter()
|
||||||
|
|
||||||
|
# rewrite CALLs to LINEARs and strip AFTERs
|
||||||
|
call_linear_pairs: list[tuple[UOp, UOp]] = []
|
||||||
|
graph_rewrite(big_sink, pm_schedule, ctx=call_linear_pairs, name="schedule calls")
|
||||||
|
|
||||||
|
# collect ExecItems from all LINEARs
|
||||||
|
schedule: list[ExecItem] = []
|
||||||
|
for _, linear in call_linear_pairs:
|
||||||
|
schedule.extend(linear_to_schedule(linear))
|
||||||
|
|
||||||
|
# get var_vals from CALL params
|
||||||
|
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for _, linear in call_linear_pairs for si in linear.src])
|
||||||
|
var_vals: dict[str, int] = {}
|
||||||
|
for call, _ in call_linear_pairs:
|
||||||
|
for b in call.src[1:]:
|
||||||
|
if b.op is Ops.BIND:
|
||||||
|
nm = b.src[0].expr
|
||||||
|
if nm not in used_vars: continue
|
||||||
|
val = b.src[1].arg
|
||||||
|
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
|
||||||
|
var_vals[nm] = val
|
||||||
|
|
||||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||||
|
|
||||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
|
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
|
||||||
|
|
@ -131,7 +163,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||||
else:
|
else:
|
||||||
frm = None
|
frm = None
|
||||||
print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
|
|
||||||
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
||||||
|
|
||||||
return buffer_map, schedule, var_vals
|
return [call for call, _ in call_linear_pairs], schedule, var_vals
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,11 @@ from tinygrad.gradient import compute_gradient
|
||||||
from tinygrad.mixin import OpMixin
|
from tinygrad.mixin import OpMixin
|
||||||
from tinygrad.mixin.movement import _align_left
|
from tinygrad.mixin.movement import _align_left
|
||||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
|
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
|
||||||
|
from tinygrad.uop.ops import PatternMatcher, UPat
|
||||||
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
|
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
|
||||||
from tinygrad.device import Device, Buffer
|
from tinygrad.device import Device, Buffer
|
||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule
|
||||||
|
from tinygrad.engine.allocations import transform_to_call
|
||||||
|
|
||||||
# TODO: this should be the only usage of Device
|
# TODO: this should be the only usage of Device
|
||||||
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
|
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
|
||||||
|
|
@ -25,7 +27,8 @@ def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
|
||||||
|
|
||||||
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
||||||
_pending_assigns: dict[UOp, list[UOp]] = {} # buffer_uop -> [assign_uops in insertion order]
|
_pending_assigns: dict[UOp, list[UOp]] = {} # buffer_uop -> [assign_uops in insertion order]
|
||||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
_pm_strip_after_noop = PatternMatcher([(UPat(Ops.AFTER, src=(UPat(name="buf"), UPat(Ops.NOOP))), lambda ctx,buf: buf)])
|
||||||
|
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str, extra_pm:PatternMatcher|None=None) -> None:
|
||||||
with cpu_profile(TracingKey(name), "TINY"):
|
with cpu_profile(TracingKey(name), "TINY"):
|
||||||
# get tensors in scope
|
# get tensors in scope
|
||||||
in_scope: dict[UOp, bool] = {}
|
in_scope: dict[UOp, bool] = {}
|
||||||
|
|
@ -34,7 +37,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
||||||
|
|
||||||
# get all Tensors and apply the map
|
# get all Tensors and apply the map
|
||||||
sink = UOp.sink(*[t.uop for t in scope_tensors])
|
sink = UOp.sink(*[t.uop for t in scope_tensors])
|
||||||
new_sink = sink.substitute(applied_map, name=f"substitute {name}")
|
new_sink = sink.substitute(applied_map, name=f"substitute {name}", extra_pm=extra_pm)
|
||||||
|
|
||||||
# set the relevant uop to the realized UOps
|
# set the relevant uop to the realized UOps
|
||||||
for t,s,ns in zip(scope_tensors, sink.src, new_sink.src):
|
for t,s,ns in zip(scope_tensors, sink.src, new_sink.src):
|
||||||
|
|
@ -249,17 +252,25 @@ class Tensor(OpMixin):
|
||||||
"""
|
"""
|
||||||
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||||
|
|
||||||
|
def callify(self, *lst:Tensor) -> Tensor:
|
||||||
|
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||||
|
big_sink, buffer_map = transform_to_call(big_sink)
|
||||||
|
_apply_map_to_tensors({x:y.after(big_sink) for x,y in buffer_map.items()}, name="callify")
|
||||||
|
return self
|
||||||
|
|
||||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
||||||
"""
|
"""
|
||||||
Creates the schedule needed to realize these Tensor(s), with Variables.
|
Creates the schedule needed to realize these Tensor(s), with Variables.
|
||||||
|
|
||||||
NOTE: A Tensor can only be scheduled once.
|
NOTE: A Tensor can only be scheduled once.
|
||||||
"""
|
"""
|
||||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
# collect existing CALLs before callify (so we can clean them up in other tensors that share them)
|
||||||
|
pre_calls = {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.CALL}
|
||||||
# this is where the schedule cache should go
|
self.callify(*lst)
|
||||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
calls, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(*[x.uop for x in (self,)+lst]))
|
||||||
_apply_map_to_tensors(becomes_map, name="buffers")
|
# replace scheduled CALLs with NOOP so AFTER(buf, CALL) -> AFTER(buf, NOOP) -> buf in scope tensors
|
||||||
|
# include pre-existing CALLs too (they were reconstructed inside callify, but other tensors still reference the originals)
|
||||||
|
_apply_map_to_tensors({c:UOp(Ops.NOOP) for c in set(calls) | pre_calls}, name="buffers", extra_pm=_pm_strip_after_noop)
|
||||||
return schedule, var_vals
|
return schedule, var_vals
|
||||||
|
|
||||||
def schedule(self, *lst:Tensor) -> list[ExecItem]:
|
def schedule(self, *lst:Tensor) -> list[ExecItem]:
|
||||||
|
|
@ -278,8 +289,12 @@ class Tensor(OpMixin):
|
||||||
# recursively realize pending assigns that this assign's value depends on
|
# recursively realize pending assigns that this assign's value depends on
|
||||||
for u in assign_uop.toposort():
|
for u in assign_uop.toposort():
|
||||||
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
|
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
|
||||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
|
sink = UOp.sink(assign_uop)
|
||||||
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
|
call, buffer_map = transform_to_call(sink)
|
||||||
|
callified_sink = UOp.sink(*[buffer_map.get(s, s).after(call) for s in sink.src])
|
||||||
|
calls, schedule, var_vals = complete_create_schedule_with_vars(callified_sink)
|
||||||
|
becomes_map = {**buffer_map, **{c:UOp(Ops.NOOP) for c in calls}}
|
||||||
|
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign", extra_pm=_pm_strip_after_noop)
|
||||||
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
|
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
|
||||||
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
|
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
|
||||||
if becomes_map:
|
if becomes_map:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue