mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: clean (#15224)
This commit is contained in:
parent
05d6d9120a
commit
4fab320abe
1 changed files with 0 additions and 48 deletions
|
|
@ -1345,58 +1345,10 @@ def train_llama3():
|
|||
|
||||
model = Transformer(**model_params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||
|
||||
# megatron-style weight initialization (normal(0, 0.02) for most, scaled init for residual projections)
|
||||
# if SMALL:
|
||||
# init_std = 0.02
|
||||
# n_layers = model_params.get("n_layers", len(model.layers))
|
||||
# scaled_init_std = init_std / math.sqrt(2 * n_layers)
|
||||
# model.tok_embeddings.weight.assign(Tensor.normal(*model.tok_embeddings.weight.shape, std=init_std).cast(dtypes.bfloat16))
|
||||
# model.output.weight.assign(Tensor.normal(*model.output.weight.shape, std=init_std).cast(dtypes.bfloat16))
|
||||
# for layer in model.layers:
|
||||
# if hasattr(layer.attention, 'wqkv'):
|
||||
# layer.attention.wqkv.weight.assign(Tensor.normal(*layer.attention.wqkv.weight.shape, std=init_std).cast(dtypes.bfloat16))
|
||||
# else:
|
||||
# for w in [layer.attention.wq, layer.attention.wk, layer.attention.wv]:
|
||||
# w.weight.assign(Tensor.normal(*w.weight.shape, std=init_std).cast(dtypes.bfloat16))
|
||||
# layer.attention.wo.weight.assign(Tensor.normal(*layer.attention.wo.weight.shape, std=scaled_init_std).cast(dtypes.bfloat16))
|
||||
# layer.feed_forward.w1.weight.assign(Tensor.normal(*layer.feed_forward.w1.weight.shape, std=init_std).cast(dtypes.bfloat16))
|
||||
# layer.feed_forward.w2.weight.assign(Tensor.normal(*layer.feed_forward.w2.weight.shape, std=scaled_init_std).cast(dtypes.bfloat16))
|
||||
# layer.feed_forward.w3.weight.assign(Tensor.normal(*layer.feed_forward.w3.weight.shape, std=init_std).cast(dtypes.bfloat16))
|
||||
# Tensor.realize(*get_parameters(model))
|
||||
|
||||
params = get_parameters(model)
|
||||
# weights are all bfloat16 for now
|
||||
assert params and all(p.dtype == dtypes.bfloat16 for p in params)
|
||||
|
||||
# ** load pretrained weights **
|
||||
if init_ckpt := getenv("INIT_CKPT", ""):
|
||||
from examples.llama3 import load
|
||||
from extra.models.llama import convert_from_huggingface
|
||||
model_path = Path(init_ckpt)
|
||||
print(f"loading pretrained weights from {model_path}")
|
||||
weights = load(str(model_path / "model.safetensors.index.json"))
|
||||
weights = convert_from_huggingface(weights, model_params['n_layers'], model_params['n_heads'], model_params['n_kv_heads'])
|
||||
# combine wq/wk/wv into wqkv if WQKV mode
|
||||
if getenv("WQKV"):
|
||||
n_kv_heads, n_heads = model_params['n_kv_heads'], model_params['n_heads']
|
||||
head_dim, n_rep = model_params['dim'] // n_heads, n_heads // n_kv_heads
|
||||
for l in range(model_params['n_layers']):
|
||||
wq, wk, wv = (weights.pop(f"layers.{l}.attention.w{x}.weight") for x in ["q", "k", "v"])
|
||||
weights[f"layers.{l}.attention.wqkv.weight"] = \
|
||||
wq.reshape(n_kv_heads, n_rep, head_dim, -1).cat(wk.reshape(n_kv_heads, 1, head_dim, -1),
|
||||
wv.reshape(n_kv_heads, 1, head_dim, -1), dim=1).reshape(-1, model_params['dim'])
|
||||
# handle vocab size mismatch (pretrained vocab != training vocab)
|
||||
for emb_key in ["tok_embeddings.weight", "output.weight"]:
|
||||
if emb_key in weights and weights[emb_key].shape[0] != model_params['vocab_size']:
|
||||
pretrained_vocab = weights[emb_key].shape[0]
|
||||
target_vocab = model_params['vocab_size']
|
||||
print(f"{emb_key}: pretrained vocab {pretrained_vocab} -> training vocab {target_vocab}")
|
||||
if pretrained_vocab > target_vocab: weights[emb_key] = weights[emb_key][:target_vocab]
|
||||
else: weights[emb_key] = weights[emb_key].pad(((0, target_vocab - pretrained_vocab), (0, 0)))
|
||||
# cast to bf16 if needed
|
||||
weights = {k: v.cast(dtypes.bfloat16) if v.dtype != dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
load_state_dict(model, weights, strict=False, consume=True)
|
||||
|
||||
if getenv("FAKEDATA"):
|
||||
for v in get_parameters(model):
|
||||
v = v.assign(Tensor.empty(v.shape))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue