mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: realize amax (#16308)
This commit is contained in:
parent
73ea36f4ac
commit
fb718a5e9d
1 changed files with 4 additions and 1 deletions
|
|
@ -328,6 +328,9 @@ if __name__ == "__main__":
|
|||
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device).contiguous()
|
||||
grads = {x:_make_grad(x) for x in state.values() if x.requires_grad}
|
||||
|
||||
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
|
||||
fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts]
|
||||
|
||||
# print model size
|
||||
sz = 0
|
||||
for k,v in state.items():
|
||||
|
|
@ -349,7 +352,7 @@ if __name__ == "__main__":
|
|||
with Timing("python backward: "):
|
||||
for t,g in zip(grads, loss.gradient(*grads)):
|
||||
apply_grad(grads[t], g.uop)
|
||||
with Timing("run fwd_bwd: "): loss.realize(*grads.values())
|
||||
with Timing("run fwd_bwd: "): loss.realize(*grads.values(), *fp8_amax, *fp8_grad_amax)
|
||||
|
||||
@TinyJit
|
||||
def optim_step():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue