llama: realize amax (#16308)

This commit is contained in:
wozeparrot 2026-05-21 17:00:48 -04:00 committed by GitHub
commit fb718a5e9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -328,6 +328,9 @@ if __name__ == "__main__":
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device).contiguous()
grads = {x:_make_grad(x) for x in state.values() if x.requires_grad}
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts]
# print model size
sz = 0
for k,v in state.items():
@ -349,7 +352,7 @@ if __name__ == "__main__":
with Timing("python backward: "):
for t,g in zip(grads, loss.gradient(*grads)):
apply_grad(grads[t], g.uop)
with Timing("run fwd_bwd: "): loss.realize(*grads.values())
with Timing("run fwd_bwd: "): loss.realize(*grads.values(), *fp8_amax, *fp8_grad_amax)
@TinyJit
def optim_step():