mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move gradient.py to mixin/ [PR] (#16583)
This commit is contained in:
parent
a2cec397f3
commit
762f50bd52
4 changed files with 3 additions and 3 deletions
|
|
@ -126,8 +126,8 @@ do_not_mutate = [
|
|||
"tinygrad/viz/*",
|
||||
"tinygrad/device.py",
|
||||
"tinygrad/dtype.py",
|
||||
"tinygrad/gradient.py",
|
||||
"tinygrad/helpers.py",
|
||||
"tinygrad/mixin/gradient.py",
|
||||
"tinygrad/tensor.py",
|
||||
]
|
||||
tests_dir = ["test/test_tiny.py", "test/backend/test_ops.py"]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin.gradient import compute_gradient
|
||||
|
||||
class TestGradient(unittest.TestCase):
|
||||
def _cmp_nan_okay(self, x, y):
|
||||
|
|
|
|||
|
|
@ -460,7 +460,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
"""
|
||||
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient")
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin.gradient import compute_gradient
|
||||
if gradient is None: gradient = self.const_like(1.0)
|
||||
target_uops = [t._uop for t in targets]
|
||||
grads = compute_gradient(self._uop, gradient._uop, set(target_uops))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue