add finetune script.

This commit is contained in:
Qi Luo 2024-10-17 21:25:49 +08:00
commit 52871dd1a7
3 changed files with 275 additions and 0 deletions

View file

@ -0,0 +1,40 @@
{
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 20,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

225
train/finetune.py Normal file
View file

@ -0,0 +1,225 @@
import copy
import random
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
import torch
import torch.distributed
import transformers
from transformers import Trainer
from datasets import load_dataset
IGNORE_INDEX = -100
def build_instruction_prompt(instruction: str):
return """# This is the assembly code:
{}
# What is the source code?
""".format(
instruction
)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(
default="deepseek-ai/deepseek-coder-1.3b-base"
)
use_flash_attention: bool = field(
default=False, metadata={"help": "Whether to use flash attention."}
)
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
def _tokenize_fn(
strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
input_ids = [torch.tensor(x) for x in input_ids]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = [torch.tensor(x) for x in labels]
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def train_tokenize_function(examples, tokenizer):
sources = [
build_instruction_prompt(instruction) for instruction in examples["instruction"]
]
eos_token = tokenizer.eos_token
targets = [f"{output}\n{eos_token}" for output in examples["output"]]
data_dict = preprocess(sources, targets, tokenizer)
return data_dict
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if training_args.local_rank == 0:
print("=" * 100)
print(training_args)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
print("PAD Token:", tokenizer.pad_token, tokenizer.pad_token_id)
print("BOS Token", tokenizer.bos_token, tokenizer.bos_token_id)
print("EOS Token", tokenizer.eos_token, tokenizer.eos_token_id)
if training_args.local_rank == 0:
print("Load tokenizer from {} over.".format(model_args.model_name_or_path))
model_kwargs = {}
if model_args.use_flash_attention:
model_kwargs["attn_implementation"] = "flash_attention_2"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, torch_dtype=torch.bfloat16, **model_kwargs
)
if training_args.local_rank == 0:
print("Load model from {} over.".format(model_args.model_name_or_path))
raw_train_datasets = load_dataset(
"json",
data_files=data_args.data_path,
split="train",
cache_dir=training_args.cache_dir,
)
if training_args.local_rank > 0:
torch.distributed.barrier()
train_dataset = raw_train_datasets.map(
train_tokenize_function,
batched=True,
batch_size=3000,
num_proc=32,
remove_columns=raw_train_datasets.column_names,
load_from_cache_file=True, # not args.overwrite_cache
desc="Running Encoding",
fn_kwargs={"tokenizer": tokenizer},
)
if training_args.local_rank == 0:
torch.distributed.barrier()
if training_args.local_rank == 0:
print("Training dataset samples:", len(train_dataset))
for index in random.sample(range(len(train_dataset)), 3):
print(
f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}."
)
print(
f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}."
)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
data_module = dict(
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
)
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
trainer.train()
trainer.save_model()
trainer.save_state()
if __name__ == "__main__":
train()

10
train/requirements.txt Normal file
View file

@ -0,0 +1,10 @@
torch
tokenizers
transformers
accelerate
sympy
pebble
timeout-decorator
attrdict
deepspeed
tensorboard