tinygrad/tinygrad/realize.py
chenyu b1d9e54ea3
regenerate kernel ast dataset (#2968)
added back the log ast function and removed hacks that work around the old dataset
2024-01-01 20:26:17 -05:00

49 lines
2.5 KiB
Python

from typing import List, Dict, Optional, cast
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad.helpers import prod, colored, getenv
from tinygrad.shape.symbolic import Variable
# *** schedule running ***
class CustomOp(JITRunner):
def __init__(self, fxn):
self.fxn = fxn
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" # noqa: E501
if si.ast.op is LoadOps.EMPTY: return None
if si.ast.op is LoadOps.COPY: return BufferCopy
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
return Device[si.out.device].get_runner(si.ast)
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
def run_schedule(schedule:List[ScheduleItem]):
while len(schedule):
si = schedule.pop(0)
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
# get the program
prg = lower_schedule_item(si)
# invalidate the output buffer if there's a non contig usage of it in inputs
if si.out.output_buffer is not None:
for i,a in enumerate(si.inputs):
if a.realized == si.out.output_buffer:
if any(not x.arg.st.contiguous for x in si.ast.lazyops if x.op == BufferOps.LOAD and x.arg.idx == i+1):
si.out.output_buffer = None
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype)
del si.out.srcs
# run the function (put it in JIT)
assert all(x.realized is not None for x in si.inputs), f"can't run, some inputs aren't realized {[x for x in si.inputs if x.realized is None]}"
if prg: prg.exec([si.out.realized] + [cast(Buffer, x.realized) for x in si.inputs], si.var_vals)
else: update_stats(colored(f"empty {si.out.st.size():10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
realized_lazybuffer(si.out, GlobalCounters.kernel_count)