mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
llama_trai
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66676a1ac0 |
||
|
|
ee3ed9e646 | ||
|
|
62f98ef817 | ||
|
|
f6dcb9a777 | ||
|
|
dfaaeb0720 |
2 changed files with 131 additions and 28 deletions
93
examples/mlperf/llama_train.py
Normal file
93
examples/mlperf/llama_train.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import math
|
||||
from pathlib import Path
|
||||
from tinygrad import Device, nn, Tensor, TinyJit
|
||||
from tinygrad.helpers import getenv, profile_marker
|
||||
from extra.models.llama import Transformer
|
||||
from examples.llama3 import MODEL_PARAMS
|
||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
|
||||
|
||||
config = {}
|
||||
BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
|
||||
BS = config["BS"] = getenv("BS", 16)
|
||||
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
||||
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
||||
SEED = config["SEED"] = getenv("SEED", 5760)
|
||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
|
||||
SMALL = config["SMALL"] = getenv("SMALL", 0)
|
||||
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
|
||||
EVAL_FREQ = config["EVAL_FREQ"] = getenv("EVAL_FREQ", 46080)
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
|
||||
EVAL_TARGET = config["EVAL_TARGET"] = getenv("EVAL_TARGET", 5.6)
|
||||
|
||||
opt_adamw_beta_1 = 0.9
|
||||
opt_adamw_beta_2 = 0.95
|
||||
opt_adamw_epsilon = 1e-5
|
||||
opt_adamw_weight_decay = 0.1
|
||||
|
||||
opt_gradient_clip_norm = 1.0
|
||||
opt_learning_rate_warmup_steps = getenv("WARMUP_STEPS", math.ceil(8000 * 1152 / GBS))
|
||||
opt_learning_rate_decay_steps = getenv("MAX_STEPS", math.ceil(1_200_000 * 1152 / GBS)) - opt_learning_rate_warmup_steps
|
||||
opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark
|
||||
opt_end_learning_rate = getenv("END_LR", 8e-7)
|
||||
|
||||
# TODO: confirm weights are in bf16
|
||||
# vocab_size from the mixtral tokenizer
|
||||
params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
|
||||
params = params | {"vocab_size": 32000} if not SMALL else params
|
||||
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
|
||||
|
||||
if __name__ == "__main__":
|
||||
profile_marker("create model")
|
||||
|
||||
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||
|
||||
# shard the model, either data parallel (DP) or model parallel (MP)
|
||||
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
for v in nn.state.get_parameters(model):
|
||||
v.shard_(device, axis=None)
|
||||
|
||||
if (MP := getenv("MP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
||||
for k,v in nn.state.get_state_dict(model).items():
|
||||
if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
||||
elif '.attention.wq' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wk' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wv' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wo' in k: v.shard_(device, axis=1)
|
||||
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
||||
elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
|
||||
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
|
||||
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
|
||||
elif 'output.weight' in k: v.shard_(device, axis=0)
|
||||
else:
|
||||
# attention_norm, ffn_norm, norm
|
||||
v.shard_(device, axis=None)
|
||||
# prevents memory spike on device 0
|
||||
v.realize()
|
||||
|
||||
profile_marker("create optim")
|
||||
optim = nn.optim.AdamW(nn.state.get_parameters(model), lr=0.0,
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay, fused=True)
|
||||
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate,
|
||||
opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
|
||||
|
||||
profile_marker("init params")
|
||||
optim.lr.realize(*[p.replace(p.contiguous()) for p in optim.params])
|
||||
|
||||
# TODO: make this work with multigpu
|
||||
cat_params = Tensor.cat(*[t.flatten() for t in optim.params], dim=0)
|
||||
cat_grads = Tensor.zeros_like(cat_params)
|
||||
|
||||
@profile_marker("microbatch")
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def microbatch(batch:Tensor):
|
||||
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(batch[:, 1:]).backward()
|
||||
return loss.realize(cat_grads)
|
||||
|
||||
|
||||
|
||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
|
||||
|
||||
|
|
@ -1331,10 +1331,6 @@ def train_llama3():
|
|||
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
|
||||
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||
|
||||
if getenv("FAKEDATA"):
|
||||
for v in get_parameters(model):
|
||||
v = v.assign(Tensor.empty(v.shape))
|
||||
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
for v in get_parameters(model):
|
||||
|
|
@ -1360,9 +1356,13 @@ def train_llama3():
|
|||
v.realize()
|
||||
|
||||
optim = AdamW(get_parameters(model), lr=0.0,
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay, fused=True)
|
||||
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
|
||||
|
||||
# init tensors
|
||||
profile_marker("init tensors")
|
||||
optim.lr.realize(*[p.replace(p.contiguous()) for p in optim.params])
|
||||
|
||||
if resume_ckpt := getenv("RESUME_CKPT"):
|
||||
fn = f"./ckpts/llama3_{resume_ckpt}.safe"
|
||||
print(f"loading initial checkpoint from {fn}")
|
||||
|
|
@ -1374,10 +1374,15 @@ def train_llama3():
|
|||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(model, tokens:Tensor, grad_acc:int):
|
||||
def train_step(tokens:Tensor, grad_acc:int):
|
||||
optim.zero_grad()
|
||||
# grad acc
|
||||
# grad acc. NOTE: this has to become multidevice aware, this cat should be per device
|
||||
cat_params = Tensor.cat(*[t.flatten() for t in optim.params], dim=0)
|
||||
cat_grads = Tensor.zeros_like(cat_params)
|
||||
|
||||
total_loss = Tensor(0, dtype=dtypes.float)
|
||||
for batch in tokens.split(tokens.shape[0]//grad_acc):
|
||||
profile_marker("grads")
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
batch = batch.shard(device, 0)
|
||||
|
|
@ -1387,28 +1392,32 @@ def train_llama3():
|
|||
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
|
||||
loss.backward()
|
||||
Tensor.realize(*[p.grad for p in optim.params])
|
||||
total_loss += loss/grad_acc
|
||||
cat_grads += Tensor.cat(*[t.grad.flatten() for t in optim.params], dim=0)
|
||||
total_loss.realize(cat_grads)
|
||||
|
||||
# L2 norm grad clip
|
||||
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
|
||||
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
|
||||
profile_marker("optimizer")
|
||||
|
||||
if not getenv("DISABLE_GRAD_CLIP_NORM"):
|
||||
total_norm = Tensor(0.0, dtype=dtypes.float32, device=optim.params[0].device)
|
||||
for p in optim.params:
|
||||
total_norm += p.grad.float().square().sum()
|
||||
total_norm = total_norm.sqrt().contiguous()
|
||||
for p in optim.params:
|
||||
p.grad = p.grad * (opt_gradient_clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)
|
||||
total_norm = cat_grads.float().square().sum().sqrt().contiguous()
|
||||
cat_grads = cat_grads * (opt_gradient_clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)
|
||||
|
||||
optim.step()
|
||||
scheduler.step()
|
||||
# run the optimizer
|
||||
# NOTE: this is copied from _schedule_step
|
||||
out, extra = optim._step([cat_params], [cat_grads]) # this will go on CPU
|
||||
|
||||
lr = optim.lr
|
||||
loss.realize(lr)
|
||||
return loss, lr
|
||||
# update the parameters
|
||||
updated_params = [out[0][optim.pos_params[i]:optim.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(optim.params)]
|
||||
for i, tt in enumerate(optim.params): tt.assign(updated_params[i])
|
||||
Tensor.realize(*optim.params, *extra, *optim.buffers, *scheduler.schedule_step())
|
||||
return total_loss
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train(False)
|
||||
def eval_step(model, tokens:Tensor):
|
||||
def eval_step(tokens:Tensor):
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
tokens = tokens.shard(device, 0)
|
||||
|
|
@ -1449,18 +1458,19 @@ def train_llama3():
|
|||
iter = get_train_iter()
|
||||
i, sequences_seen = resume_ckpt, 0
|
||||
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
||||
profile_marker(f"train step {i}")
|
||||
t = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss, lr = train_step(model, tokens, grad_acc)
|
||||
loss = loss.float().item()
|
||||
loss = train_step(tokens, grad_acc)
|
||||
loss, lr = loss.float().item(), optim.lr.item()
|
||||
|
||||
i += 1
|
||||
sequences_seen += tokens.shape[0]
|
||||
|
||||
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
|
||||
tqdm.write(f"{loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
f.write(f"{i} {loss:.4f} {lr:.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
|
||||
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
|
||||
tqdm.write("saving checkpoint")
|
||||
|
|
@ -1481,7 +1491,7 @@ def train_llama3():
|
|||
tqdm.write(f"evaluating {5760//EVAL_BS} batches of {EVAL_BS} sequences")
|
||||
|
||||
for tokens in tqdm(eval_iter, total=5760//EVAL_BS):
|
||||
eval_losses += eval_step(model, tokens).tolist()
|
||||
eval_losses += eval_step(tokens).tolist()
|
||||
log_perplexity = Tensor(eval_losses).mean().float().item()
|
||||
|
||||
tqdm.write(f"eval log perplexity: {log_perplexity:.4f}")
|
||||
|
|
@ -1564,7 +1574,7 @@ def train_stable_diffusion():
|
|||
loss, out_lr = loss.detach().to("CPU"), optimizer.lr.to("CPU")
|
||||
Tensor.realize(loss, out_lr)
|
||||
return loss, out_lr
|
||||
|
||||
|
||||
# checkpointing takes ~9 minutes without this, and ~1 minute with this
|
||||
@TinyJit
|
||||
def ckpt_to_cpu():
|
||||
|
|
@ -1603,7 +1613,7 @@ def train_stable_diffusion():
|
|||
if i == 3:
|
||||
for _ in range(3): ckpt_to_cpu() # do this at the beginning of run to prevent OOM surprises when checkpointing
|
||||
print("BEAM COMPLETE", flush=True) # allows wrapper script to detect BEAM search completion and retry if it failed
|
||||
|
||||
|
||||
total_train_time = time.perf_counter() - train_start_time
|
||||
if WANDB:
|
||||
wandb.log({"train/loss": loss_item, "train/lr": lr_item, "train/loop_time_prev": loop_time, "train/dl_time": dl_time, "train/step": i,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue