mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: match flat_llama with model_train (#16269)
This commit is contained in:
parent
da7414d6dc
commit
361553c0a8
2 changed files with 26 additions and 9 deletions
|
|
@ -12,7 +12,7 @@ if __name__ == "__main__":
|
|||
if "ASM_GEMM" not in os.environ:
|
||||
os.environ["ASM_GEMM"] = "1"
|
||||
from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
|
||||
from tinygrad.helpers import Timing, colored, GlobalCounters, profile_marker
|
||||
from tinygrad.helpers import Timing, colored, GlobalCounters, profile_marker, round_up
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
|
||||
from extra.llama_kernels.rmsnorm import rmsnorm
|
||||
|
|
@ -290,20 +290,36 @@ if __name__ == "__main__":
|
|||
config = {}
|
||||
BS = config["BS"] = getenv("BS", 16)
|
||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||
SMALL = config["SMALL"] = getenv("SMALL", 0)
|
||||
|
||||
from examples.llama3 import MODEL_PARAMS
|
||||
model_params = MODEL_PARAMS[llama_size:=getenv("LLAMA3_SIZE", "8B")]["args"]
|
||||
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: model_params['n_layers'] = llama_layers
|
||||
# vocab_size from mixtral tokenizer
|
||||
if not SMALL: model_params |= {"vocab_size": 32000}
|
||||
real_vocab_size = model_params['vocab_size']
|
||||
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: model_params["n_layers"] = llama_layers
|
||||
|
||||
# pad vocab
|
||||
if (MP := getenv("MP", 1)) > 1: model_params["vocab_size"] = round_up(model_params["vocab_size"], 256 * MP)
|
||||
vocab_mask:Tensor = Tensor.arange(model_params["vocab_size"]).reshape(1, 1, -1) >= real_vocab_size
|
||||
|
||||
model = FlatTransformer(**model_params, max_context=SEQLEN)
|
||||
|
||||
state = nn.state.get_state_dict(model)
|
||||
print("tensor count:", len(state))
|
||||
|
||||
# shard the model
|
||||
from tinygrad import Device
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)))
|
||||
if (MP := getenv("MP", 1)) > 1:
|
||||
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)), mp=True)
|
||||
is_dp = (DP := getenv("DP", 1)) > 1
|
||||
is_mp = (MP := getenv("MP", 1)) > 1
|
||||
is_sharding = is_dp or is_mp
|
||||
device_count = max(DP, MP)
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(device_count))
|
||||
|
||||
model.shard(device, is_mp)
|
||||
|
||||
if is_dp: vocab_mask.shard_(device, axis=None).realize()
|
||||
if is_mp: vocab_mask.shard_(device, axis=2).realize()
|
||||
|
||||
# preallocate all the grad buffers and zero them out
|
||||
grad_dtype = lambda x: dtypes.bfloat16 if x.dtype in dtypes.fp8s else x.dtype
|
||||
|
|
@ -320,7 +336,7 @@ if __name__ == "__main__":
|
|||
sz += v.nbytes()
|
||||
print(f"total sz: {sz/1e9:.2f} GB")
|
||||
|
||||
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int)
|
||||
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=real_vocab_size, dtype=dtypes.int)
|
||||
with Timing("realize weights/grads/data: "): Tensor.realize(*state.values(), *grads.values(), tokens)
|
||||
print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items())))
|
||||
if DP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)), axis=0)
|
||||
|
|
@ -328,7 +344,9 @@ if __name__ == "__main__":
|
|||
|
||||
@TinyJit
|
||||
def fwd_bwd(tokens:Tensor):
|
||||
with Timing("python forward: "): loss = model(tokens[:, :-1], save=llama_size=="8B").sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
with Timing("python forward: "):
|
||||
logits = model(tokens[:, :-1], save=llama_size=="8B")
|
||||
loss = vocab_mask.where(-1e9, logits).sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
with Timing("python backward: "):
|
||||
for t,g in zip(grads, loss.gradient(*grads)):
|
||||
apply_grad(grads[t], g.uop)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ export PYTHONPATH="."
|
|||
export PATH="/opt/rocm-7.1.1/bin:$PATH"
|
||||
export ROCM_PATH="/opt/rocm-7.1.1"
|
||||
export DEV=${DEV:-AMD}
|
||||
export EMULATE="AMD_CDNA4"
|
||||
export CHECK_OOB=0
|
||||
export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
|
||||
export DEVICE_IN_FUNCTION_BUG=1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue