mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
one more opt test
This commit is contained in:
parent
dd543fbc7a
commit
7909786dbf
1 changed files with 15 additions and 0 deletions
|
|
@ -59,6 +59,21 @@ class TestOpt(unittest.TestCase):
|
|||
assert len(CL.CACHE) == 3, "optimizer didn't fold batchnorm"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
opt = optim.SGD(optim.get_parameters(c1))
|
||||
with CLCache():
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
opt.step()
|
||||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
assert len(CL.CACHE) == 5, "optimizer didn't fold conv-backward SGD"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue