mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
jit: don't memplan buffers reachable from live tensors (#16588)
The memory planner was suballocating BUFFERs created during JIT capture that are still referenced by external lazy tensor graphs, like the .grad tensors assigned by backward(). The replay then only writes the arena slices, so realizing such a tensor after the call reads freshly allocated memory and silently returns zeros. Hold every BUFFER reachable from a live Tensor instead of only the parameters of the return value; true internals are still planned. Fixes #16571.
This commit is contained in:
parent
2bfdf85f87
commit
76c10cd635
2 changed files with 25 additions and 2 deletions
|
|
@ -368,6 +368,28 @@ class TestJit(unittest.TestCase):
|
|||
c = f(a, b)
|
||||
self.assertEqual(c.item(), i+1)
|
||||
|
||||
def test_jit_lazy_grad_after_replay(self):
|
||||
# the lazy .grad created during capture is read outside the JIT, the memory planner must not suballocate its buffers (issue #16571)
|
||||
from tinygrad import nn
|
||||
def step(conv, x, y):
|
||||
out = conv(x.permute(0, 3, 1, 2).contiguous()).relu().flatten(1)
|
||||
loss = (out * y).sum(axis=1) # per-example loss
|
||||
loss.sum().backward()
|
||||
conv.weight.grad = None
|
||||
(loss * 0.5).sum().backward()
|
||||
return loss.mean().realize()
|
||||
|
||||
Tensor.manual_seed(42)
|
||||
conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
|
||||
x, y = Tensor.randn(4, 8, 8, 3).realize(), Tensor.randn(4, 4*8*8).realize()
|
||||
|
||||
step(conv, x, y)
|
||||
ref = conv.weight.grad.numpy()
|
||||
jit_step = TinyJit(step)
|
||||
for _ in range(4):
|
||||
jit_step(conv, x, y)
|
||||
np.testing.assert_allclose(conv.weight.grad.numpy(), ref, atol=1e-4, rtol=1e-5)
|
||||
|
||||
class TestJitPrune(unittest.TestCase):
|
||||
def test_simple_prune(self):
|
||||
weights = Tensor.rand(16).realize()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import TypeVar, Generic, Callable, Any
|
||||
import functools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.tensor import Tensor, all_tensors
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ
|
||||
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
|
|
@ -296,7 +296,8 @@ class TinyJit(Generic[ReturnType]):
|
|||
if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels")
|
||||
run_linear(onetime_linear, var_vals)
|
||||
|
||||
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
|
||||
# hold all buffers reachable from live Tensors (e.g. lazy .grad created during capture), the memory planner can't suballocate those
|
||||
held_bufs = set(buffers) | {u for tref in list(all_tensors) if (t:=tref()) is not None for u in t.uop.toposort() if u.op is Ops.BUFFER}
|
||||
linear = jit_lower(big_linear, held_bufs, input_buf_uops)
|
||||
self.captured = CapturedJit(ret, linear, names, expected_input_info)
|
||||
ret = self.captured(input_buf_uops, var_vals)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue