llama: more mp mem fixes (#16701)

* llama: more mp mem fixes

* clean: unused

* fix: batch
This commit is contained in:
wozeparrot 2026-06-22 10:54:35 -04:00 committed by GitHub
commit fe9b19b12d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 23 additions and 15 deletions

View file

@ -2675,8 +2675,8 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# 1x32 block scaling along the last axis
*batch, K = x.shape
scale_K, k_iters = K // 32, K // 128
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
scale_K = K // 32
amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
x_scaled = x.float() * qscale