mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
jit: captures linears, not execitems (#15399)
* jit: captures linears, not execitems * x * um * etsts * mockcuda
This commit is contained in:
parent
c13d9d29ff
commit
9656d97d97
7 changed files with 35 additions and 31 deletions
|
|
@ -618,7 +618,7 @@ class TestJitFree(unittest.TestCase):
|
|||
|
||||
expected_savings = (len(inp) * inp.dtype.itemsize * 2) + dtypes.float32.itemsize # (t1 and t2) + out
|
||||
|
||||
self.assertEqual(savings_after_free, expected_savings)
|
||||
self.assertGreaterEqual(savings_after_free, expected_savings)
|
||||
out = fxn(Tensor([11,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 136)
|
||||
|
||||
|
|
@ -628,7 +628,7 @@ class TestJitFree(unittest.TestCase):
|
|||
fxn.captured.free_intermediates() # 2nd time to validate
|
||||
savings_after_free = pre_free - GlobalCounters.mem_used
|
||||
|
||||
self.assertEqual(savings_after_free, expected_savings)
|
||||
self.assertGreaterEqual(savings_after_free, expected_savings)
|
||||
out = fxn(Tensor([11,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 136)
|
||||
|
||||
|
|
|
|||
8
test/external/external_test_opt.py
vendored
8
test/external/external_test_opt.py
vendored
|
|
@ -6,14 +6,18 @@ import torch
|
|||
from tinygrad import GlobalCounters, Tensor, Device
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.engine.realize import capturing
|
||||
from tinygrad.engine.realize import capturing, run_schedule
|
||||
from tinygrad.engine.schedule import linear_to_schedule
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
|
||||
class CLCache:
|
||||
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None):
|
||||
self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
|
||||
self.count = 0
|
||||
def add(self, ei): self.count += 1
|
||||
def add_linear(self, linear, var_vals):
|
||||
schedule = linear_to_schedule(linear)
|
||||
self.count += len(schedule)
|
||||
run_schedule(schedule, var_vals)
|
||||
def __enter__(self):
|
||||
if self.preclear:
|
||||
gc.collect()
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ def cuMemHostAlloc(pp, bytesize: int, flags: int) -> int:
|
|||
def cuMemFreeHost(p: ctypes.c_void_p) -> int: return cuMemFree_v2(p)
|
||||
|
||||
def cuMemcpyDtoDAsync_v2(dst, src, bytesize: int, stream: Any) -> int:
|
||||
ctypes.memmove(dst.value, src.value, bytesize)
|
||||
ctypes.memmove(dst if isinstance(dst, int) else dst.value, src if isinstance(src, int) else src.value, bytesize)
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuFuncSetAttribute(hfunc, attrib: int, value: int) -> int:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import unittest
|
|||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.tensor import _METADATA
|
||||
from tinygrad.engine.realize import capturing
|
||||
from tinygrad.engine.schedule import linear_to_schedule
|
||||
from tinygrad.helpers import Context
|
||||
|
||||
@unittest.skip("tensor metadata is no longer supported")
|
||||
|
|
@ -94,10 +95,11 @@ class TestTensorMetadata(unittest.TestCase):
|
|||
self.assertEqual(si.metadata, ())
|
||||
|
||||
def _has_metadata(self, h, name):
|
||||
items = []
|
||||
capturing.append(type("", (), {"add": lambda _, ei: items.append(ei)})())
|
||||
linears = []
|
||||
capturing.append(type("", (), {"add_linear": lambda _, linear, var_vals: linears.append(linear)})())
|
||||
try: h.realize()
|
||||
finally: capturing.clear()
|
||||
items = [ei for linear in linears for ei in linear_to_schedule(linear)]
|
||||
return any(m.name == name for ei in items for m in ei.metadata)
|
||||
|
||||
def test_metadata_survives_realize_pending_assign(self):
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ from tinygrad.dtype import DType
|
|||
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
|
||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.engine.schedule import linear_to_schedule
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.schedule.rangeify import mop_cleanup
|
||||
from dataclasses import dataclass, replace
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
class GraphException(Exception): pass
|
||||
class JitError(Exception): pass
|
||||
|
|
@ -280,17 +280,7 @@ class TinyJit(Generic[ReturnType]):
|
|||
self.prune = prune
|
||||
self.optimize = optimize
|
||||
|
||||
def add_buffer(self, b:Buffer) -> Buffer:
|
||||
if found:=self._buffer_replace.get(b, None): return found
|
||||
if b.is_allocated() or b.uop_refcount > 0: return b
|
||||
if b._base is not None:
|
||||
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
|
||||
else:
|
||||
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
||||
return ret
|
||||
|
||||
def add(self, ei:ExecItem):
|
||||
self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.prg))
|
||||
def add_linear(self, linear:UOp, var_vals:dict[str, int]): self._linears.append(linear)
|
||||
|
||||
def reset(self):
|
||||
assert self.fxn is not None, "can't reset without function"
|
||||
|
|
@ -321,20 +311,20 @@ class TinyJit(Generic[ReturnType]):
|
|||
# jit capture
|
||||
assert self.fxn is not None
|
||||
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
||||
self._jit_cache: list[ExecItem] = []
|
||||
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
||||
# TODO: should we always disable the memory planner here? it must be off for prune
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
|
||||
self._linears: list[UOp] = []
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
capturing.append(self)
|
||||
try:
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(*params)
|
||||
finally: capturing.clear()
|
||||
jit_cache = self._jit_cache
|
||||
del self._buffer_replace, self._jit_cache
|
||||
if not len(jit_cache): raise JitError("didn't JIT anything!")
|
||||
if not len(self._linears): raise JitError("didn't JIT anything!")
|
||||
_check_no_non_tensor_return(ret)
|
||||
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs")
|
||||
|
||||
# combine all captured linears into one and convert to ExecItems
|
||||
jit_cache = [ei.lower() for ei in linear_to_schedule(UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears]))))]
|
||||
del self._linears
|
||||
|
||||
# track inputs that are views of buffers
|
||||
# TODO: eventually expected_buffers should live in ExecItem
|
||||
|
|
@ -367,7 +357,9 @@ class TinyJit(Generic[ReturnType]):
|
|||
input_replace = get_input_replace(jit_cache, input_buffers)
|
||||
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
||||
|
||||
# set this for next run
|
||||
# exec
|
||||
for ei in jit_cache: ei.run(var_vals)
|
||||
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
|
||||
if self.optimize: self.captured.replan_buffers_memory_layout()
|
||||
elif self.cnt >= 2:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import cast, Callable
|
||||
import time, pprint, random, itertools, math
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context, unwrap
|
||||
from tinygrad.helpers import EMULATED_DTYPES
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer
|
||||
|
|
@ -187,12 +187,11 @@ class ExecItem:
|
|||
|
||||
# **************** main run function ****************
|
||||
|
||||
capturing: list = [] # put classes with an add method in here
|
||||
capturing: list = [] # put classes with an add_linear method in here
|
||||
|
||||
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
while len(schedule):
|
||||
ei = schedule.pop(0).lower()
|
||||
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
||||
if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK:
|
||||
# copy in allocated buffers from the GPU
|
||||
bufs = [b for b in ei.bufs if b is not None]
|
||||
|
|
|
|||
|
|
@ -84,7 +84,9 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]:
|
|||
return schedule
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.realize import capturing
|
||||
from tinygrad.schedule.rangeify import get_kernel_graph
|
||||
from tinygrad.helpers import CAPTURING
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat
|
||||
|
||||
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
||||
|
|
@ -156,6 +158,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], di
|
|||
if var_vals.get(nm, val) != val: raise RuntimeError(f"bind mismatch on {nm}, {var_vals[nm]} != {val}")
|
||||
var_vals[nm] = val
|
||||
|
||||
# jit captures this schedule, no need to execute.
|
||||
if len(capturing) and CAPTURING:
|
||||
capturing[0].add_linear(linear, var_vals)
|
||||
return [], var_vals
|
||||
|
||||
# convert LINEAR to ExecItems
|
||||
schedule: list[ExecItem] = linear_to_schedule(linear)
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue