remove noop requires_grad_ calls (#16213)

This commit is contained in:
chenyu 2026-05-15 13:31:10 -04:00 committed by GitHub
commit 07a172dbbb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 21 additions and 26 deletions

View file

@ -3,7 +3,6 @@ os.environ["WQKV"] = "1"
import unittest
import numpy as np
from tinygrad import Tensor, nn, dtypes
from tinygrad.nn.state import get_parameters
from tinygrad.device import is_dtype_supported, Device
from examples.mlperf.models.llama import Transformer
from examples.mlperf.models.flat_llama import FlatTransformer
@ -45,8 +44,6 @@ class TestFlatLlama(unittest.TestCase):
flat = FlatTransformer(**params)
copy_weights(flat, ref)
for p in get_parameters(ref): p.requires_grad_(True)
for p in get_parameters(flat): p.requires_grad_(True)
Tensor.realize(*nn.state.get_state_dict(flat).values())
tokens = Tensor([[1, 50, 100, 999, 2, 10]])

View file

@ -109,9 +109,9 @@ def fa():
def fa_bw():
Tensor.manual_seed(1337)
with Context(DEBUG=0):
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
attn_output.weight.requires_grad_().realize()
attn_output.weight.realize()
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
GlobalCounters.reset()

View file

@ -190,7 +190,6 @@ class TestSoftmaxFusion(unittest.TestCase):
def test_softmax_bw(self):
print("*** softmax bw ***")
self.test.requires_grad_()
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):
self.test.softmax(-1).sum().backward()
sg = self.test.grad.realize()

View file

@ -141,7 +141,6 @@ class TestTiny(unittest.TestCase):
Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)])
# realize gradients
for x in nn.state.get_parameters(layers): x.requires_grad_()
Tensor.empty(4, 1, 14, 14).sequential(layers).sum().backward()
Tensor.realize(*[x.grad for x in nn.state.get_parameters(layers) if x.grad is not None])

View file

@ -142,9 +142,9 @@ class TestFA(unittest.TestCase):
base_do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
with Context(DEBUG=0):
q = base_q.clone().requires_grad_(True).shard(GPUS, axis=0)
k = base_k.clone().requires_grad_(True).shard(GPUS, axis=0)
v = base_v.clone().requires_grad_(True).shard(GPUS, axis=0)
q = base_q.clone().shard(GPUS, axis=0)
k = base_k.clone().shard(GPUS, axis=0)
v = base_v.clone().shard(GPUS, axis=0)
Tensor.realize(q, k, v)
do = base_do.clone().shard(GPUS, axis=0)
@ -157,9 +157,9 @@ class TestFA(unittest.TestCase):
Tensor.realize(q.grad, k.grad, v.grad)
with Context(DEBUG=0):
q_ref = base_q.clone().requires_grad_(True)
k_ref = base_k.clone().requires_grad_(True)
v_ref = base_v.clone().requires_grad_(True)
q_ref = base_q.clone()
k_ref = base_k.clone()
v_ref = base_v.clone()
Tensor.realize(q_ref, k_ref, v_ref)
do_ref = base_do.clone()
@ -189,9 +189,9 @@ class TestFA(unittest.TestCase):
base_do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
with Context(DEBUG=0):
q = base_q.clone().requires_grad_(True).shard(GPUS, axis=2)
k = base_k.clone().requires_grad_(True).shard(GPUS, axis=2)
v = base_v.clone().requires_grad_(True).shard(GPUS, axis=2)
q = base_q.clone().shard(GPUS, axis=2)
k = base_k.clone().shard(GPUS, axis=2)
v = base_v.clone().shard(GPUS, axis=2)
Tensor.realize(q, k, v)
do = base_do.clone().shard(GPUS, axis=2)
@ -204,9 +204,9 @@ class TestFA(unittest.TestCase):
Tensor.realize(q.grad, k.grad, v.grad)
with Context(DEBUG=0):
q_ref = base_q.clone().requires_grad_(True)
k_ref = base_k.clone().requires_grad_(True)
v_ref = base_v.clone().requires_grad_(True)
q_ref = base_q.clone()
k_ref = base_k.clone()
v_ref = base_v.clone()
Tensor.realize(q_ref, k_ref, v_ref)
do_ref = base_do.clone()

View file

@ -951,9 +951,9 @@ class TestTK(unittest.TestCase):
base_do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
with Context(DEBUG=0):
q = base_q.clone().requires_grad_(True).shard(GPUS, axis=0)
k = base_k.clone().requires_grad_(True).shard(GPUS, axis=0)
v = base_v.clone().requires_grad_(True).shard(GPUS, axis=0)
q = base_q.clone().shard(GPUS, axis=0)
k = base_k.clone().shard(GPUS, axis=0)
v = base_v.clone().shard(GPUS, axis=0)
Tensor.realize(q, k, v)
do = base_do.clone().shard(GPUS, axis=0)
@ -966,9 +966,9 @@ class TestTK(unittest.TestCase):
Tensor.realize(q.grad, k.grad, v.grad)
with Context(DEBUG=0):
q_ref = base_q.clone().requires_grad_(True)
k_ref = base_k.clone().requires_grad_(True)
v_ref = base_v.clone().requires_grad_(True)
q_ref = base_q.clone()
k_ref = base_k.clone()
v_ref = base_v.clone()
Tensor.realize(q_ref, k_ref, v_ref)
do_ref = base_do.clone()

View file

@ -79,7 +79,7 @@ class TestTensorGradient(unittest.TestCase):
def test_gradient_through_clone_from_grad_src(self):
# unlike torch, tinygrad accumulates grad on all requires_grad tensors, including non-leaf x
src = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
x = src.clone().requires_grad_(True)
x = src.clone()
(x * 2.0).sum().backward()
np.testing.assert_allclose(src.grad.numpy(), [2.0, 2.0, 2.0, 2.0])
np.testing.assert_allclose(x.grad.numpy(), [2.0, 2.0, 2.0, 2.0])