mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix the need for explicit assign. track pending assigns for each buffer, and run those before the main realize in order
387 lines
14 KiB
Python
387 lines
14 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
JIT Footguns: Documenting unexpected behavior changes when using @TinyJit
|
|
|
|
Each test shows behavior that works without JIT but changes with JIT.
|
|
Comments marked "should be X!" indicate the intuitively expected value.
|
|
|
|
SILENT MISMATCHES (highest priority - wrong results, no error):
|
|
class_method_shared_across_instances EASY could check if first arg is self and warn
|
|
slice_assign_requires_realize MED assign graph not connected to read during JIT replay
|
|
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
|
symbolic_pad_view_frozen MED pad view BIND values baked in at capture time
|
|
python_constants_frozen HARD inherent to tracing JITs
|
|
conditional_branches_frozen HARD inherent to tracing JITs
|
|
|
|
ERRORS RAISED (lower priority - at least users know):
|
|
item_bakes_in_values EASY raises JitError if .item()/.data() accessed during capture
|
|
unrealized_const_input_error EASY raises JitError for unrealized const inputs
|
|
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
|
|
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
|
|
duplicate_inputs_fail MED would need to handle aliasing in input_replace
|
|
nested_jit_fails_on_second_call MED could fail on first call instead of second
|
|
"""
|
|
import unittest
|
|
import numpy as np
|
|
from tinygrad import Tensor, TinyJit
|
|
from tinygrad.engine.jit import JitError
|
|
from tinygrad.helpers import JIT
|
|
|
|
class TestJitFootguns(unittest.TestCase):
|
|
|
|
def test_output_buffer_reuse(self):
|
|
"""Output tensors share buffer after capture - old references get overwritten."""
|
|
@TinyJit
|
|
def f(x): return x.sum().realize()
|
|
|
|
r1 = f(Tensor([1, 1])) # warmup
|
|
r2 = f(Tensor([2, 2])) # capture
|
|
r3 = f(Tensor([3, 3])) # jit exec
|
|
|
|
self.assertEqual(r1.item(), 2) # warmup result independent
|
|
self.assertEqual(r3.item(), 6) # latest is correct
|
|
self.assertEqual(r2.item(), 6) # should be 4! (overwritten by r3)
|
|
|
|
def test_output_buffer_workaround(self):
|
|
"""Use .clone().realize() to get independent copies."""
|
|
@TinyJit
|
|
def f(x): return x.sum().realize()
|
|
|
|
r1 = f(Tensor([1, 1])).clone().realize()
|
|
r2 = f(Tensor([2, 2])).clone().realize()
|
|
r3 = f(Tensor([3, 3])).clone().realize()
|
|
|
|
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
|
|
|
|
def test_graph_input_output_aliasing(self):
|
|
"""Test that JIT handles input=output aliasing correctly, simulating LLM generate pattern.
|
|
|
|
The LLM generate pattern:
|
|
1. First "session": multiple iterations where output becomes next input
|
|
2. Second "session": starts with a NEW input tensor (not the previous output)
|
|
|
|
The bug: GraphRunner computes input_replace during _first_run. If at that time input buffer == output buffer
|
|
(aliasing), it incorrectly includes the output position in input_replace. Later, when a DIFFERENT input
|
|
is passed, the output position gets overwritten with the input, corrupting the computation.
|
|
|
|
This requires multiple kernels to trigger because single-kernel JITs don't get graphed ("only one kernel doesn't graph").
|
|
"""
|
|
from tinygrad import Device
|
|
if Device[Device.DEFAULT].graph is None or JIT != 1:
|
|
self.skipTest("test requires JIT graph support")
|
|
|
|
# Multiple operations to create multiple kernels that get batched into a GraphRunner
|
|
@TinyJit
|
|
def step(x):
|
|
y = (x + 1).realize() # kernel 1
|
|
z = (y * 2).realize() # kernel 2
|
|
return z
|
|
|
|
# Phase 1: warmup and capture
|
|
a = Tensor([10]).contiguous().realize()
|
|
step(a) # warmup (cnt=0)
|
|
b = Tensor([20]).contiguous().realize()
|
|
x = step(b) # capture (cnt=1), x = (20+1)*2 = 42
|
|
|
|
# Phase 2: first "session" - iterations where output becomes input (triggers _first_run with aliasing)
|
|
for _ in range(3):
|
|
x = step(x) # (42+1)*2=86, (86+1)*2=174, (174+1)*2=350
|
|
self.assertEqual(x.item(), 350)
|
|
|
|
# Phase 3: second "session" - NEW input tensor (simulates new generate() call)
|
|
# The bug: GraphRunner's input_replace incorrectly includes the output position
|
|
# When new input y is passed, it overwrites the output buffer, using old value (350) instead of new (100)
|
|
y = Tensor([100]).contiguous().realize()
|
|
for _ in range(3):
|
|
y = step(y) # should be (100+1)*2=202, (202+1)*2=406, (406+1)*2=814
|
|
self.assertEqual(y.item(), 814) # fails with 1406 if bug exists (uses 350 instead of 100)
|
|
|
|
def test_multiple_outputs_same_intermediate(self):
|
|
"""Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard."""
|
|
@TinyJit
|
|
def f(buf, frame):
|
|
new_buf = buf[1:].cat(frame, dim=0)
|
|
return new_buf.contiguous(), new_buf[:1].contiguous()
|
|
|
|
buf = Tensor([[0], [1], [2]]).contiguous().realize()
|
|
for i in range(4):
|
|
frame = Tensor([[10+i]]).contiguous().realize()
|
|
expected_first = buf[1:2].numpy().item()
|
|
new_buf, first = f(buf, frame)
|
|
self.assertEqual(first.numpy().item(), expected_first)
|
|
buf = new_buf
|
|
|
|
def test_slice_assign_works_without_realize(self):
|
|
"""Slice assign then read from same buffer - pending assigns are side-realized."""
|
|
from tinygrad import Variable
|
|
v_pos = Variable("pos", 0, 3)
|
|
cache = Tensor.zeros(4, 4).contiguous().realize()
|
|
@TinyJit
|
|
def f(pos):
|
|
cache[pos:pos+1, :].assign(Tensor.ones(1, 4))
|
|
return cache.sum().realize()
|
|
for i in range(4):
|
|
cache.assign(Tensor.zeros(4, 4)).realize()
|
|
self.assertEqual(f(v_pos.bind(i)).item(), 4.0)
|
|
|
|
def test_symbolic_pad_view_frozen(self):
|
|
"""Symbolic pad view has BIND values baked in at capture time. TODO: pad should be captured in jit."""
|
|
from tinygrad import Variable
|
|
a = Tensor.rand(3, 10).realize()
|
|
|
|
# broken: pad is a view, BIND values frozen at capture (i=2)
|
|
@TinyJit
|
|
def f_broken(a): return (a+1).pad((None, (0, 10-a.shape[1]))).realize()
|
|
for i in range(1, 5): f_broken(a[:, :Variable("i", 1, 10).bind(i)])
|
|
self.assertEqual(int((f_broken(a[:, :Variable("i", 1, 10).bind(4)])[0] != 0).sum().item()), 2) # should be 4!
|
|
|
|
# workaround: contiguous fuses pad into kernel
|
|
@TinyJit
|
|
def f_fixed(a): return (a+1).pad((None, (0, 10-a.shape[1]))).contiguous().realize()
|
|
for i in range(1, 5): f_fixed(a[:, :Variable("i", 1, 10).bind(i)])
|
|
self.assertEqual(int((f_fixed(a[:, :Variable("i", 1, 10).bind(4)])[0] != 0).sum().item()), 4)
|
|
|
|
def test_non_tensor_outputs_error(self):
|
|
@TinyJit
|
|
def f(x, mult): return (x * 2).realize(), mult * 10
|
|
with self.assertRaises(JitError):
|
|
for i in range(3): f(Tensor([i]), i)
|
|
|
|
def test_duplicate_inputs_fail(self):
|
|
"""JIT cannot handle the same tensor passed as multiple arguments."""
|
|
@TinyJit
|
|
def f(a, b): return (a + b).realize()
|
|
|
|
x = Tensor([1, 2, 3])
|
|
with self.assertRaises(JitError):
|
|
f(x, x)
|
|
|
|
def test_tensors_in_containers(self):
|
|
@TinyJit
|
|
def f(a, arr): return (a + arr[0]).realize()
|
|
for i in range(4):
|
|
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
|
|
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
|
|
|
|
def test_nested_jit_fails_on_second_call(self):
|
|
"""Nested JIT works on first call but fails on second."""
|
|
@TinyJit
|
|
def inner(t): return t + 1
|
|
@TinyJit
|
|
def outer(t): return inner(t) * 3
|
|
|
|
self.assertEqual(outer(Tensor([1])).realize().item(), 6) # works!
|
|
with self.assertRaises(RuntimeError):
|
|
outer(Tensor([2])).realize() # fails
|
|
|
|
def test_implicit_inputs_need_realize(self):
|
|
"""Closure tensors must be realized before JIT call."""
|
|
x = Tensor([0])
|
|
|
|
@TinyJit
|
|
def f(): return (x * 2).realize()
|
|
|
|
for i in range(5):
|
|
x.assign(Tensor([i])).realize() # must realize!
|
|
self.assertEqual(f().item(), i * 2)
|
|
|
|
def test_views_with_different_offsets_fail(self):
|
|
"""JIT requires consistent tensor views across calls."""
|
|
@TinyJit
|
|
def f(a): return (a + 1).realize()
|
|
|
|
base = Tensor.randn(10, 10).realize()
|
|
with self.assertRaises(JitError):
|
|
for i in range(1, 5):
|
|
f(base[:, i:i+2]) # different offset each time
|
|
|
|
def test_shape_change_after_capture_fails(self):
|
|
"""Shapes are locked at capture time."""
|
|
@TinyJit
|
|
def f(a, b): return (a + b).realize()
|
|
|
|
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # warmup
|
|
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # capture
|
|
|
|
with self.assertRaises(JitError):
|
|
f(Tensor.randn(20, 20), Tensor.randn(20, 20))
|
|
|
|
def test_python_constants_frozen(self):
|
|
"""Python variables inside JIT use capture-time values."""
|
|
mult = 1
|
|
|
|
@TinyJit
|
|
def f(x): return (x * mult).realize()
|
|
|
|
results = []
|
|
for i in range(5):
|
|
mult = i + 1
|
|
results.append(f(Tensor([10])).item())
|
|
|
|
self.assertEqual(results[0], 10) # warmup, mult=1
|
|
self.assertEqual(results[1], 20) # capture, mult=2
|
|
self.assertEqual(results[2], 20) # should be 30!
|
|
self.assertEqual(results[3], 20) # should be 40!
|
|
|
|
def test_unrealized_const_input_error(self):
|
|
"""Const tensors have no buffer to replace, so JIT raises an error. Even explicit .realize() doesn't help."""
|
|
@TinyJit
|
|
def f(a, b): return (a * b).realize()
|
|
|
|
# unrealized const fails
|
|
with self.assertRaises(JitError):
|
|
f(Tensor([1, 2, 3]).realize(), Tensor(2))
|
|
|
|
# explicit .realize() on const still fails - const cannot be realized to have a buffer
|
|
@TinyJit
|
|
def g(a, b): return (a * b).realize()
|
|
with self.assertRaises(JitError):
|
|
g(Tensor([1, 2, 3]).realize(), Tensor(2).realize())
|
|
|
|
def test_conditional_branches_frozen(self):
|
|
"""Only the branch taken during capture runs thereafter."""
|
|
@TinyJit
|
|
def f(x, use_square):
|
|
if use_square:
|
|
return (x * x).realize()
|
|
return (x * 2).realize()
|
|
|
|
f(Tensor([3]), True) # warmup
|
|
f(Tensor([3]), False) # capture (False branch)
|
|
|
|
result = f(Tensor([3]), True) # passing True but False branch runs
|
|
self.assertEqual(result.item(), 6) # should be 9!
|
|
|
|
def test_positional_kwargs_cannot_mix(self):
|
|
"""Must use same calling convention after capture."""
|
|
@TinyJit
|
|
def f(a, b): return (a + b).realize()
|
|
|
|
f(Tensor([1]), Tensor([2])) # warmup with positional
|
|
f(Tensor([1]), Tensor([2])) # capture with positional
|
|
|
|
with self.assertRaises(JitError):
|
|
f(a=Tensor([3]), b=Tensor([4])) # kwargs fail
|
|
|
|
def test_class_method_shared_across_instances(self):
|
|
"""JIT on instance methods is shared at class level."""
|
|
class Model:
|
|
def __init__(self, scale):
|
|
self.scale = Tensor([scale])
|
|
@TinyJit
|
|
def forward(self, x):
|
|
return (x * self.scale).realize()
|
|
|
|
m1, m2 = Model(2), Model(3)
|
|
|
|
m1.forward(Tensor([5])) # warmup
|
|
m1.forward(Tensor([5])) # capture with m1.scale=2
|
|
|
|
self.assertEqual(m1.forward(Tensor([5])).item(), 10)
|
|
self.assertEqual(m2.forward(Tensor([5])).item(), 10) # should be 15!
|
|
|
|
def test_side_effects_only_during_capture(self):
|
|
"""Function body not executed during JIT replay."""
|
|
call_count = [0]
|
|
|
|
@TinyJit
|
|
def f(x):
|
|
call_count[0] += 1
|
|
return (x * 2).realize()
|
|
|
|
f(Tensor([1])) # warmup
|
|
f(Tensor([2])) # capture
|
|
self.assertEqual(call_count[0], 2)
|
|
|
|
f(Tensor([3]))
|
|
f(Tensor([4]))
|
|
f(Tensor([5]))
|
|
self.assertEqual(call_count[0], 2) # still 2, not 5!
|
|
|
|
def test_nothing_realized_fails(self):
|
|
"""Must JIT at least one kernel."""
|
|
@TinyJit
|
|
def f(a, b): return None
|
|
|
|
with self.assertRaises(JitError):
|
|
for _ in range(3):
|
|
f(Tensor([1]), Tensor([2]))
|
|
|
|
def test_item_creates_unrealized_return(self):
|
|
""".item() in shape computation raises error during JIT capture."""
|
|
@TinyJit
|
|
def f(x): return Tensor.zeros(x.sum().item())
|
|
|
|
f(Tensor([1, 1, 1])) # warmup
|
|
with self.assertRaises(JitError):
|
|
f(Tensor([1, 1, 1])) # capture - .item() raises
|
|
|
|
def test_item_bakes_in_values(self):
|
|
""".item() during JIT capture raises error (would bake in value)."""
|
|
@TinyJit
|
|
def f(x, mask): return x.masked_select(mask)
|
|
|
|
f(Tensor([1, 2, 3, 4]), Tensor([True, False, True, False])) # warmup
|
|
with self.assertRaises(JitError):
|
|
f(Tensor([1, 2, 3, 4]), Tensor([True, False, True, False])) # capture - .item() raises
|
|
|
|
def test_tolist_bakes_in_values(self):
|
|
""".tolist() raises error during JIT capture (would bake in values)."""
|
|
@TinyJit
|
|
def f(x): return Tensor(x.tolist())
|
|
|
|
f(Tensor([1, 2, 3])) # warmup
|
|
with self.assertRaises(JitError):
|
|
f(Tensor([1, 2, 3])) # capture - .tolist() raises
|
|
|
|
|
|
class TestJitCorrectBehavior(unittest.TestCase):
|
|
"""Behaviors that work correctly - documented for clarity."""
|
|
|
|
def test_random_regenerates(self):
|
|
"""Random tensors regenerate each call."""
|
|
@TinyJit
|
|
def f(x):
|
|
return (x + Tensor.rand(3)).realize()
|
|
|
|
f(Tensor([0, 0, 0])) # warmup
|
|
f(Tensor([0, 0, 0])) # capture
|
|
|
|
results = {tuple(f(Tensor([0, 0, 0])).numpy().tolist()) for _ in range(5)}
|
|
self.assertEqual(len(results), 5)
|
|
|
|
def test_unrealized_return_auto_realized(self):
|
|
"""Unrealized return tensors are auto-realized."""
|
|
@TinyJit
|
|
def f(a, b): return a + b # no explicit realize
|
|
|
|
for _ in range(5):
|
|
a, b = Tensor.randn(10), Tensor.randn(10)
|
|
np.testing.assert_allclose(f(a, b).numpy(), a.numpy() + b.numpy(), atol=1e-5)
|
|
|
|
def test_kwargs_order_doesnt_matter(self):
|
|
"""Kwargs are sorted by name, so order doesn't matter."""
|
|
@TinyJit
|
|
def f(first, second): return (first / second).realize()
|
|
|
|
for _ in range(3):
|
|
a, b = Tensor.randn(10), Tensor.randn(10) + 1
|
|
np.testing.assert_allclose(f(second=b, first=a).numpy(), a.numpy() / b.numpy(), atol=1e-4)
|
|
np.testing.assert_allclose(f(first=a, second=b).numpy(), a.numpy() / b.numpy(), atol=1e-4)
|
|
|
|
def test_input_mutation_consistent(self):
|
|
"""Input mutation via assign works consistently."""
|
|
@TinyJit
|
|
def f(x):
|
|
x += 1
|
|
x.realize()
|
|
return x
|
|
|
|
a = Tensor([0]).contiguous().realize()
|
|
for _ in range(5):
|
|
f(a)
|
|
self.assertEqual(a.item(), 5)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|