jit: don't memplan buffers reachable from live tensors (#16588)

The memory planner was suballocating BUFFERs created during JIT capture that are still referenced by external lazy tensor graphs, like the .grad tensors assigned by backward(). The replay then only writes the arena slices, so realizing such a tensor after the call reads freshly allocated memory and silently returns zeros. Hold every BUFFER reachable from a live Tensor instead of only the parameters of the return value; true internals are still planned. Fixes #16571.
This commit is contained in:
Philip Sinitsin 2026-06-12 15:51:54 +01:00 committed by GitHub
commit 76c10cd635
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 2 deletions

View file

@ -368,6 +368,28 @@ class TestJit(unittest.TestCase):
c = f(a, b)
self.assertEqual(c.item(), i+1)
def test_jit_lazy_grad_after_replay(self):
# the lazy .grad created during capture is read outside the JIT, the memory planner must not suballocate its buffers (issue #16571)
from tinygrad import nn
def step(conv, x, y):
out = conv(x.permute(0, 3, 1, 2).contiguous()).relu().flatten(1)
loss = (out * y).sum(axis=1) # per-example loss
loss.sum().backward()
conv.weight.grad = None
(loss * 0.5).sum().backward()
return loss.mean().realize()
Tensor.manual_seed(42)
conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
x, y = Tensor.randn(4, 8, 8, 3).realize(), Tensor.randn(4, 4*8*8).realize()
step(conv, x, y)
ref = conv.weight.grad.numpy()
jit_step = TinyJit(step)
for _ in range(4):
jit_step(conv, x, y)
np.testing.assert_allclose(conv.weight.grad.numpy(), ref, atol=1e-4, rtol=1e-5)
class TestJitPrune(unittest.TestCase):
def test_simple_prune(self):
weights = Tensor.rand(16).realize()

View file

@ -1,6 +1,6 @@
from typing import TypeVar, Generic, Callable, Any
import functools, collections
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, all_tensors
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType, dtypes
@ -296,7 +296,8 @@ class TinyJit(Generic[ReturnType]):
if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels")
run_linear(onetime_linear, var_vals)
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
# hold all buffers reachable from live Tensors (e.g. lazy .grad created during capture), the memory planner can't suballocate those
held_bufs = set(buffers) | {u for tref in list(all_tensors) if (t:=tref()) is not None for u in t.uop.toposort() if u.op is Ops.BUFFER}
linear = jit_lower(big_linear, held_bufs, input_buf_uops)
self.captured = CapturedJit(ret, linear, names, expected_input_info)
ret = self.captured(input_buf_uops, var_vals)