mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
update tiny torch backend hook (#12575)
* update the backend to fix torch deprecation warning * use param_hook to avoid full backward hook needlessly firing on inputs which do not require gradients * fix indentation --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
db5ae846aa
commit
d65bd669f8
1 changed files with 5 additions and 4 deletions
|
|
@ -642,10 +642,11 @@ def get_real_tinygrad_buffers():
|
|||
torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer)
|
||||
|
||||
from torch.nn.modules import Module
|
||||
def backward_hook(model:Module, _grad_input, _grad_out):
|
||||
grads_to_realize = [unwrap(p.grad) for p in model.parameters() if p.grad is not None]
|
||||
if len(grads_to_realize): Tensor.realize(*grads_to_realize)
|
||||
def module_hook(module:Module, _name, _submodule): module.register_backward_hook(backward_hook)
|
||||
def param_hook(_grad):
|
||||
if _grad is not None and _grad.is_tiny: Tensor.realize(unwrap(_grad))
|
||||
def module_hook(module:Module, _name, _submodule):
|
||||
for param in _submodule.parameters(recurse=False):
|
||||
if param.requires_grad: param.register_hook(param_hook)
|
||||
torch.nn.modules.module.register_module_module_registration_hook(module_hook)
|
||||
|
||||
def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue