new memory scheduler with explicit refcounts (#4198)

* new memory scheduler with explict refcounts

* move central memory planner

* typo + use central memory planner in openpilot

* cleanups

* include lb_refcount in pickle

* replace PlaceHolder with memory planner

* cleaner
This commit is contained in:
George Hotz 2024-04-17 08:46:47 +04:00 committed by GitHub
commit 8564e28a1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 67 additions and 46 deletions

View file

@ -14,7 +14,7 @@ from extra.onnx import get_run_onnx
from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad.dtype import ImageDType
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import LoadOps, ScheduleItem
Device.DEFAULT = "GPU"
@ -107,6 +107,7 @@ if __name__ == "__main__":
run_schedule(schedule_independent)
run_schedule(schedule_input)
schedule = memory_planner(schedule)
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")

View file

@ -12,14 +12,16 @@ class BufferOptions:
nolru: bool = False
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, initial_value:Optional[bytes]=None):
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
initial_value:Optional[bytes]=None, lb_refcount=0):
assert isinstance(dtype, DType)
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
self.device, self.size, self.dtype, self.options = device, size, dtype, options
self.device, self.size, self.dtype, self.options, self.lb_refcount = device, size, dtype, options, lb_refcount
if opaque is not None: self.allocate(opaque)
if initial_value is not None:
self.allocate()
self.copyin(memoryview(initial_value))
def is_allocated(self) -> bool: return hasattr(self, '_buf')
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
def allocate(self, opaque=None) -> Buffer:
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
@ -30,11 +32,11 @@ class Buffer:
return self
def __reduce__(self):
buf = None
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options)
if hasattr(self, '_buf'):
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
if self.is_allocated():
buf = bytearray(self.nbytes)
self.copyout(memoryview(buf))
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf)
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
@property
def nbytes(self): return self.size*self.dtype.itemsize
def __del__(self):
@ -52,10 +54,12 @@ class Buffer:
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
assert self.is_allocated(), "can't copyin to unallocated buffer"
self.allocator.copyin(self._buf, mv)
return self
def copyout(self, mv:memoryview) -> memoryview:
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
assert self.is_allocated(), "can't copyout unallocated buffer"
self.allocator.copyout(mv, self._buf)
return mv

View file

@ -1,17 +1,16 @@
from __future__ import annotations
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
import functools, itertools, operator
from dataclasses import dataclass
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException
from tinygrad.device import Buffer, Runner, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device
from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.engine.realize import ExecItem, capturing
from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planner
from tinygrad.nn.state import get_parameters
from weakref import ref, WeakKeyDictionary
from weakref import WeakKeyDictionary
# TODO: these graph functions probably shouldn't exist here
@ -71,44 +70,25 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace
class PlaceHolder:
placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
def __init__(self, buf:Buffer):
self.size, self.dtype, self.device, self.ref, self.bufid, self.options = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf), buf.options
def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid, self.options)
def __hash__(self): return hash(self.to_tuple())
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
@staticmethod
def create_if_needed(buf:Buffer) -> Union[PlaceHolder, Buffer]:
if found:=PlaceHolder.placeholders.get(buf, None): return found
if hasattr(buf, '_buf'): return buf
PlaceHolder.placeholders[buf] = ret = PlaceHolder(buf.ensure_allocated()) # TODO: do I need to allocate here?
return ret
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
ret = self.ref()
if ret: return ret
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate()
return buffer_cache[self]
@dataclass(frozen=True)
class WeakExecItem:
prg: Runner
rawbufs: List[Union[PlaceHolder, Buffer]]
ReturnType = TypeVar('ReturnType')
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]):
self.fxn = fxn
self.reset()
def add_buffer(self, b:Buffer) -> Buffer:
if found:=self.buffer_replace.get(b, None): return found
if b.is_allocated() or b.lb_refcount > 0: return b
self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
return ret
def add(self, ei:ExecItem):
self._cc.append(WeakExecItem(ei.prg, [PlaceHolder.create_if_needed(buf) for buf in ei.rawbufs if buf is not None]))
self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.rawbufs if buf is not None]))
def reset(self):
self._cc: List[WeakExecItem] = []
self.jit_cache: List[ExecItem] = []
self.input_replace: Dict[Tuple[int, int], int] = {}
self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
self.cnt: int = 0
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
@ -140,13 +120,14 @@ class TinyJit(Generic[ReturnType]):
self.ret = self.fxn(*args, **kwargs)
Tensor.corealize(get_parameters(self.ret))
capturing.clear()
assert len(self._cc), "didn't JIT anything!"
buffer_cache: Dict[PlaceHolder, Buffer] = {}
self.jit_cache = \
[ExecItem(ei.prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in ei.rawbufs]) for ei in self._cc]
del self._cc
del self.buffer_replace
assert len(self.jit_cache), "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# memory planning (optional)
assigned = _internal_memory_planner([cast(List[Buffer], x.rawbufs) for x in self.jit_cache], debug_prefix="JIT ")
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.rawbufs if x is not None]) for ei in self.jit_cache]
# Condense the items into a graph executor.
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)

