mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove noop requires_grad_ calls (#16213)
This commit is contained in:
parent
c6cf9e8f0c
commit
07a172dbbb
7 changed files with 21 additions and 26 deletions
|
|
@ -3,7 +3,6 @@ os.environ["WQKV"] = "1"
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, nn, dtypes
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.device import is_dtype_supported, Device
|
||||
from examples.mlperf.models.llama import Transformer
|
||||
from examples.mlperf.models.flat_llama import FlatTransformer
|
||||
|
|
@ -45,8 +44,6 @@ class TestFlatLlama(unittest.TestCase):
|
|||
flat = FlatTransformer(**params)
|
||||
copy_weights(flat, ref)
|
||||
|
||||
for p in get_parameters(ref): p.requires_grad_(True)
|
||||
for p in get_parameters(flat): p.requires_grad_(True)
|
||||
Tensor.realize(*nn.state.get_state_dict(flat).values())
|
||||
|
||||
tokens = Tensor([[1, 50, 100, 999, 2, 10]])
|
||||
|
|
|
|||
|
|
@ -109,9 +109,9 @@ def fa():
|
|||
def fa_bw():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0):
|
||||
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
|
||||
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
|
||||
attn_output.weight.requires_grad_().realize()
|
||||
attn_output.weight.realize()
|
||||
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
|
|
|
|||
|
|
@ -190,7 +190,6 @@ class TestSoftmaxFusion(unittest.TestCase):
|
|||
|
||||
def test_softmax_bw(self):
|
||||
print("*** softmax bw ***")
|
||||
self.test.requires_grad_()
|
||||
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
|
||||
self.test.softmax(-1).sum().backward()
|
||||
sg = self.test.grad.realize()
|
||||
|
|
|
|||
|
|
@ -141,7 +141,6 @@ class TestTiny(unittest.TestCase):
|
|||
Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)])
|
||||
|
||||
# realize gradients
|
||||
for x in nn.state.get_parameters(layers): x.requires_grad_()
|
||||
Tensor.empty(4, 1, 14, 14).sequential(layers).sum().backward()
|
||||
Tensor.realize(*[x.grad for x in nn.state.get_parameters(layers) if x.grad is not None])
|
||||
|
||||
|
|
|
|||
|
|
@ -142,9 +142,9 @@ class TestFA(unittest.TestCase):
|
|||
base_do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = base_q.clone().requires_grad_(True).shard(GPUS, axis=0)
|
||||
k = base_k.clone().requires_grad_(True).shard(GPUS, axis=0)
|
||||
v = base_v.clone().requires_grad_(True).shard(GPUS, axis=0)
|
||||
q = base_q.clone().shard(GPUS, axis=0)
|
||||
k = base_k.clone().shard(GPUS, axis=0)
|
||||
v = base_v.clone().shard(GPUS, axis=0)
|
||||
Tensor.realize(q, k, v)
|
||||
|
||||
do = base_do.clone().shard(GPUS, axis=0)
|
||||
|
|
@ -157,9 +157,9 @@ class TestFA(unittest.TestCase):
|
|||
Tensor.realize(q.grad, k.grad, v.grad)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q_ref = base_q.clone().requires_grad_(True)
|
||||
k_ref = base_k.clone().requires_grad_(True)
|
||||
v_ref = base_v.clone().requires_grad_(True)
|
||||
q_ref = base_q.clone()
|
||||
k_ref = base_k.clone()
|
||||
v_ref = base_v.clone()
|
||||
Tensor.realize(q_ref, k_ref, v_ref)
|
||||
|
||||
do_ref = base_do.clone()
|
||||
|
|
@ -189,9 +189,9 @@ class TestFA(unittest.TestCase):
|
|||
base_do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = base_q.clone().requires_grad_(True).shard(GPUS, axis=2)
|
||||
k = base_k.clone().requires_grad_(True).shard(GPUS, axis=2)
|
||||
v = base_v.clone().requires_grad_(True).shard(GPUS, axis=2)
|
||||
q = base_q.clone().shard(GPUS, axis=2)
|
||||
k = base_k.clone().shard(GPUS, axis=2)
|
||||
v = base_v.clone().shard(GPUS, axis=2)
|
||||
Tensor.realize(q, k, v)
|
||||
|
||||
do = base_do.clone().shard(GPUS, axis=2)
|
||||
|
|
@ -204,9 +204,9 @@ class TestFA(unittest.TestCase):
|
|||
Tensor.realize(q.grad, k.grad, v.grad)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q_ref = base_q.clone().requires_grad_(True)
|
||||
k_ref = base_k.clone().requires_grad_(True)
|
||||
v_ref = base_v.clone().requires_grad_(True)
|
||||
q_ref = base_q.clone()
|
||||
k_ref = base_k.clone()
|
||||
v_ref = base_v.clone()
|
||||
Tensor.realize(q_ref, k_ref, v_ref)
|
||||
|
||||
do_ref = base_do.clone()
|
||||
|
|
|
|||
|
|
@ -951,9 +951,9 @@ class TestTK(unittest.TestCase):
|
|||
base_do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = base_q.clone().requires_grad_(True).shard(GPUS, axis=0)
|
||||
k = base_k.clone().requires_grad_(True).shard(GPUS, axis=0)
|
||||
v = base_v.clone().requires_grad_(True).shard(GPUS, axis=0)
|
||||
q = base_q.clone().shard(GPUS, axis=0)
|
||||
k = base_k.clone().shard(GPUS, axis=0)
|
||||
v = base_v.clone().shard(GPUS, axis=0)
|
||||
Tensor.realize(q, k, v)
|
||||
|
||||
do = base_do.clone().shard(GPUS, axis=0)
|
||||
|
|
@ -966,9 +966,9 @@ class TestTK(unittest.TestCase):
|
|||
Tensor.realize(q.grad, k.grad, v.grad)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q_ref = base_q.clone().requires_grad_(True)
|
||||
k_ref = base_k.clone().requires_grad_(True)
|
||||
v_ref = base_v.clone().requires_grad_(True)
|
||||
q_ref = base_q.clone()
|
||||
k_ref = base_k.clone()
|
||||
v_ref = base_v.clone()
|
||||
Tensor.realize(q_ref, k_ref, v_ref)
|
||||
|
||||
do_ref = base_do.clone()
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class TestTensorGradient(unittest.TestCase):
|
|||
def test_gradient_through_clone_from_grad_src(self):
|
||||
# unlike torch, tinygrad accumulates grad on all requires_grad tensors, including non-leaf x
|
||||
src = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
||||
x = src.clone().requires_grad_(True)
|
||||
x = src.clone()
|
||||
(x * 2.0).sum().backward()
|
||||
np.testing.assert_allclose(src.grad.numpy(), [2.0, 2.0, 2.0, 2.0])
|
||||
np.testing.assert_allclose(x.grad.numpy(), [2.0, 2.0, 2.0, 2.0])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue