llama: immediate scaling on flag (#16494)

This commit is contained in:
wozeparrot 2026-06-04 13:30:00 -04:00 committed by GitHub
commit f11f63007d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,6 +7,7 @@ from tinygrad.uop.ops import UOp, Ops
STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0)
MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0)
FP8_AMAX_MARGIN = getenv("FP8_AMAX_MARGIN", 1.1)
IMMEDIATE_SCALE = getenv("IMMEDIATE_SCALE", 0)
def stochastic_round_bf16(x:Tensor) -> Tensor:
bits = x.bitcast(dtypes.uint32)
@ -90,6 +91,13 @@ class GradAccClipAdamW(Optimizer):
return out.shard_like(t) if offloaded else out
if t.dtype in dtypes.fp8s:
from examples.mlperf.models.flat_llama import FP8_MAX
if IMMEDIATE_SCALE:
amax_axis = tuple(range(t._inv_scale.ndim, new_w.ndim))
new_inv = ((new_w.float().abs().max(axis=amax_axis).detach() + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype)
t._inv_scale.assign(new_inv.shard_like(t._inv_scale) if offloaded else new_inv)
scale = new_inv.reciprocal().reshape(*new_inv.shape, *([1]*(new_w.ndim-new_inv.ndim)))
ret = (new_w * scale).clamp(-FP8_MAX, FP8_MAX).cast(t.dtype)
return ret.shard_like(t) if offloaded else ret
# delayed scaling: reuse previous step's inv_scale
t._inv_scale.assign(t._next_inv_scale)
inv_scale = t._inv_scale.to(new_w.device) if offloaded else t._inv_scale