View file

@ -1,6 +1,8 @@
from typing import List, Dict, Optional, cast, Generator
from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, Iterable
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.helpers import colored, getenv
from tinygrad.dtype import DType
from tinygrad.helpers import colored, getenv, dedup, DEBUG
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast
from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats
from tinygrad.buffer import Buffer
@ -41,6 +43,34 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non
capturing: List = [] # put classes with an add method in here
def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") -> Dict[Buffer, Buffer]:
last_appearance = {}
for i,u in enumerate(buffers):
for buf in u: last_appearance[buf] = i
# LRU algorithm
assigned: Dict[Buffer, Buffer] = {}
local_cache: DefaultDict[Tuple[str, int, DType], List[Buffer]] = defaultdict(list)
for i,u in enumerate(buffers):
for buf in u:
# all unallocated unparented buffers are fair game to replace
if buf.is_allocated() or buf.lb_refcount > 0: continue
key = (buf.device, buf.size, buf.dtype)
if buf not in assigned:
if len(ll:=local_cache[key]): assigned[buf] = ll.pop()
else: assigned[buf] = Buffer(*key)
if i == last_appearance[buf]:
local_cache[key].append(assigned[buf])
if DEBUG >= 1 and len(ak:=dedup(assigned.keys())) != len(av:=dedup(assigned.values())):
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB to {sum([x.nbytes for x in av])/1e6:.2f} MB")
return assigned
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
assigned = _internal_memory_planner([si.outputs+si.inputs for si in schedule])
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.outputs),
tuple(assigned.get(x, x) for x in si.inputs)) for si in schedule]
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
for ei in lower_schedule(schedule):
if len(capturing): capturing[0].add(ei)

View file

@ -33,6 +33,7 @@ class LazyBuffer:
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
self.buffer: Buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
self.buffer.lb_refcount += 1
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
self.forced_realize = False
else:
@ -40,6 +41,9 @@ class LazyBuffer:
assert base.base == base, "base must be a base itself"
self._base = base
def __del__(self):
if hasattr(self, 'buffer'): self.buffer.lb_refcount -= 1
def __repr__(self) -> str:
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"

View file

@ -15,7 +15,7 @@ from tinygrad.ops import LoadOps
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Device
from tinygrad.shape.symbolic import sint
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import create_schedule_with_vars
# **** start with two base classes, Tensor and Function ****
@ -145,7 +145,8 @@ class Tensor:
if getenv("FUZZ_SCHEDULE"):
from test.external.fuzz_schedule import fuzz_schedule
fuzz_schedule(flatten([x.lazydata.lbs for x in lst]))
run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst])))
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst]))
run_schedule(memory_planner(schedule), var_vals)
def realize(self) -> Tensor:
"""Trigger the computation needed to create this Tensor. This is a light wrapper around corealize."""