jit: captures linears, not execitems (#15399)

* jit: captures linears, not execitems

* x

* um

* etsts

* mockcuda
This commit is contained in:
nimlgen 2026-03-21 16:32:12 +08:00 committed by GitHub
commit 9656d97d97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 35 additions and 31 deletions

View file

@ -618,7 +618,7 @@ class TestJitFree(unittest.TestCase):
expected_savings = (len(inp) * inp.dtype.itemsize * 2) + dtypes.float32.itemsize # (t1 and t2) + out
self.assertEqual(savings_after_free, expected_savings)
self.assertGreaterEqual(savings_after_free, expected_savings)
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 136)
@ -628,7 +628,7 @@ class TestJitFree(unittest.TestCase):
fxn.captured.free_intermediates() # 2nd time to validate
savings_after_free = pre_free - GlobalCounters.mem_used
self.assertEqual(savings_after_free, expected_savings)
self.assertGreaterEqual(savings_after_free, expected_savings)
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 136)

View file

@ -6,14 +6,18 @@ import torch
from tinygrad import GlobalCounters, Tensor, Device
from tinygrad.helpers import getenv
from tinygrad.nn.state import get_parameters
from tinygrad.engine.realize import capturing
from tinygrad.engine.realize import capturing, run_schedule
from tinygrad.engine.schedule import linear_to_schedule
from tinygrad.tensor import _to_np_dtype
class CLCache:
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None):
self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
self.count = 0
def add(self, ei): self.count += 1
def add_linear(self, linear, var_vals):
schedule = linear_to_schedule(linear)
self.count += len(schedule)
run_schedule(schedule, var_vals)
def __enter__(self):
if self.preclear:
gc.collect()

View file

@ -154,7 +154,7 @@ def cuMemHostAlloc(pp, bytesize: int, flags: int) -> int:
def cuMemFreeHost(p: ctypes.c_void_p) -> int: return cuMemFree_v2(p)
def cuMemcpyDtoDAsync_v2(dst, src, bytesize: int, stream: Any) -> int:
ctypes.memmove(dst.value, src.value, bytesize)
ctypes.memmove(dst if isinstance(dst, int) else dst.value, src if isinstance(src, int) else src.value, bytesize)
return orig_cuda.CUDA_SUCCESS
def cuFuncSetAttribute(hfunc, attrib: int, value: int) -> int:

View file

@ -2,6 +2,7 @@ import unittest
from tinygrad import Tensor, dtypes
from tinygrad.tensor import _METADATA
from tinygrad.engine.realize import capturing
from tinygrad.engine.schedule import linear_to_schedule
from tinygrad.helpers import Context
@unittest.skip("tensor metadata is no longer supported")
@ -94,10 +95,11 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(si.metadata, ())
def _has_metadata(self, h, name):
items = []
capturing.append(type("", (), {"add": lambda _, ei: items.append(ei)})())
linears = []
capturing.append(type("", (), {"add_linear": lambda _, linear, var_vals: linears.append(linear)})())
try: h.realize()
finally: capturing.clear()
items = [ei for linear in linears for ei in linear_to_schedule(linear)]
return any(m.name == name for ei in items for m in ei.metadata)
def test_metadata_survives_realize_pending_assign(self):

View file

@ -7,10 +7,10 @@ from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.engine.schedule import linear_to_schedule
from tinygrad.nn.state import get_parameters
from tinygrad.schedule.rangeify import mop_cleanup
from dataclasses import dataclass, replace
from weakref import WeakKeyDictionary
class GraphException(Exception): pass
class JitError(Exception): pass
@ -280,17 +280,7 @@ class TinyJit(Generic[ReturnType]):
self.prune = prune
self.optimize = optimize
def add_buffer(self, b:Buffer) -> Buffer:
if found:=self._buffer_replace.get(b, None): return found
if b.is_allocated() or b.uop_refcount > 0: return b
if b._base is not None:
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
else:
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
return ret
def add(self, ei:ExecItem):
self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.prg))
def add_linear(self, linear:UOp, var_vals:dict[str, int]): self._linears.append(linear)
def reset(self):
assert self.fxn is not None, "can't reset without function"
@ -321,20 +311,20 @@ class TinyJit(Generic[ReturnType]):
# jit capture
assert self.fxn is not None
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
self._jit_cache: list[ExecItem] = []
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
# TODO: should we always disable the memory planner here? it must be off for prune
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
self._linears: list[UOp] = []
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(*params)
finally: capturing.clear()
jit_cache = self._jit_cache
del self._buffer_replace, self._jit_cache
if not len(jit_cache): raise JitError("didn't JIT anything!")
if not len(self._linears): raise JitError("didn't JIT anything!")
_check_no_non_tensor_return(ret)
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs")
# combine all captured linears into one and convert to ExecItems
jit_cache = [ei.lower() for ei in linear_to_schedule(UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears]))))]
del self._linears
# track inputs that are views of buffers
# TODO: eventually expected_buffers should live in ExecItem
@ -367,7 +357,9 @@ class TinyJit(Generic[ReturnType]):
input_replace = get_input_replace(jit_cache, input_buffers)
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
# set this for next run
# exec
for ei in jit_cache: ei.run(var_vals)
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
if self.optimize: self.captured.replan_buffers_memory_layout()
elif self.cnt >= 2:

View file

@ -1,7 +1,7 @@
from typing import cast, Callable
import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context, unwrap
from tinygrad.helpers import EMULATED_DTYPES
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer
@ -187,12 +187,11 @@ class ExecItem:
# **************** main run function ****************
capturing: list = [] # put classes with an add method in here
capturing: list = [] # put classes with an add_linear method in here
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
while len(schedule):
ei = schedule.pop(0).lower()
if len(capturing) and CAPTURING: capturing[0].add(ei)
if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK:
# copy in allocated buffers from the GPU
bufs = [b for b in ei.bufs if b is not None]

View file

@ -84,7 +84,9 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]:
return schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.realize import capturing
from tinygrad.schedule.rangeify import get_kernel_graph
from tinygrad.helpers import CAPTURING
from tinygrad.uop.ops import PatternMatcher, UPat
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
@ -156,6 +158,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], di
if var_vals.get(nm, val) != val: raise RuntimeError(f"bind mismatch on {nm}, {var_vals[nm]} != {val}")
var_vals[nm] = val
# jit captures this schedule, no need to execute.
if len(capturing) and CAPTURING:
capturing[0].add_linear(linear, var_vals)
return [], var_vals
# convert LINEAR to ExecItems
schedule: list[ExecItem] = linear_to_schedule(linear)
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)