llama: clean (#15224)

This commit is contained in:
wozeparrot 2026-03-12 04:33:59 +08:00 committed by GitHub
commit 4fab320abe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1345,58 +1345,10 @@ def train_llama3():
model = Transformer(**model_params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
# megatron-style weight initialization (normal(0, 0.02) for most, scaled init for residual projections)
# if SMALL:
# init_std = 0.02
# n_layers = model_params.get("n_layers", len(model.layers))
# scaled_init_std = init_std / math.sqrt(2 * n_layers)
# model.tok_embeddings.weight.assign(Tensor.normal(*model.tok_embeddings.weight.shape, std=init_std).cast(dtypes.bfloat16))
# model.output.weight.assign(Tensor.normal(*model.output.weight.shape, std=init_std).cast(dtypes.bfloat16))
# for layer in model.layers:
# if hasattr(layer.attention, 'wqkv'):
# layer.attention.wqkv.weight.assign(Tensor.normal(*layer.attention.wqkv.weight.shape, std=init_std).cast(dtypes.bfloat16))
# else:
# for w in [layer.attention.wq, layer.attention.wk, layer.attention.wv]:
# w.weight.assign(Tensor.normal(*w.weight.shape, std=init_std).cast(dtypes.bfloat16))
# layer.attention.wo.weight.assign(Tensor.normal(*layer.attention.wo.weight.shape, std=scaled_init_std).cast(dtypes.bfloat16))
# layer.feed_forward.w1.weight.assign(Tensor.normal(*layer.feed_forward.w1.weight.shape, std=init_std).cast(dtypes.bfloat16))
# layer.feed_forward.w2.weight.assign(Tensor.normal(*layer.feed_forward.w2.weight.shape, std=scaled_init_std).cast(dtypes.bfloat16))
# layer.feed_forward.w3.weight.assign(Tensor.normal(*layer.feed_forward.w3.weight.shape, std=init_std).cast(dtypes.bfloat16))
# Tensor.realize(*get_parameters(model))
params = get_parameters(model)
# weights are all bfloat16 for now
assert params and all(p.dtype == dtypes.bfloat16 for p in params)
# ** load pretrained weights **
if init_ckpt := getenv("INIT_CKPT", ""):
from examples.llama3 import load
from extra.models.llama import convert_from_huggingface
model_path = Path(init_ckpt)
print(f"loading pretrained weights from {model_path}")
weights = load(str(model_path / "model.safetensors.index.json"))
weights = convert_from_huggingface(weights, model_params['n_layers'], model_params['n_heads'], model_params['n_kv_heads'])
# combine wq/wk/wv into wqkv if WQKV mode
if getenv("WQKV"):
n_kv_heads, n_heads = model_params['n_kv_heads'], model_params['n_heads']
head_dim, n_rep = model_params['dim'] // n_heads, n_heads // n_kv_heads
for l in range(model_params['n_layers']):
wq, wk, wv = (weights.pop(f"layers.{l}.attention.w{x}.weight") for x in ["q", "k", "v"])
weights[f"layers.{l}.attention.wqkv.weight"] = \
wq.reshape(n_kv_heads, n_rep, head_dim, -1).cat(wk.reshape(n_kv_heads, 1, head_dim, -1),
wv.reshape(n_kv_heads, 1, head_dim, -1), dim=1).reshape(-1, model_params['dim'])
# handle vocab size mismatch (pretrained vocab != training vocab)
for emb_key in ["tok_embeddings.weight", "output.weight"]:
if emb_key in weights and weights[emb_key].shape[0] != model_params['vocab_size']:
pretrained_vocab = weights[emb_key].shape[0]
target_vocab = model_params['vocab_size']
print(f"{emb_key}: pretrained vocab {pretrained_vocab} -> training vocab {target_vocab}")
if pretrained_vocab > target_vocab: weights[emb_key] = weights[emb_key][:target_vocab]
else: weights[emb_key] = weights[emb_key].pad(((0, target_vocab - pretrained_vocab), (0, 0)))
# cast to bf16 if needed
weights = {k: v.cast(dtypes.bfloat16) if v.dtype != dtypes.bfloat16 else v for k, v in weights.items()}
load_state_dict(model, weights, strict=False, consume=True)
if getenv("FAKEDATA"):
for v in get_parameters(model):
v = v.assign(Tensor.empty(v.shape))