Benchmark-v0 / finetune.py
JayceAnova's picture
Update finetune.py
095abb0 verified
import argparse
import os
import sys
from typing import List
import torch
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, AutoTokenizer, AutoModelForCausalLM, AutoConfig
from utils import *
from collator import Collator
def train(args):
set_seed(args.seed)
ensure_dir(args.output_dir)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
local_rank = int(os.environ.get("LOCAL_RANK") or 0)
if local_rank == 0:
print(vars(args))
if ddp:
device_map = {"": local_rank}
config = AutoConfig.from_pretrained(args.base_model)
tokenizer = AutoTokenizer.from_pretrained(
args.base_model,
model_max_length = args.model_max_length,
padding_side="right",
)
tokenizer.pad_token_id = tokenizer.eos_token_id
gradient_checkpointing = True
train_data, valid_data = load_datasets(args)
add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens())
config.vocab_size = len(tokenizer)
if local_rank == 0:
print("add {} new token.".format(add_num))
print("data num:", len(train_data))
tokenizer.save_pretrained(args.output_dir)
config.save_pretrained(args.output_dir)
collator = Collator(args, tokenizer)
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
# torch_dtype=torch.float16,
device_map=device_map,
)
model.resize_token_embeddings(len(tokenizer))
if not ddp and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=valid_data,
args=transformers.TrainingArguments(
seed=args.seed,
per_device_train_batch_size=args.per_device_batch_size,
per_device_eval_batch_size=args.per_device_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.epochs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
fp16=args.fp16,
bf16=args.bf16,
logging_steps=args.logging_step,
optim=args.optim,
gradient_checkpointing=gradient_checkpointing,
evaluation_strategy=args.save_and_eval_strategy,
save_strategy=args.save_and_eval_strategy,
eval_steps=args.save_and_eval_steps,
save_steps=args.save_and_eval_steps,
output_dir=args.output_dir,
save_total_limit=20,
load_best_model_at_end=True,
deepspeed=args.deepspeed,
ddp_find_unused_parameters=False if ddp else None,
report_to=None,
eval_delay= 1 if args.save_and_eval_strategy=="epoch" else 2000,
dataloader_num_workers = args.dataloader_num_workers,
dataloader_prefetch_factor = args.dataloader_prefetch_factor
),
tokenizer=tokenizer,
data_collator=collator,
)
model.config.use_cache = False
trainer.train(
resume_from_checkpoint=args.resume_from_checkpoint,
)
trainer.save_state()
trainer.save_model(output_dir=args.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='LLMRec')
parser = parse_global_args(parser)
parser = parse_train_args(parser)
parser = parse_dataset_args(parser)
args = parser.parse_args()
train(args)