mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
* switch beautiful_mnist to use new optimizer [pr] * fix abstractions3 + docs * fix OptimizerGroup with schedule_step api
90 lines
3.6 KiB
Python
90 lines
3.6 KiB
Python
from typing import Callable
|
|
import unittest, math
|
|
import numpy as np
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from tinygrad import Tensor
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.ops import UOp
|
|
from tinygrad.gradient import gradient
|
|
from tinygrad.nn.optim import SGD, Adam
|
|
|
|
class TestGradient(unittest.TestCase):
|
|
def _cmp_nan_okay(self, x, y):
|
|
if math.isnan(x) and math.isnan(y): return
|
|
self.assertAlmostEqual(x, y, places=5)
|
|
|
|
def _test_one_input_function(self, f:Callable, jf:Callable|None=None):
|
|
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
|
|
gx = gradient(f(x), [x])[0]
|
|
gf = jax.grad(f if jf is None else jf)
|
|
|
|
for val in [-5., -2.0, 0.0, 2.0, 5.]:
|
|
tg_out, jax_out = gx.substitute({x: x.const_like(val)}).ssimplify(), gf(val).item()
|
|
self._cmp_nan_okay(tg_out, jax_out)
|
|
|
|
def _test_two_input_function(self, f:Callable, jf:Callable|None=None):
|
|
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
|
|
y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float)
|
|
gx, gy = gradient(f(x, y), [x, y])
|
|
gf = jax.grad(f if jf is None else jf, argnums=(0, 1))
|
|
|
|
for valx in [-5., -2.0, 0.0, 2.0, 5.]:
|
|
for valy in [-5., -2.0, 0.0, 2.0, 5.]:
|
|
# Substitute the values into the gradient expressions
|
|
substitutions = {x: x.const_like(valx), y: y.const_like(valy)}
|
|
tg_out_x = gx.substitute(substitutions).ssimplify()
|
|
tg_out_y = gy.substitute(substitutions).ssimplify()
|
|
jax_out_x, jax_out_y = [x.item() for x in gf(valx, valy)]
|
|
|
|
self._cmp_nan_okay(tg_out_x, jax_out_x)
|
|
self._cmp_nan_okay(tg_out_y, jax_out_y)
|
|
|
|
# unary ops unit
|
|
def test_recip(self): self._test_one_input_function(lambda x: 1.0/x)
|
|
def test_sin(self): self._test_one_input_function(lambda x: x.sin(), lambda x: jnp.sin(x))
|
|
def test_sqrt(self): self._test_one_input_function(lambda x: x.sqrt(), lambda x: jnp.sqrt(x))
|
|
def test_log2(self): self._test_one_input_function(lambda x: x.log2(), lambda x: jnp.log2(x))
|
|
def test_exp2(self): self._test_one_input_function(lambda x: x.exp2(), lambda x: jnp.exp2(x))
|
|
|
|
# binary ops unit
|
|
def test_add(self): self._test_two_input_function(lambda x,y: x+y)
|
|
def test_mul(self): self._test_two_input_function(lambda x,y: x*y)
|
|
|
|
# chain rule
|
|
def test_chain(self): self._test_one_input_function(lambda x: x.sin().sqrt(), lambda x: jnp.sqrt(jnp.sin(x)))
|
|
def test_chain_binop(self): self._test_two_input_function(lambda x,y: (x*y)+x*y)
|
|
def test_big_add_sin(self): self._test_two_input_function(lambda x,y: x.sin()+3.0/y, lambda x,y: jnp.sin(x)+3.0/y)
|
|
def test_big_chain(self): self._test_two_input_function(lambda x,y: (1.0/x*y)+x*y)
|
|
|
|
# TODO: this isn't working
|
|
#def test_where(self): self._test_two_input_function(lambda x,y: (x<y).where(x,y), lambda x,y: jnp.where(x<y, x, y))
|
|
|
|
class TestTensorGradient(unittest.TestCase):
|
|
def test_example(self):
|
|
x = Tensor.eye(3)
|
|
y = Tensor([[2.0,0,-2.0]])
|
|
z = y.matmul(x).sum()
|
|
dx, dy = z.gradient(x, y)
|
|
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
|
|
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])
|
|
|
|
def test_raises(self):
|
|
x = Tensor([1.0, 2.0, 3.0])
|
|
w = Tensor.randn((3,))
|
|
with self.assertRaises(RuntimeError): x.sum().gradient(w)
|
|
|
|
def test_optim(self):
|
|
with Tensor.train():
|
|
w = Tensor([1.0, 2.0, 3.0])
|
|
SGD([w], lr=0.1).step(w.sum())
|
|
np.testing.assert_almost_equal(w.tolist(), [0.9, 1.9, 2.9])
|
|
|
|
def test_optim_rng(self):
|
|
with Tensor.train():
|
|
x = Tensor([1.0, 2.0, 3.0])
|
|
w = Tensor.randn((3,))
|
|
Adam([w], lr=0.1).step((x*w).sum())
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|