| | import argparse |
| | import os |
| |
|
| | import sys |
| | from typing import List |
| |
|
| | import torch |
| | import transformers |
| |
|
| | from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, AutoTokenizer, AutoModelForCausalLM, AutoConfig |
| |
|
| | from utils import * |
| | from collator import Collator |
| |
|
| | def train(args): |
| |
|
| | set_seed(args.seed) |
| | ensure_dir(args.output_dir) |
| |
|
| | device_map = "auto" |
| | world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| | ddp = world_size != 1 |
| | local_rank = int(os.environ.get("LOCAL_RANK") or 0) |
| | if local_rank == 0: |
| | print(vars(args)) |
| |
|
| | if ddp: |
| | device_map = {"": local_rank} |
| |
|
| | config = AutoConfig.from_pretrained(args.base_model) |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | args.base_model, |
| | model_max_length = args.model_max_length, |
| | padding_side="right", |
| | ) |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| | gradient_checkpointing = True |
| |
|
| | train_data, valid_data = load_datasets(args) |
| | add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens()) |
| | config.vocab_size = len(tokenizer) |
| | if local_rank == 0: |
| | print("add {} new token.".format(add_num)) |
| | print("data num:", len(train_data)) |
| | tokenizer.save_pretrained(args.output_dir) |
| | config.save_pretrained(args.output_dir) |
| |
|
| | collator = Collator(args, tokenizer) |
| |
|
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | args.base_model, |
| | |
| | device_map=device_map, |
| | ) |
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| |
|
| | if not ddp and torch.cuda.device_count() > 1: |
| | model.is_parallelizable = True |
| | model.model_parallel = True |
| |
|
| |
|
| | trainer = transformers.Trainer( |
| | model=model, |
| | train_dataset=train_data, |
| | eval_dataset=valid_data, |
| | args=transformers.TrainingArguments( |
| | seed=args.seed, |
| | per_device_train_batch_size=args.per_device_batch_size, |
| | per_device_eval_batch_size=args.per_device_batch_size, |
| | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| | warmup_ratio=args.warmup_ratio, |
| | num_train_epochs=args.epochs, |
| | learning_rate=args.learning_rate, |
| | weight_decay=args.weight_decay, |
| | lr_scheduler_type=args.lr_scheduler_type, |
| | fp16=args.fp16, |
| | bf16=args.bf16, |
| | logging_steps=args.logging_step, |
| | optim=args.optim, |
| | gradient_checkpointing=gradient_checkpointing, |
| | evaluation_strategy=args.save_and_eval_strategy, |
| | save_strategy=args.save_and_eval_strategy, |
| | eval_steps=args.save_and_eval_steps, |
| | save_steps=args.save_and_eval_steps, |
| | output_dir=args.output_dir, |
| | save_total_limit=20, |
| | load_best_model_at_end=True, |
| | deepspeed=args.deepspeed, |
| | ddp_find_unused_parameters=False if ddp else None, |
| | report_to=None, |
| | eval_delay= 1 if args.save_and_eval_strategy=="epoch" else 2000, |
| | dataloader_num_workers = args.dataloader_num_workers, |
| | dataloader_prefetch_factor = args.dataloader_prefetch_factor |
| | ), |
| | tokenizer=tokenizer, |
| | data_collator=collator, |
| | ) |
| | model.config.use_cache = False |
| |
|
| |
|
| | trainer.train( |
| | resume_from_checkpoint=args.resume_from_checkpoint, |
| | ) |
| |
|
| | trainer.save_state() |
| | trainer.save_model(output_dir=args.output_dir) |
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description='LLMRec') |
| | parser = parse_global_args(parser) |
| | parser = parse_train_args(parser) |
| | parser = parse_dataset_args(parser) |
| |
|
| | args = parser.parse_args() |
| |
|
| | train(args) |
| |
|