mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: more mp mem fixes (#16701)
* llama: more mp mem fixes * clean: unused * fix: batch
This commit is contained in:
parent
267af9c601
commit
fe9b19b12d
2 changed files with 23 additions and 15 deletions
|
|
@ -126,10 +126,8 @@ class FlatTransformer:
|
|||
|
||||
# FeedForward
|
||||
if SPLIT_W13:
|
||||
if getenv("ZEROS"): w13_raw = Tensor.zeros(2, self.n_layers, hidden_dim, dim)
|
||||
else: w13_raw = Tensor.normal(2, self.n_layers, hidden_dim, dim, mean=0.0, std=0.02)
|
||||
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[0])
|
||||
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim, w=w13_raw[1])
|
||||
self.w1, s_1 = self.lin_per_layer(dim, hidden_dim)
|
||||
self.w3, s_3 = self.lin_per_layer(dim, hidden_dim)
|
||||
else:
|
||||
self.w13, s_13 = self.lin_per_layer(dim, hidden_dim * 2)
|
||||
self.w2, s_2 = self.lin_per_layer(hidden_dim, dim, std=scaled_std)
|
||||
|
|
@ -160,7 +158,7 @@ class FlatTransformer:
|
|||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02, w:Tensor|None=None):
|
||||
if w is None:
|
||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std).realize()
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
|
||||
|
|
@ -247,20 +245,30 @@ class FlatTransformer:
|
|||
for v in get_parameters(self): v.shard_(device, axis=None)
|
||||
else:
|
||||
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
||||
def _shard_fp8(name:str, axis:int):
|
||||
getattr(self, name).shard_(device, axis=axis)
|
||||
scale_axis = axis if MXFP8 else (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||
def _shard_fp8(name:str, axis:int, std:float=0.02):
|
||||
w = getattr(self, name)
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_bf16 = Tensor.empty(self.n_layers, w.shape[1], w.shape[2], dtype=dtypes.bfloat16).shard(device, axis=axis).randn_like() * std
|
||||
w_q, w_e8, _ = quantize_mxfp8(w_bf16)
|
||||
w.replace(w_q)
|
||||
self._fp8_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
|
||||
self._fp8_next_inv_scale[name].replace(w_e8.contiguous()).is_param_(False)
|
||||
else:
|
||||
w.shard_(device, axis=axis)
|
||||
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
Tensor.realize(w, self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||
sstd = 0.02 / math.sqrt(2 * self.n_layers)
|
||||
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
|
||||
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
|
||||
_shard_fp8("wo", 2, sstd) # (n_layers, dim, in) shard in
|
||||
if SPLIT_W13:
|
||||
_shard_fp8("w1", 1)
|
||||
_shard_fp8("w3", 1)
|
||||
else:
|
||||
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
|
||||
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
|
||||
_shard_fp8("w2", 2, sstd) # (n_layers, dim, hidden) shard in
|
||||
self.attention_norm.shard_(device, axis=None).realize()
|
||||
self.ffn_norm.shard_(device, axis=None).realize()
|
||||
self.norm.weight.shard_(device, axis=None).realize()
|
||||
|
|
|
|||
|
|
@ -2675,8 +2675,8 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
|
|||
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# 1x32 block scaling along the last axis
|
||||
*batch, K = x.shape
|
||||
scale_K, k_iters = K // 32, K // 128
|
||||
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
|
||||
scale_K = K // 32
|
||||
amax = x.detach().float().reshape(*batch, scale_K, 32).abs().max(axis=-1)
|
||||
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
|
||||
x_scaled = x.float() * qscale
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue