LitBench-UI / src /train.py
Andreas99's picture
Update src/train.py
b093dc1 verified
import json
import torch
import random
import transformers
import networkx as nx
from tqdm import tqdm
from peft import (LoraConfig, get_peft_model,
prepare_model_for_kbit_training)
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
class QloraTrainer_CS:
def __init__(self, config: dict, use_predefined_graph=False):
self.config = config
self.use_predefined_graph = use_predefined_graph
self.tokenizer = None
self.base_model = None
self.adapter_model = None
self.merged_model = None
self.transformer_trainer = None
self.test_data = None
template_file_path = 'configs/alpaca.json'
with open(template_file_path) as fp:
self.template = json.load(fp)
def load_base_model(self):
model_id = self.config['inference']["base_model"]
print(model_id)
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type="nf8",
bnb_8bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"]
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, torch_dtype=torch.bfloat16)
if model.device.type != 'cuda':
model.to('cuda')
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
self.tokenizer = tokenizer
self.base_model = model
def train(self):
# Set up lora config or load pre-trained adapter
lora_config = LoraConfig(
r=self.config['training']['qlora']['rank'],
lora_alpha=self.config['training']['qlora']['lora_alpha'],
target_modules=self.config['training']['qlora']['target_modules'],
lora_dropout=self.config['training']['qlora']['lora_dropout'],
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(self.base_model, lora_config)
self._print_trainable_parameters(model)
print("Start data preprocessing")
train_data = self._process_data_instruction()
print('Length of dataset: ', len(train_data))
print("Start training")
self.transformer_trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=self.config["training"]['trainer_args']["per_device_train_batch_size"],
gradient_accumulation_steps=self.config['model_saving']['index'],
warmup_steps=self.config["training"]['trainer_args']["warmup_steps"],
num_train_epochs=self.config["training"]['trainer_args']["num_train_epochs"],
learning_rate=self.config["training"]['trainer_args']["learning_rate"],
lr_scheduler_type=self.config["training"]['trainer_args']["lr_scheduler_type"],
fp16=self.config["training"]['trainer_args']["fp16"],
logging_steps=self.config["training"]['trainer_args']["logging_steps"],
output_dir=self.config["training"]['trainer_args']["trainer_output_dir"],
report_to="wandb",
save_steps=self.config["training"]['trainer_args']["save_steps"],
),
data_collator=transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
self.transformer_trainer.train()
model_save_path = f"{self.config['model_saving']['model_output_dir']}/{self.config['model_saving']['model_name']}_{self.config['model_saving']['index']}_adapter_test_graph"
self.transformer_trainer.save_model(model_save_path)
self.adapter_model = model
print(f"Training complete, adapter model saved in {model_save_path}")
def _print_trainable_parameters(self, model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def _process_data_instruction(self):
context_window = self.tokenizer.model_max_length
if self.use_predefined_graph:
graph_data = nx.read_gexf('datasets/' + self.config["training"]["predefined_graph_path"], node_type=None, relabel=False, version='1.2draft')
else:
graph_path = self.config['data_downloading']['download_directory'] + 'description/' + self.config['data_downloading']['gexf_file']
graph_data = nx.read_gexf(graph_path, node_type=None, relabel=False, version='1.2draft')
raw_graph = graph_data
test_set_size = len(graph_data.nodes()) // 10
all_test_nodes = set(list(graph_data.nodes())[:test_set_size])
all_train_nodes = set(list(graph_data.nodes())[test_set_size:])
raw_id_2_title_abs = dict()
for paper_id in list(graph_data.nodes())[test_set_size:]:
title = graph_data.nodes()[paper_id]['title']
abstract = graph_data.nodes()[paper_id]['abstract']
raw_id_2_title_abs[paper_id] = [title, abstract]
raw_id_2_intro = dict()
for paper_id in list(graph_data.nodes())[test_set_size:]:
if graph_data.nodes[paper_id]['introduction'] != '':
intro = graph_data.nodes[paper_id]['introduction']
raw_id_2_intro[paper_id] = intro
raw_id_pair_2_sentence = dict()
for edge in list(graph_data.edges()):
sentence = graph_data.edges()[edge]['sentence']
raw_id_pair_2_sentence[edge] = sentence
test_data = []
edge_list = []
for edge in list(raw_graph.edges()):
src, tar = edge
if src not in all_test_nodes and tar not in all_test_nodes:
edge_list.append(edge)
else:
test_data.append(edge)
train_num = int(len(edge_list))
data_LP = []
data_abstract_2_title = []
data_paper_retrieval = []
data_citation_sentence = []
data_abs_completion = []
data_title_2_abs = []
data_intro_2_abs = []
for sample in tqdm(random.sample(edge_list, train_num)):
source, target = sample[0], sample[1]
source_title, source_abs = raw_id_2_title_abs[source]
target_title, target_abs = raw_id_2_title_abs[target]
# LP prompt
rand_ind = random.choice(list(raw_id_2_title_abs.keys()))
neg_title, neg_abs = raw_id_2_title_abs[rand_ind]
data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'label':'yes'})
data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':neg_title, 't_abs':neg_abs, 'label':'no'})
for sample in tqdm(random.sample(edge_list, train_num)):
source, target = sample[0], sample[1]
source_title, source_abs = raw_id_2_title_abs[source]
target_title, target_abs = raw_id_2_title_abs[target]
# abs_2_title prompt
data_abstract_2_title.append({'title':source_title, 'abs':source_abs})
data_abstract_2_title.append({'title':target_title, 'abs':target_abs})
for sample in tqdm(random.sample(edge_list, train_num)):
source, target = sample[0], sample[1]
source_title, source_abs = raw_id_2_title_abs[source]
target_title, target_abs = raw_id_2_title_abs[target]
# paper_retrieval prompt
neighbors = list(nx.all_neighbors(raw_graph, source))
sample_node_list = list(all_train_nodes - set(neighbors) - set([source]) - set([target]))
sampled_neg_nodes = random.sample(sample_node_list, 5) + [target]
random.shuffle(sampled_neg_nodes)
data_paper_retrieval.append({'title':source_title, 'abs':source_abs, 'sample_title': [raw_id_2_title_abs[node][0] for node in sampled_neg_nodes], 'right_title':target_title})
for sample in tqdm(random.sample(edge_list, train_num)):
source, target = sample[0], sample[1]
source_title, source_abs = raw_id_2_title_abs[source]
target_title, target_abs = raw_id_2_title_abs[target]
# citation_sentence prompt
citation_sentence = raw_id_pair_2_sentence[(source, target)] if (source, target) in raw_id_pair_2_sentence.keys() else raw_id_pair_2_sentence[(target, source)]
data_citation_sentence.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'sentence': citation_sentence})
for sample in tqdm(random.sample(edge_list, train_num)):
source, target = sample[0], sample[1]
source_title, source_abs = raw_id_2_title_abs[source]
target_title, target_abs = raw_id_2_title_abs[target]
# abs_complete prompt
data_abs_completion.append({'title':source_title, 'abs':source_abs})
data_abs_completion.append({'title':target_title, 'abs':target_abs})
for sample in tqdm(random.sample(edge_list, train_num)):
source, target = sample[0], sample[1]
source_title, source_abs = raw_id_2_title_abs[source]
target_title, target_abs = raw_id_2_title_abs[target]
# title_2_abs prompt
data_title_2_abs.append({'title':source_title, 'right_abs':source_abs})
data_title_2_abs.append({'title':target_title, 'right_abs':target_abs})
for sample in tqdm(random.sample(edge_list, train_num)):
source, target = sample[0], sample[1]
if source in raw_id_2_intro:
source_intro = raw_id_2_intro[source]
_, source_abs = raw_id_2_title_abs[source]
data_intro_2_abs.append({'intro':source_intro, 'abs':source_abs})
if target in raw_id_2_intro:
target_intro = raw_id_2_intro[target]
_, target_abs = raw_id_2_title_abs[target]
data_intro_2_abs.append({'intro':target_intro, 'abs':target_abs})
data_prompt = []
data_prompt += [self._generate_paper_retrieval_prompt(data_point) for data_point in data_paper_retrieval]
data_prompt += [self._generate_LP_prompt(data_point) for data_point in data_LP]
data_prompt += [self._generate_abstract_2_title_prompt(data_point) for data_point in data_abstract_2_title]
data_prompt += [self._generate_citation_sentence_prompt(data_point) for data_point in data_citation_sentence]
data_prompt += [self._generate_abstract_completion_prompt(data_point) for data_point in data_abs_completion]
data_prompt += [self._generate_title_2_abstract_prompt(data_point) for data_point in data_title_2_abs]
data_prompt += [self._generate_intro_2_abstract_prompt(data_point, context_window) for data_point in data_intro_2_abs]
print("Total prompts:", len(data_prompt))
random.shuffle(data_prompt)
if self.tokenizer.chat_template is None:
data_tokenized = [self.tokenizer(sample, max_length=context_window, truncation=True) for sample in tqdm(data_prompt)]
else:
data_tokenized = [self.tokenizer.apply_chat_template(sample, max_length=context_window, truncation=True, tokenize=False) for sample in tqdm(data_prompt)]
return data_tokenized
def _generate_LP_prompt(self, data_point: dict):
instruction = "Determine if paper A will cite paper B."
prompt_input = ""
prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
if self.tokenizer.chat_template is None:
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
res = f"{res}{data_point['label']}"
else:
res = [
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
{"role": "assistant", "content": data_point['label']}
]
return res
def _generate_abstract_2_title_prompt(self, data_point: dict):
instruction = "Please generate the title of paper based on its abstract."
prompt_input = ""
prompt_input = prompt_input + "Abstract: " + data_point['abs'] + "\n"
if self.tokenizer.chat_template is None:
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
res = f"{res}{data_point['title']}"
else:
res = [
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
{"role": "assistant", "content": data_point['title']}
]
return res
def _generate_paper_retrieval_prompt(self, data_point: dict):
instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers."
prompt_input = ""
prompt_input = prompt_input + "Title of the Paper A: " + data_point['title'] + "\n"
prompt_input = prompt_input + "Abstract of the Paper A: " + data_point['abs'] + "\n"
prompt_input = prompt_input + "candidate papers: " + "\n"
for i in range(len(data_point['sample_title'])):
prompt_input = prompt_input + str(i) + '. ' + data_point['sample_title'][i] + "\n"
if self.tokenizer.chat_template is None:
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
res = f"{res}{data_point['right_title']}"
else:
res = [
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
{"role": "assistant", "content": data_point['right_title']}
]
return res
def _generate_citation_sentence_prompt(self, data_point: dict):
instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section."
prompt_input = ""
prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
if self.tokenizer.chat_template is None:
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
res = f"{res}{data_point['sentence']}"
else:
res = [
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
{"role": "assistant", "content": data_point['sentence']}
]
return res
def _generate_abstract_completion_prompt(self, data_point: dict):
instruction = "Please complete the abstract of a paper."
prompt_input = ""
prompt_input = prompt_input + "Title: " + data_point['title'] if data_point['title'] != None else 'Unknown' + "\n"
split_abs = data_point['abs'][: int(0.3*len(data_point['abs']))]
prompt_input = prompt_input + "Part of abstract: " + split_abs + "\n"
if self.tokenizer.chat_template is None:
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
res = f"{res}{data_point['abs']}"
else:
res = [
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
{"role": "assistant", "content": data_point['abs']}
]
return res
def _generate_title_2_abstract_prompt(self, data_point: dict):
instruction = "Please generate the abstract of paper based on its title."
prompt_input = ""
prompt_input = prompt_input + "Title: " + data_point['title'] + "\n"
if self.tokenizer.chat_template is None:
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
res = f"{res}{data_point['right_abs']}"
else:
res = [
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
{"role": "assistant", "content": data_point['right_abs']}
]
return res
def _generate_intro_2_abstract_prompt(self, data_point: dict, context_window):
instruction = "Please generate the abstract of paper based on its introduction section."
prompt_input = ""
prompt_input = prompt_input + "Introduction: " + data_point['intro'] + "\n"
# Reduce it to make it fit
prompt_input = prompt_input[:int(context_window*2)]
if self.tokenizer.chat_template is None:
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
res = f"{res}{data_point['abs']}"
else:
res = [
{"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
{"role": "assistant", "content": data_point['abs']}
]
return res