mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
many jit tests belong in unit (#16508)
This commit is contained in:
parent
bb407d8b3c
commit
9b0f75622c
4 changed files with 425 additions and 447 deletions
|
|
@ -2,40 +2,15 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import assert_jit_cache_len, call_is_graph, not_support_multi_device, needs_second_gpu
|
||||
from tinygrad import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit, JitError, graph_class
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import Context, JIT, DEV, GlobalCounters
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from extra.models.unet import ResBlock
|
||||
from test.unit.test_jit import _simple_test
|
||||
from tinygrad import Tensor, Variable, TinyJit, Device, dtypes
|
||||
from tinygrad.engine.jit import graph_class
|
||||
from tinygrad.helpers import JIT, DEV, GlobalCounters
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.renderer.isa.x86 import X86Renderer
|
||||
|
||||
def _simple_test(add, extract=lambda x: x, N=10):
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(N, N)
|
||||
b = Tensor.randn(N, N)
|
||||
c = add(a, b)
|
||||
np.testing.assert_allclose(extract(c).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add, 1)
|
||||
|
||||
class TestJit(unittest.TestCase):
|
||||
|
||||
@settings(deadline=2e4)
|
||||
@unittest.skipUnless(Device.DEFAULT in ["CPU"], f"no support on {Device.DEFAULT}")
|
||||
@given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin]))
|
||||
def test_approx_jit_timeout(self, op):
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
model = [ResBlock(16, 24, 16) for _ in range(4)]
|
||||
@TinyJit
|
||||
def fw_approx(t, t2):
|
||||
for l in model: t = l(t, t2)
|
||||
return op(t).realize()
|
||||
fw_approx(Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24))
|
||||
|
||||
def test_simple_jit(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
|
|
@ -54,40 +29,6 @@ class TestJit(unittest.TestCase):
|
|||
self.assertEqual(y.shape, (2*N,))
|
||||
self.assertEqual(z.shape, (N,))
|
||||
|
||||
def test_jitbeam_triggers_beam(self):
|
||||
from unittest.mock import patch
|
||||
from tinygrad.helpers import getenv as _getenv
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
a, b = Tensor.ones(10, 10).contiguous().realize(), Tensor.ones(10, 10).contiguous().realize()
|
||||
with patch("tinygrad.codegen.opt.search.beam_search", wraps=lambda k,*a,**kw: k) as mock_beam:
|
||||
add(a, b)
|
||||
assert mock_beam.call_count == 0
|
||||
with patch("tinygrad.engine.jit.getenv", side_effect=lambda k, d=0: 1 if k == "JITBEAM" else _getenv(k, d)): add(a, b)
|
||||
assert mock_beam.call_count == 1
|
||||
|
||||
def test_simple_jit_reset(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
_simple_test(add)
|
||||
add.reset()
|
||||
_simple_test(add, N=20)
|
||||
|
||||
def test_simple_jit_norealize(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b)
|
||||
_simple_test(add)
|
||||
|
||||
def test_simple_jit_norealize_list(self):
|
||||
@TinyJit
|
||||
def add(a, b): return [a+b]
|
||||
_simple_test(add, extract=lambda x: x[0])
|
||||
|
||||
def test_simple_jit_norealize_dict(self):
|
||||
@TinyJit
|
||||
def add(a, b): return {"billy": a+b}
|
||||
_simple_test(add, extract=lambda x: x["billy"])
|
||||
|
||||
def test_jit_input_view(self):
|
||||
@TinyJit
|
||||
def f(x): return (x[2:5].contiguous() + 1).realize()
|
||||
|
|
@ -95,18 +36,6 @@ class TestJit(unittest.TestCase):
|
|||
x = (Tensor.arange(10).float() + i * 10).clone().realize()
|
||||
np.testing.assert_allclose(f(x).numpy(), x.numpy()[2:5] + 1)
|
||||
|
||||
def test_jit_multiple_outputs(self):
|
||||
@TinyJit
|
||||
def f(a, b): return (a+b).realize(), (a-b).realize(), (a*b).realize()
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
c, d, e = f(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(f, 3)
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "estimates are wrong for x86")
|
||||
def test_global_counters_jit(self):
|
||||
@TinyJit
|
||||
|
|
@ -122,78 +51,6 @@ class TestJit(unittest.TestCase):
|
|||
self.assertGreater(GlobalCounters.global_mem, 0)
|
||||
self.assertGreater(GlobalCounters.global_ops, 0)
|
||||
|
||||
def test_nothing_jitted(self):
|
||||
@TinyJit
|
||||
def add(a, b): return None
|
||||
with self.assertRaises(JitError):
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
add(a, b)
|
||||
|
||||
def test_jit_zero_does_not_jit(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
with Context(JIT=0):
|
||||
for i in range(5):
|
||||
a = Tensor([i])
|
||||
b = Tensor([i])
|
||||
c = add(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add, 0)
|
||||
|
||||
def test_jit_not_capturing(self):
|
||||
@TinyJit
|
||||
def add(a, b):
|
||||
Tensor.zeros(4, 4).contiguous().realize() # no-op kernel is captured
|
||||
return (a+b).realize()
|
||||
for i in range(5):
|
||||
a = Tensor([i])
|
||||
b = Tensor([i])
|
||||
c = add(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add, 2)
|
||||
|
||||
@TinyJit
|
||||
def add2(a, b):
|
||||
with Context(CAPTURING=0): # not captured
|
||||
Tensor.zeros(4, 4).contiguous().realize()
|
||||
return (a+b).realize()
|
||||
for i in range(5):
|
||||
a = Tensor([i])
|
||||
b = Tensor([i])
|
||||
c = add2(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add2, 1)
|
||||
|
||||
def test_jit_shape_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
add(a, b)
|
||||
bad = Tensor.randn(20, 20)
|
||||
with self.assertRaises(JitError):
|
||||
add(a, bad)
|
||||
|
||||
def test_jit_shape_views_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a): return (a+1).realize()
|
||||
with self.assertRaises(JitError):
|
||||
for i in range(1,5):
|
||||
# a has an offset that the kernel doesn't know about
|
||||
a = Tensor.randn(10, 10).realize()[:, i:i+2]
|
||||
add(a)
|
||||
|
||||
def test_jit_duplicate_fail(self):
|
||||
# the jit doesn't support duplicate arguments
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
a = Tensor.randn(10, 10)
|
||||
with self.assertRaises(JitError):
|
||||
add(a, a)
|
||||
|
||||
def test_jit_assign(self, dtype=dtypes.float32):
|
||||
@TinyJit
|
||||
def add(a):
|
||||
|
|
@ -205,39 +62,6 @@ class TestJit(unittest.TestCase):
|
|||
|
||||
def test_jit_assign_int8(self): self.test_jit_assign(dtypes.int8)
|
||||
|
||||
def test_kwargs_jit(self):
|
||||
@TinyJit
|
||||
def add_kwargs(first, second): return (first+second).realize()
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
c = add_kwargs(first=a, second=b)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add_kwargs, 1)
|
||||
|
||||
def test_reorder_kwargs_jit(self):
|
||||
@TinyJit
|
||||
def add_kwargs(first, second): return (first/second).realize()
|
||||
for _ in range(2):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
c = add_kwargs(second=b, first=a)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()/b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
for _ in range(2):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
c = add_kwargs(first=a, second=b)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()/b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add_kwargs, 1)
|
||||
|
||||
def test_array_jit(self):
|
||||
@TinyJit
|
||||
def add_array(a, arr): return (a+arr[0]).realize()
|
||||
for _ in range(5):
|
||||
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
|
||||
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add_array, 1)
|
||||
|
||||
def test_jit_copyin(self):
|
||||
@TinyJit
|
||||
def f(a):
|
||||
|
|
@ -247,183 +71,6 @@ class TestJit(unittest.TestCase):
|
|||
c = f(b)
|
||||
np.testing.assert_allclose(c.numpy(), b.numpy()+[1,2,3], atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_method_jit(self):
|
||||
class Fun:
|
||||
def __init__(self):
|
||||
self.a = Tensor.randn(10, 10)
|
||||
@TinyJit
|
||||
def __call__(self, b:Tensor) -> Tensor:
|
||||
return (self.a+b).realize()
|
||||
fun = Fun()
|
||||
for _ in range(5):
|
||||
b = Tensor.randn(10, 10)
|
||||
c = fun(b)
|
||||
np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(fun.__call__.func.__self__, 1)
|
||||
|
||||
def test_jit_size1_input(self):
|
||||
@TinyJit
|
||||
def f(a, b): return (a+b).realize()
|
||||
a = Tensor([1, 2, 3])
|
||||
for i in range(5):
|
||||
np.testing.assert_allclose(f(a, Tensor([i])).numpy(), (a+i).numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(f, 1)
|
||||
|
||||
def test_jit_output_non_tensor_fail(self):
|
||||
@TinyJit
|
||||
def f(a, b, i): return (a+b).realize(), i
|
||||
with self.assertRaises(JitError):
|
||||
for i in range(3):
|
||||
f(Tensor.randn(10, 10), Tensor.randn(10, 10), i)
|
||||
|
||||
def test_jit_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
return ((a+b)*rn).realize()
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf = TinyJit(f)
|
||||
res = set()
|
||||
for _ in range(5):
|
||||
o1 = jf(a, b)
|
||||
res.add(o1.numpy()[0][0])
|
||||
assert len(res) == 5, "All values should be different, rand works in jit."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf2 = TinyJit(f)
|
||||
res2 = set()
|
||||
for _ in range(5):
|
||||
o1 = jf2(a, b)
|
||||
res2.add(o1.numpy()[0][0])
|
||||
assert len(res2) == 5, "All values should be different, rand works in jit."
|
||||
assert res == res2, "Jit rand is not reproducible with the same seed"
|
||||
|
||||
Tensor.manual_seed(3421)
|
||||
jf3 = TinyJit(f)
|
||||
res3 = set()
|
||||
for _ in range(5):
|
||||
o1 = jf3(a, b)
|
||||
res3.add(o1.numpy()[0][0])
|
||||
assert len(res3) == 5, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
def test_jit_v_nojit_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
rn = rn * a
|
||||
rn2 = Tensor.randn(*a.shape)
|
||||
rn2 = rn2 * b
|
||||
rn = rn + rn2
|
||||
rn2 = rn2 + Tensor.randn(*a.shape)
|
||||
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
without_jit = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = f(a, b)
|
||||
without_jit.add(o1.numpy()[0][0])
|
||||
without_jit.add(o2.numpy()[0][0])
|
||||
assert len(without_jit) == 10, "All values should be different."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf = TinyJit(f)
|
||||
with_jit = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf(a, b)
|
||||
with_jit.add(o1.numpy()[0][0])
|
||||
with_jit.add(o2.numpy()[0][0])
|
||||
assert len(with_jit) == 10, "All values should be different."
|
||||
assert with_jit == without_jit, "jit and non-jit should produce the same random values with the same seed"
|
||||
|
||||
def test_jit_multiple_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
rn = rn * a
|
||||
rn2 = Tensor.randn(*a.shape)
|
||||
rn2 = rn2 * b
|
||||
rn = rn + rn2
|
||||
rn2 = rn2 + Tensor.randn(*a.shape)
|
||||
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf = TinyJit(f)
|
||||
res = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf(a, b)
|
||||
res.add(o1.numpy()[0][0])
|
||||
res.add(o2.numpy()[0][0])
|
||||
assert len(res) == 10, "All values should be different, rand works in jit."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf2 = TinyJit(f)
|
||||
res2 = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf2(a, b)
|
||||
res2.add(o1.numpy()[0][0])
|
||||
res2.add(o2.numpy()[0][0])
|
||||
assert len(res2) == 10, "All values should be different, rand works in jit."
|
||||
assert res == res2, "Jit rand is not reproducible with the same seed"
|
||||
|
||||
Tensor.manual_seed(3421)
|
||||
jf3 = TinyJit(f)
|
||||
res3 = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf3(a, b)
|
||||
res3.add(o1.numpy()[0][0])
|
||||
res3.add(o2.numpy()[0][0])
|
||||
assert len(res3) == 10, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
def test_jit_random_after_unrealized_random(self):
|
||||
@TinyJit
|
||||
def f(): return Tensor.rand()
|
||||
Tensor.manual_seed(1234)
|
||||
Tensor.rand()
|
||||
res = [f().numpy() for _ in range(3)]
|
||||
assert res[1] != res[2]
|
||||
|
||||
def test_jit_realization_and_sampling(self):
|
||||
w = Tensor.eye(5)
|
||||
|
||||
@TinyJit
|
||||
def foo (x): return w.dot(x).realize()
|
||||
|
||||
arg = [
|
||||
Tensor([1,2,3,4,5]),
|
||||
Tensor([1,3,3,4,6]),
|
||||
Tensor([1,2,5,4,7]),
|
||||
Tensor([0,2,3,1,0]),
|
||||
]
|
||||
|
||||
Y = [foo(e).numpy() for e in arg]
|
||||
|
||||
foo(Tensor([7,7,7,7,7]))
|
||||
want = [[1., 2., 3., 4., 5.],
|
||||
[1., 3., 3., 4., 6.],
|
||||
[1., 2., 5., 4., 7.],
|
||||
[0., 2., 3., 1., 0.]]
|
||||
np.testing.assert_allclose(want, Y)
|
||||
|
||||
def test_jit_buffer_behavior(self):
|
||||
@TinyJit
|
||||
def foo(x) -> Tensor: return x.sum().realize()
|
||||
|
||||
result_1 = foo(Tensor([1] * 2))
|
||||
result_2 = foo(Tensor([2] * 2))
|
||||
result_3 = foo(Tensor([3] * 2))
|
||||
|
||||
# expect the buffer to share underlying buffer
|
||||
np.testing.assert_allclose(result_1.numpy(), [2], atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(result_2.numpy(), [6], atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(result_3.numpy(), [6], atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_jit_batch_split(self):
|
||||
if Device[Device.DEFAULT].graph is None or JIT >= 2: raise unittest.SkipTest("only test graphs")
|
||||
|
||||
|
|
@ -513,49 +160,6 @@ class TestJit(unittest.TestCase):
|
|||
xc = jf(a)
|
||||
np.testing.assert_allclose((a.numpy().sum(axis=(1,)) + 5).view(np.int32), xc.numpy(), atol=1e-4, rtol=5e-5)
|
||||
|
||||
def test_jit_output_clone(self):
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return (x + 1).realize()
|
||||
|
||||
f(Tensor([0.0]))
|
||||
f(Tensor([0.0]))
|
||||
|
||||
a = f(Tensor([1.0])).clone().realize()
|
||||
b = f(Tensor([2.0]))
|
||||
assert abs((a - b).item()) > 0.5
|
||||
|
||||
def test_jit_init_empty(self):
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return (x + 1).realize()
|
||||
|
||||
f(Tensor.empty(1))
|
||||
f(Tensor.empty(1))
|
||||
# scalar const input is not allowed
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor(2.0)).item()
|
||||
# self.assertEqual(f(Tensor([2.0])).item(), 1.0) # TODO: wrong output, should be 3.0. currently depends on empty value
|
||||
|
||||
def test_jit_const_input(self):
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return (x + 1).realize()
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor(UOp.const(dtypes.float, 2.0))).item()
|
||||
|
||||
def test_jit_deviceless_compute_input(self):
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return (x + 1).realize()
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor(UOp.const(dtypes.float, 2.0) + UOp.const(dtypes.float, 1.0))).item()
|
||||
|
||||
def test_jit_init_empty_alt(self):
|
||||
@TinyJit
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return b.assign(a+1)
|
||||
for i in range(4):
|
||||
a = Tensor([i])
|
||||
b = Tensor.empty_like(a)
|
||||
c = f(a, b)
|
||||
self.assertEqual(c.item(), i+1)
|
||||
|
||||
@unittest.skip("Pending multioutput implementation #3607")
|
||||
class TestMultioutputJit(unittest.TestCase):
|
||||
def _test(self, f):
|
||||
|
|
@ -584,19 +188,6 @@ class TestMultioutputJit(unittest.TestCase):
|
|||
self._test(fxn)
|
||||
assert_jit_cache_len(fxn, 2)
|
||||
|
||||
class TestJitInsideJit(unittest.TestCase):
|
||||
def test_jit_jit_error(self):
|
||||
@TinyJit
|
||||
def f(t): return t + 1
|
||||
|
||||
@TinyJit
|
||||
def g(t): return f(t) * 3
|
||||
|
||||
# NOTE: first does not raise
|
||||
g(Tensor([1])).realize()
|
||||
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
|
||||
g(Tensor([1])).realize()
|
||||
|
||||
class TestCopyInsideJit(unittest.TestCase):
|
||||
def test_copy_inside_jit(self):
|
||||
@TinyJit
|
||||
|
|
@ -609,24 +200,6 @@ class TestCopyInsideJit(unittest.TestCase):
|
|||
np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
|
||||
|
||||
class TestJitPrune(unittest.TestCase):
|
||||
def test_simple_prune(self):
|
||||
weights = Tensor.rand(16).realize()
|
||||
def w2(x) -> Tensor: return (weights*2).contiguous() + x
|
||||
w2_noprune = TinyJit(w2)
|
||||
w2_prune = TinyJit(w2, prune=True)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_noprune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert_jit_cache_len(w2_noprune, 2)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert_jit_cache_len(w2_prune, 1)
|
||||
|
||||
def test_prune_w_copy_correct(self):
|
||||
weights = Tensor.rand(16).realize()
|
||||
def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
|
||||
|
|
@ -873,20 +446,5 @@ class TestJitGraphSplit(unittest.TestCase):
|
|||
multigraph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
|
||||
hcqgraph=[self.ji_graph(4)])
|
||||
|
||||
class TestJitRandom(unittest.TestCase):
|
||||
def test_jit_rangeify(self):
|
||||
tst = {0:[], 1:[]}
|
||||
for r in [0,1]:
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(JIT=r):
|
||||
_ = Tensor.randint(4, high=3)
|
||||
# this second one makes the behavior different
|
||||
_ = Tensor.randint(4, high=3)
|
||||
@TinyJit
|
||||
def f(): return Tensor.randint(20, high=5)
|
||||
for _ in range(5): tst[r].append(f().tolist())
|
||||
for i, (t0, t1) in enumerate(zip(tst[0], tst[1])):
|
||||
self.assertListEqual(t0, t1, msg=f"mismatch at list {i}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
420
test/unit/test_jit.py
Normal file
420
test/unit/test_jit.py
Normal file
|
|
@ -0,0 +1,420 @@
|
|||
import unittest, numpy as np
|
||||
from test.helpers import assert_jit_cache_len
|
||||
from tinygrad import Tensor, TinyJit, Context, UOp, dtypes
|
||||
from tinygrad.engine.jit import JitError
|
||||
|
||||
def _simple_test(add, extract=lambda x: x, N=10):
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(N, N)
|
||||
b = Tensor.randn(N, N)
|
||||
c = add(a, b)
|
||||
np.testing.assert_allclose(extract(c).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add, 1)
|
||||
|
||||
class TestJit(unittest.TestCase):
|
||||
def test_jitbeam_triggers_beam(self):
|
||||
from unittest.mock import patch
|
||||
from tinygrad.helpers import getenv as _getenv
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
a, b = Tensor.ones(10, 10).contiguous().realize(), Tensor.ones(10, 10).contiguous().realize()
|
||||
with patch("tinygrad.codegen.opt.search.beam_search", wraps=lambda k,*a,**kw: k) as mock_beam:
|
||||
add(a, b)
|
||||
assert mock_beam.call_count == 0
|
||||
with patch("tinygrad.engine.jit.getenv", side_effect=lambda k, d=0: 1 if k == "JITBEAM" else _getenv(k, d)): add(a, b)
|
||||
assert mock_beam.call_count == 1
|
||||
|
||||
def test_simple_jit_reset(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
_simple_test(add)
|
||||
add.reset()
|
||||
_simple_test(add, N=20)
|
||||
|
||||
def test_simple_jit_norealize(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b)
|
||||
_simple_test(add)
|
||||
|
||||
def test_simple_jit_norealize_list(self):
|
||||
@TinyJit
|
||||
def add(a, b): return [a+b]
|
||||
_simple_test(add, extract=lambda x: x[0])
|
||||
|
||||
def test_simple_jit_norealize_dict(self):
|
||||
@TinyJit
|
||||
def add(a, b): return {"billy": a+b}
|
||||
_simple_test(add, extract=lambda x: x["billy"])
|
||||
|
||||
def test_jit_multiple_outputs(self):
|
||||
@TinyJit
|
||||
def f(a, b): return (a+b).realize(), (a-b).realize(), (a*b).realize()
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
c, d, e = f(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(f, 3)
|
||||
|
||||
def test_nothing_jitted(self):
|
||||
@TinyJit
|
||||
def add(a, b): return None
|
||||
with self.assertRaises(JitError):
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
add(a, b)
|
||||
|
||||
def test_jit_zero_does_not_jit(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
with Context(JIT=0):
|
||||
for i in range(5):
|
||||
a = Tensor([i])
|
||||
b = Tensor([i])
|
||||
c = add(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add, 0)
|
||||
|
||||
def test_jit_not_capturing(self):
|
||||
@TinyJit
|
||||
def add(a, b):
|
||||
Tensor.zeros(4, 4).contiguous().realize() # no-op kernel is captured
|
||||
return (a+b).realize()
|
||||
for i in range(5):
|
||||
a = Tensor([i])
|
||||
b = Tensor([i])
|
||||
c = add(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add, 2)
|
||||
|
||||
@TinyJit
|
||||
def add2(a, b):
|
||||
with Context(CAPTURING=0): # not captured
|
||||
Tensor.zeros(4, 4).contiguous().realize()
|
||||
return (a+b).realize()
|
||||
for i in range(5):
|
||||
a = Tensor([i])
|
||||
b = Tensor([i])
|
||||
c = add2(a, b)
|
||||
np.testing.assert_allclose(c.numpy(), 2*i)
|
||||
assert_jit_cache_len(add2, 1)
|
||||
|
||||
def test_jit_shape_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
add(a, b)
|
||||
bad = Tensor.randn(20, 20)
|
||||
with self.assertRaises(JitError):
|
||||
add(a, bad)
|
||||
|
||||
def test_jit_shape_views_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a): return (a+1).realize()
|
||||
with self.assertRaises(JitError):
|
||||
for i in range(1,5):
|
||||
# a has an offset that the kernel doesn't know about
|
||||
a = Tensor.randn(10, 10).realize()[:, i:i+2]
|
||||
add(a)
|
||||
|
||||
def test_jit_duplicate_fail(self):
|
||||
# the jit doesn't support duplicate arguments
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
a = Tensor.randn(10, 10)
|
||||
with self.assertRaises(JitError):
|
||||
add(a, a)
|
||||
|
||||
def test_kwargs_jit(self):
|
||||
@TinyJit
|
||||
def add_kwargs(first, second): return (first+second).realize()
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
c = add_kwargs(first=a, second=b)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add_kwargs, 1)
|
||||
|
||||
def test_array_jit(self):
|
||||
@TinyJit
|
||||
def add_array(a, arr): return (a+arr[0]).realize()
|
||||
for _ in range(5):
|
||||
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
|
||||
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add_array, 1)
|
||||
|
||||
def test_method_jit(self):
|
||||
class Fun:
|
||||
def __init__(self):
|
||||
self.a = Tensor.randn(10, 10)
|
||||
@TinyJit
|
||||
def __call__(self, b:Tensor) -> Tensor:
|
||||
return (self.a+b).realize()
|
||||
fun = Fun()
|
||||
for _ in range(5):
|
||||
b = Tensor.randn(10, 10)
|
||||
c = fun(b)
|
||||
np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(fun.__call__.func.__self__, 1)
|
||||
|
||||
def test_jit_size1_input(self):
|
||||
@TinyJit
|
||||
def f(a, b): return (a+b).realize()
|
||||
a = Tensor([1, 2, 3])
|
||||
for i in range(5):
|
||||
np.testing.assert_allclose(f(a, Tensor([i])).numpy(), (a+i).numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(f, 1)
|
||||
|
||||
def test_jit_output_non_tensor_fail(self):
|
||||
@TinyJit
|
||||
def f(a, b, i): return (a+b).realize(), i
|
||||
with self.assertRaises(JitError):
|
||||
for i in range(3):
|
||||
f(Tensor.randn(10, 10), Tensor.randn(10, 10), i)
|
||||
|
||||
def test_jit_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
return ((a+b)*rn).realize()
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf = TinyJit(f)
|
||||
res = set()
|
||||
for _ in range(5):
|
||||
o1 = jf(a, b)
|
||||
res.add(o1.numpy()[0][0])
|
||||
assert len(res) == 5, "All values should be different, rand works in jit."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf2 = TinyJit(f)
|
||||
res2 = set()
|
||||
for _ in range(5):
|
||||
o1 = jf2(a, b)
|
||||
res2.add(o1.numpy()[0][0])
|
||||
assert len(res2) == 5, "All values should be different, rand works in jit."
|
||||
assert res == res2, "Jit rand is not reproducible with the same seed"
|
||||
|
||||
Tensor.manual_seed(3421)
|
||||
jf3 = TinyJit(f)
|
||||
res3 = set()
|
||||
for _ in range(5):
|
||||
o1 = jf3(a, b)
|
||||
res3.add(o1.numpy()[0][0])
|
||||
assert len(res3) == 5, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
def test_jit_v_nojit_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
rn = rn * a
|
||||
rn2 = Tensor.randn(*a.shape)
|
||||
rn2 = rn2 * b
|
||||
rn = rn + rn2
|
||||
rn2 = rn2 + Tensor.randn(*a.shape)
|
||||
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
without_jit = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = f(a, b)
|
||||
without_jit.add(o1.numpy()[0][0])
|
||||
without_jit.add(o2.numpy()[0][0])
|
||||
assert len(without_jit) == 10, "All values should be different."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf = TinyJit(f)
|
||||
with_jit = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf(a, b)
|
||||
with_jit.add(o1.numpy()[0][0])
|
||||
with_jit.add(o2.numpy()[0][0])
|
||||
assert len(with_jit) == 10, "All values should be different."
|
||||
assert with_jit == without_jit, "jit and non-jit should produce the same random values with the same seed"
|
||||
|
||||
def test_jit_multiple_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
rn = rn * a
|
||||
rn2 = Tensor.randn(*a.shape)
|
||||
rn2 = rn2 * b
|
||||
rn = rn + rn2
|
||||
rn2 = rn2 + Tensor.randn(*a.shape)
|
||||
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf = TinyJit(f)
|
||||
res = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf(a, b)
|
||||
res.add(o1.numpy()[0][0])
|
||||
res.add(o2.numpy()[0][0])
|
||||
assert len(res) == 10, "All values should be different, rand works in jit."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf2 = TinyJit(f)
|
||||
res2 = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf2(a, b)
|
||||
res2.add(o1.numpy()[0][0])
|
||||
res2.add(o2.numpy()[0][0])
|
||||
assert len(res2) == 10, "All values should be different, rand works in jit."
|
||||
assert res == res2, "Jit rand is not reproducible with the same seed"
|
||||
|
||||
Tensor.manual_seed(3421)
|
||||
jf3 = TinyJit(f)
|
||||
res3 = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf3(a, b)
|
||||
res3.add(o1.numpy()[0][0])
|
||||
res3.add(o2.numpy()[0][0])
|
||||
assert len(res3) == 10, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
def test_jit_random_after_unrealized_random(self):
|
||||
@TinyJit
|
||||
def f(): return Tensor.rand()
|
||||
Tensor.manual_seed(1234)
|
||||
Tensor.rand()
|
||||
res = [f().numpy() for _ in range(3)]
|
||||
assert res[1] != res[2]
|
||||
|
||||
def test_jit_realization_and_sampling(self):
|
||||
w = Tensor.eye(5)
|
||||
|
||||
@TinyJit
|
||||
def foo (x): return w.dot(x).realize()
|
||||
|
||||
arg = [
|
||||
Tensor([1,2,3,4,5]),
|
||||
Tensor([1,3,3,4,6]),
|
||||
Tensor([1,2,5,4,7]),
|
||||
Tensor([0,2,3,1,0]),
|
||||
]
|
||||
|
||||
Y = [foo(e).numpy() for e in arg]
|
||||
|
||||
foo(Tensor([7,7,7,7,7]))
|
||||
want = [[1., 2., 3., 4., 5.],
|
||||
[1., 3., 3., 4., 6.],
|
||||
[1., 2., 5., 4., 7.],
|
||||
[0., 2., 3., 1., 0.]]
|
||||
np.testing.assert_allclose(want, Y)
|
||||
|
||||
def test_jit_buffer_behavior(self):
|
||||
@TinyJit
|
||||
def foo(x) -> Tensor: return x.sum().realize()
|
||||
|
||||
result_1 = foo(Tensor([1] * 2))
|
||||
result_2 = foo(Tensor([2] * 2))
|
||||
result_3 = foo(Tensor([3] * 2))
|
||||
|
||||
# expect the buffer to share underlying buffer
|
||||
np.testing.assert_allclose(result_1.numpy(), [2], atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(result_2.numpy(), [6], atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(result_3.numpy(), [6], atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_jit_output_clone(self):
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return (x + 1).realize()
|
||||
|
||||
f(Tensor([0.0]))
|
||||
f(Tensor([0.0]))
|
||||
|
||||
a = f(Tensor([1.0])).clone().realize()
|
||||
b = f(Tensor([2.0]))
|
||||
assert abs((a - b).item()) > 0.5
|
||||
|
||||
def test_jit_init_empty(self):
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return (x + 1).realize()
|
||||
|
||||
f(Tensor.empty(1))
|
||||
f(Tensor.empty(1))
|
||||
# scalar const input is not allowed
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor(2.0)).item()
|
||||
# self.assertEqual(f(Tensor([2.0])).item(), 1.0) # TODO: wrong output, should be 3.0. currently depends on empty value
|
||||
|
||||
def test_jit_const_input(self):
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return (x + 1).realize()
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor(UOp.const(dtypes.float, 2.0))).item()
|
||||
|
||||
def test_jit_deviceless_compute_input(self):
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return (x + 1).realize()
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor(UOp.const(dtypes.float, 2.0) + UOp.const(dtypes.float, 1.0))).item()
|
||||
|
||||
def test_jit_init_empty_alt(self):
|
||||
@TinyJit
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return b.assign(a+1)
|
||||
for i in range(4):
|
||||
a = Tensor([i])
|
||||
b = Tensor.empty_like(a)
|
||||
c = f(a, b)
|
||||
self.assertEqual(c.item(), i+1)
|
||||
|
||||
class TestJitPrune(unittest.TestCase):
|
||||
def test_simple_prune(self):
|
||||
weights = Tensor.rand(16).realize()
|
||||
def w2(x) -> Tensor: return (weights*2).contiguous() + x
|
||||
w2_noprune = TinyJit(w2)
|
||||
w2_prune = TinyJit(w2, prune=True)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_noprune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert_jit_cache_len(w2_noprune, 2)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert_jit_cache_len(w2_prune, 1)
|
||||
|
||||
|
||||
class TestJitInsideJit(unittest.TestCase):
|
||||
def test_jit_jit_error(self):
|
||||
@TinyJit
|
||||
def f(t): return t + 1
|
||||
|
||||
@TinyJit
|
||||
def g(t): return f(t) * 3
|
||||
|
||||
# NOTE: first does not raise
|
||||
g(Tensor([1])).realize()
|
||||
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
|
||||
g(Tensor([1])).realize()
|
||||
|
||||
class TestJitRandom(unittest.TestCase):
|
||||
def test_jit_rangeify(self):
|
||||
tst = {0:[], 1:[]}
|
||||
for r in [0,1]:
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(JIT=r):
|
||||
_ = Tensor.randint(4, high=3)
|
||||
# this second one makes the behavior different
|
||||
_ = Tensor.randint(4, high=3)
|
||||
@TinyJit
|
||||
def f(): return Tensor.randint(20, high=5)
|
||||
for _ in range(5): tst[r].append(f().tolist())
|
||||
for i, (t0, t1) in enumerate(zip(tst[0], tst[1])):
|
||||
self.assertListEqual(t0, t1, msg=f"mismatch at list {i}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue