llama: match flat_llama with model_train (#16269)

This commit is contained in:
wozeparrot 2026-05-19 20:25:56 -04:00 committed by GitHub
commit 361553c0a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 9 deletions

View file

@ -12,7 +12,7 @@ if __name__ == "__main__":
if "ASM_GEMM" not in os.environ:
os.environ["ASM_GEMM"] = "1"
from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
from tinygrad.helpers import Timing, colored, GlobalCounters, profile_marker
from tinygrad.helpers import Timing, colored, GlobalCounters, profile_marker, round_up
from tinygrad.uop.ops import Ops, UOp
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
from extra.llama_kernels.rmsnorm import rmsnorm
@ -290,20 +290,36 @@ if __name__ == "__main__":
config = {}
BS = config["BS"] = getenv("BS", 16)
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
SMALL = config["SMALL"] = getenv("SMALL", 0)
from examples.llama3 import MODEL_PARAMS
model_params = MODEL_PARAMS[llama_size:=getenv("LLAMA3_SIZE", "8B")]["args"]
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: model_params['n_layers'] = llama_layers
# vocab_size from mixtral tokenizer
if not SMALL: model_params |= {"vocab_size": 32000}
real_vocab_size = model_params['vocab_size']
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: model_params["n_layers"] = llama_layers
# pad vocab
if (MP := getenv("MP", 1)) > 1: model_params["vocab_size"] = round_up(model_params["vocab_size"], 256 * MP)
vocab_mask:Tensor = Tensor.arange(model_params["vocab_size"]).reshape(1, 1, -1) >= real_vocab_size
model = FlatTransformer(**model_params, max_context=SEQLEN)
state = nn.state.get_state_dict(model)
print("tensor count:", len(state))
# shard the model
from tinygrad import Device
if (DP := getenv("DP", 1)) > 1:
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)))
if (MP := getenv("MP", 1)) > 1:
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)), mp=True)
is_dp = (DP := getenv("DP", 1)) > 1
is_mp = (MP := getenv("MP", 1)) > 1
is_sharding = is_dp or is_mp
device_count = max(DP, MP)
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(device_count))
model.shard(device, is_mp)
if is_dp: vocab_mask.shard_(device, axis=None).realize()
if is_mp: vocab_mask.shard_(device, axis=2).realize()
# preallocate all the grad buffers and zero them out
grad_dtype = lambda x: dtypes.bfloat16 if x.dtype in dtypes.fp8s else x.dtype
@ -320,7 +336,7 @@ if __name__ == "__main__":
sz += v.nbytes()
print(f"total sz: {sz/1e9:.2f} GB")
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int)
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=real_vocab_size, dtype=dtypes.int)
with Timing("realize weights/grads/data: "): Tensor.realize(*state.values(), *grads.values(), tokens)
print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items())))
if DP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)), axis=0)
@ -328,7 +344,9 @@ if __name__ == "__main__":
@TinyJit
def fwd_bwd(tokens:Tensor):
with Timing("python forward: "): loss = model(tokens[:, :-1], save=llama_size=="8B").sparse_categorical_crossentropy(tokens[:, 1:])
with Timing("python forward: "):
logits = model(tokens[:, :-1], save=llama_size=="8B")
loss = vocab_mask.where(-1e9, logits).sparse_categorical_crossentropy(tokens[:, 1:])
with Timing("python backward: "):
for t,g in zip(grads, loss.gradient(*grads)):
apply_grad(grads[t], g.uop)

View file

@ -4,7 +4,6 @@ export PYTHONPATH="."
export PATH="/opt/rocm-7.1.1/bin:$PATH"
export ROCM_PATH="/opt/rocm-7.1.1"
export DEV=${DEV:-AMD}
export EMULATE="AMD_CDNA4"
export CHECK_OOB=0
export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
export DEVICE_IN_FUNCTION_BUG=1