Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
66676a1ac0
Merge branch 'master' into llama_trainer 2025-11-27 09:37:44 -08:00
George Hotz
ee3ed9e646 wip junk 2025-11-24 20:03:12 -08:00
George Hotz
62f98ef817 fused optim 2025-11-24 19:51:07 -08:00
George Hotz
f6dcb9a777 why did fakedata have fakeweights? 2025-11-24 19:29:42 -08:00
George Hotz
dfaaeb0720 improve llama trainer 2025-11-24 19:11:14 -08:00
2 changed files with 131 additions and 28 deletions

View file

@ -0,0 +1,93 @@
import math
from pathlib import Path
from tinygrad import Device, nn, Tensor, TinyJit
from tinygrad.helpers import getenv, profile_marker
from extra.models.llama import Transformer
from examples.llama3 import MODEL_PARAMS
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
config = {}
BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
BS = config["BS"] = getenv("BS", 16)
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
SEED = config["SEED"] = getenv("SEED", 5760)
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
SMALL = config["SMALL"] = getenv("SMALL", 0)
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
EVAL_FREQ = config["EVAL_FREQ"] = getenv("EVAL_FREQ", 46080)
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
EVAL_TARGET = config["EVAL_TARGET"] = getenv("EVAL_TARGET", 5.6)
opt_adamw_beta_1 = 0.9
opt_adamw_beta_2 = 0.95
opt_adamw_epsilon = 1e-5
opt_adamw_weight_decay = 0.1
opt_gradient_clip_norm = 1.0
opt_learning_rate_warmup_steps = getenv("WARMUP_STEPS", math.ceil(8000 * 1152 / GBS))
opt_learning_rate_decay_steps = getenv("MAX_STEPS", math.ceil(1_200_000 * 1152 / GBS)) - opt_learning_rate_warmup_steps
opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark
opt_end_learning_rate = getenv("END_LR", 8e-7)
# TODO: confirm weights are in bf16
# vocab_size from the mixtral tokenizer
params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
params = params | {"vocab_size": 32000} if not SMALL else params
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
if __name__ == "__main__":
profile_marker("create model")
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
# shard the model, either data parallel (DP) or model parallel (MP)
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
for v in nn.state.get_parameters(model):
v.shard_(device, axis=None)
if (MP := getenv("MP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
for k,v in nn.state.get_state_dict(model).items():
if 'scale' in k: v.shard_(device, axis=None) # from quantized
elif '.attention.wq' in k: v.shard_(device, axis=0)
elif '.attention.wk' in k: v.shard_(device, axis=0)
elif '.attention.wv' in k: v.shard_(device, axis=0)
elif '.attention.wo' in k: v.shard_(device, axis=1)
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
elif 'output.weight' in k: v.shard_(device, axis=0)
else:
# attention_norm, ffn_norm, norm
v.shard_(device, axis=None)
# prevents memory spike on device 0
v.realize()
profile_marker("create optim")
optim = nn.optim.AdamW(nn.state.get_parameters(model), lr=0.0,
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay, fused=True)
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate,
opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
profile_marker("init params")
optim.lr.realize(*[p.replace(p.contiguous()) for p in optim.params])
# TODO: make this work with multigpu
cat_params = Tensor.cat(*[t.flatten() for t in optim.params], dim=0)
cat_grads = Tensor.zeros_like(cat_params)
@profile_marker("microbatch")
@TinyJit
@Tensor.train()
def microbatch(batch:Tensor):
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(batch[:, 1:]).backward()
return loss.realize(cat_grads)

View file

@ -3,7 +3,7 @@ from pathlib import Path
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, Profiling, profile_marker
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
@ -1331,10 +1331,6 @@ def train_llama3():
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
if getenv("FAKEDATA"):
for v in get_parameters(model):
v = v.assign(Tensor.empty(v.shape))
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
for v in get_parameters(model):
@ -1360,9 +1356,13 @@ def train_llama3():
v.realize()
optim = AdamW(get_parameters(model), lr=0.0,
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay, fused=True)
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
# init tensors
profile_marker("init tensors")
optim.lr.realize(*[p.replace(p.contiguous()) for p in optim.params])
if resume_ckpt := getenv("RESUME_CKPT"):
fn = f"./ckpts/llama3_{resume_ckpt}.safe"
print(f"loading initial checkpoint from {fn}")
@ -1374,10 +1374,15 @@ def train_llama3():
@TinyJit
@Tensor.train()
def train_step(model, tokens:Tensor, grad_acc:int):
def train_step(tokens:Tensor, grad_acc:int):
optim.zero_grad()
# grad acc
# grad acc. NOTE: this has to become multidevice aware, this cat should be per device
cat_params = Tensor.cat(*[t.flatten() for t in optim.params], dim=0)
cat_grads = Tensor.zeros_like(cat_params)
total_loss = Tensor(0, dtype=dtypes.float)
for batch in tokens.split(tokens.shape[0]//grad_acc):
profile_marker("grads")
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
batch = batch.shard(device, 0)
@ -1387,28 +1392,32 @@ def train_llama3():
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
loss.backward()
Tensor.realize(*[p.grad for p in optim.params])
total_loss += loss/grad_acc
cat_grads += Tensor.cat(*[t.grad.flatten() for t in optim.params], dim=0)
total_loss.realize(cat_grads)
# L2 norm grad clip
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
profile_marker("optimizer")
if not getenv("DISABLE_GRAD_CLIP_NORM"):
total_norm = Tensor(0.0, dtype=dtypes.float32, device=optim.params[0].device)
for p in optim.params:
total_norm += p.grad.float().square().sum()
total_norm = total_norm.sqrt().contiguous()
for p in optim.params:
p.grad = p.grad * (opt_gradient_clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)
total_norm = cat_grads.float().square().sum().sqrt().contiguous()
cat_grads = cat_grads * (opt_gradient_clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)
optim.step()
scheduler.step()
# run the optimizer
# NOTE: this is copied from _schedule_step
out, extra = optim._step([cat_params], [cat_grads]) # this will go on CPU
lr = optim.lr
loss.realize(lr)
return loss, lr
# update the parameters
updated_params = [out[0][optim.pos_params[i]:optim.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(optim.params)]
for i, tt in enumerate(optim.params): tt.assign(updated_params[i])
Tensor.realize(*optim.params, *extra, *optim.buffers, *scheduler.schedule_step())
return total_loss
@TinyJit
@Tensor.train(False)
def eval_step(model, tokens:Tensor):
def eval_step(tokens:Tensor):
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
tokens = tokens.shard(device, 0)
@ -1449,18 +1458,19 @@ def train_llama3():
iter = get_train_iter()
i, sequences_seen = resume_ckpt, 0
for tokens in tqdm(iter, total=SAMPLES//GBS):
profile_marker(f"train step {i}")
t = time.perf_counter()
GlobalCounters.reset()
loss, lr = train_step(model, tokens, grad_acc)
loss = loss.float().item()
loss = train_step(tokens, grad_acc)
loss, lr = loss.float().item(), optim.lr.item()
i += 1
sequences_seen += tokens.shape[0]
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
tqdm.write(f"{loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
f.write(f"{i} {loss:.4f} {lr:.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
tqdm.write("saving checkpoint")
@ -1481,7 +1491,7 @@ def train_llama3():
tqdm.write(f"evaluating {5760//EVAL_BS} batches of {EVAL_BS} sequences")
for tokens in tqdm(eval_iter, total=5760//EVAL_BS):
eval_losses += eval_step(model, tokens).tolist()
eval_losses += eval_step(tokens).tolist()
log_perplexity = Tensor(eval_losses).mean().float().item()
tqdm.write(f"eval log perplexity: {log_perplexity:.4f}")
@ -1564,7 +1574,7 @@ def train_stable_diffusion():
loss, out_lr = loss.detach().to("CPU"), optimizer.lr.to("CPU")
Tensor.realize(loss, out_lr)
return loss, out_lr
# checkpointing takes ~9 minutes without this, and ~1 minute with this
@TinyJit
def ckpt_to_cpu():
@ -1603,7 +1613,7 @@ def train_stable_diffusion():
if i == 3:
for _ in range(3): ckpt_to_cpu() # do this at the beginning of run to prevent OOM surprises when checkpointing
print("BEAM COMPLETE", flush=True) # allows wrapper script to detect BEAM search completion and retry if it failed
total_train_time = time.perf_counter() - train_start_time
if WANDB:
wandb.log({"train/loss": loss_item, "train/lr": lr_item, "train/loop_time_prev": loop_time, "train/dl_time": dl_time, "train/step": i,