mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: more mp mem fixes (#16701)
* llama: more mp mem fixes * clean: unused * fix: batch
This commit is contained in:
parent
267af9c601
commit
fe9b19b12d
2 changed files with 23 additions and 15 deletions
|
|
@ -2675,8 +2675,8 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
|
|||
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# 1x32 block scaling along the last axis
|
||||
*batch, K = x.shape
|
||||
scale_K, k_iters = K // 32, K // 128
|
||||
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
|
||||
scale_K = K // 32
|
||||
amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
|
||||
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
|
||||
x_scaled = x.float() * qscale
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue