mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
49 lines
2.5 KiB
Python
49 lines
2.5 KiB
Python
from typing import List, Dict, Optional, cast
|
|
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
|
|
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats
|
|
from tinygrad.graph import print_tree, realized_lazybuffer
|
|
from tinygrad.helpers import prod, colored, getenv
|
|
from tinygrad.shape.symbolic import Variable
|
|
|
|
# *** schedule running ***
|
|
|
|
class CustomOp(JITRunner):
|
|
def __init__(self, fxn):
|
|
self.fxn = fxn
|
|
super().__init__()
|
|
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
|
|
|
|
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
|
|
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" # noqa: E501
|
|
if si.ast.op is LoadOps.EMPTY: return None
|
|
if si.ast.op is LoadOps.COPY: return BufferCopy
|
|
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
|
|
return Device[si.out.device].get_runner(si.ast)
|
|
|
|
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
|
def run_schedule(schedule:List[ScheduleItem]):
|
|
while len(schedule):
|
|
si = schedule.pop(0)
|
|
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
|
|
|
|
# get the program
|
|
prg = lower_schedule_item(si)
|
|
|
|
# invalidate the output buffer if there's a non contig usage of it in inputs
|
|
if si.out.output_buffer is not None:
|
|
for i,a in enumerate(si.inputs):
|
|
if a.realized == si.out.output_buffer:
|
|
if any(not x.arg.st.contiguous for x in si.ast.lazyops if x.op == BufferOps.LOAD and x.arg.idx == i+1):
|
|
si.out.output_buffer = None
|
|
break
|
|
|
|
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
|
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
|
|
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype)
|
|
del si.out.srcs
|
|
|
|
# run the function (put it in JIT)
|
|
assert all(x.realized is not None for x in si.inputs), f"can't run, some inputs aren't realized {[x for x in si.inputs if x.realized is None]}"
|
|
if prg: prg.exec([si.out.realized] + [cast(Buffer, x.realized) for x in si.inputs], si.var_vals)
|
|
else: update_stats(colored(f"empty {si.out.st.size():10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
|
|
realized_lazybuffer(si.out, GlobalCounters.kernel_count)
|