Spaces:
Build error
Build error
| import json | |
| import torch | |
| import random | |
| import transformers | |
| import networkx as nx | |
| from tqdm import tqdm | |
| from peft import (LoraConfig, get_peft_model, | |
| prepare_model_for_kbit_training) | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| class QloraTrainer_CS: | |
| def __init__(self, config: dict, use_predefined_graph=False): | |
| self.config = config | |
| self.use_predefined_graph = use_predefined_graph | |
| self.tokenizer = None | |
| self.base_model = None | |
| self.adapter_model = None | |
| self.merged_model = None | |
| self.transformer_trainer = None | |
| self.test_data = None | |
| template_file_path = 'configs/alpaca.json' | |
| with open(template_file_path) as fp: | |
| self.template = json.load(fp) | |
| def load_base_model(self): | |
| model_id = self.config['inference']["base_model"] | |
| print(model_id) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_use_double_quant=True, | |
| bnb_8bit_quant_type="nf8", | |
| bnb_8bit_compute_dtype=torch.bfloat16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"] | |
| if not tokenizer.pad_token: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, torch_dtype=torch.bfloat16) | |
| if model.device.type != 'cuda': | |
| model.to('cuda') | |
| model.gradient_checkpointing_enable() | |
| model = prepare_model_for_kbit_training(model) | |
| self.tokenizer = tokenizer | |
| self.base_model = model | |
| def train(self): | |
| # Set up lora config or load pre-trained adapter | |
| lora_config = LoraConfig( | |
| r=self.config['training']['qlora']['rank'], | |
| lora_alpha=self.config['training']['qlora']['lora_alpha'], | |
| target_modules=self.config['training']['qlora']['target_modules'], | |
| lora_dropout=self.config['training']['qlora']['lora_dropout'], | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(self.base_model, lora_config) | |
| self._print_trainable_parameters(model) | |
| print("Start data preprocessing") | |
| train_data = self._process_data_instruction() | |
| print('Length of dataset: ', len(train_data)) | |
| print("Start training") | |
| self.transformer_trainer = transformers.Trainer( | |
| model=model, | |
| train_dataset=train_data, | |
| args=transformers.TrainingArguments( | |
| per_device_train_batch_size=self.config["training"]['trainer_args']["per_device_train_batch_size"], | |
| gradient_accumulation_steps=self.config['model_saving']['index'], | |
| warmup_steps=self.config["training"]['trainer_args']["warmup_steps"], | |
| num_train_epochs=self.config["training"]['trainer_args']["num_train_epochs"], | |
| learning_rate=self.config["training"]['trainer_args']["learning_rate"], | |
| lr_scheduler_type=self.config["training"]['trainer_args']["lr_scheduler_type"], | |
| fp16=self.config["training"]['trainer_args']["fp16"], | |
| logging_steps=self.config["training"]['trainer_args']["logging_steps"], | |
| output_dir=self.config["training"]['trainer_args']["trainer_output_dir"], | |
| report_to="wandb", | |
| save_steps=self.config["training"]['trainer_args']["save_steps"], | |
| ), | |
| data_collator=transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False), | |
| ) | |
| model.config.use_cache = False | |
| self.transformer_trainer.train() | |
| model_save_path = f"{self.config['model_saving']['model_output_dir']}/{self.config['model_saving']['model_name']}_{self.config['model_saving']['index']}_adapter_test_graph" | |
| self.transformer_trainer.save_model(model_save_path) | |
| self.adapter_model = model | |
| print(f"Training complete, adapter model saved in {model_save_path}") | |
| def _print_trainable_parameters(self, model): | |
| """ | |
| Prints the number of trainable parameters in the model. | |
| """ | |
| trainable_params = 0 | |
| all_param = 0 | |
| for _, param in model.named_parameters(): | |
| all_param += param.numel() | |
| if param.requires_grad: | |
| trainable_params += param.numel() | |
| print( | |
| f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" | |
| ) | |
| def _process_data_instruction(self): | |
| context_window = self.tokenizer.model_max_length | |
| if self.use_predefined_graph: | |
| graph_data = nx.read_gexf('datasets/' + self.config["training"]["predefined_graph_path"], node_type=None, relabel=False, version='1.2draft') | |
| else: | |
| graph_path = self.config['data_downloading']['download_directory'] + 'description/' + self.config['data_downloading']['gexf_file'] | |
| graph_data = nx.read_gexf(graph_path, node_type=None, relabel=False, version='1.2draft') | |
| raw_graph = graph_data | |
| test_set_size = len(graph_data.nodes()) // 10 | |
| all_test_nodes = set(list(graph_data.nodes())[:test_set_size]) | |
| all_train_nodes = set(list(graph_data.nodes())[test_set_size:]) | |
| raw_id_2_title_abs = dict() | |
| for paper_id in list(graph_data.nodes())[test_set_size:]: | |
| title = graph_data.nodes()[paper_id]['title'] | |
| abstract = graph_data.nodes()[paper_id]['abstract'] | |
| raw_id_2_title_abs[paper_id] = [title, abstract] | |
| raw_id_2_intro = dict() | |
| for paper_id in list(graph_data.nodes())[test_set_size:]: | |
| if graph_data.nodes[paper_id]['introduction'] != '': | |
| intro = graph_data.nodes[paper_id]['introduction'] | |
| raw_id_2_intro[paper_id] = intro | |
| raw_id_pair_2_sentence = dict() | |
| for edge in list(graph_data.edges()): | |
| sentence = graph_data.edges()[edge]['sentence'] | |
| raw_id_pair_2_sentence[edge] = sentence | |
| test_data = [] | |
| edge_list = [] | |
| for edge in list(raw_graph.edges()): | |
| src, tar = edge | |
| if src not in all_test_nodes and tar not in all_test_nodes: | |
| edge_list.append(edge) | |
| else: | |
| test_data.append(edge) | |
| train_num = int(len(edge_list)) | |
| data_LP = [] | |
| data_abstract_2_title = [] | |
| data_paper_retrieval = [] | |
| data_citation_sentence = [] | |
| data_abs_completion = [] | |
| data_title_2_abs = [] | |
| data_intro_2_abs = [] | |
| for sample in tqdm(random.sample(edge_list, train_num)): | |
| source, target = sample[0], sample[1] | |
| source_title, source_abs = raw_id_2_title_abs[source] | |
| target_title, target_abs = raw_id_2_title_abs[target] | |
| # LP prompt | |
| rand_ind = random.choice(list(raw_id_2_title_abs.keys())) | |
| neg_title, neg_abs = raw_id_2_title_abs[rand_ind] | |
| data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'label':'yes'}) | |
| data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':neg_title, 't_abs':neg_abs, 'label':'no'}) | |
| for sample in tqdm(random.sample(edge_list, train_num)): | |
| source, target = sample[0], sample[1] | |
| source_title, source_abs = raw_id_2_title_abs[source] | |
| target_title, target_abs = raw_id_2_title_abs[target] | |
| # abs_2_title prompt | |
| data_abstract_2_title.append({'title':source_title, 'abs':source_abs}) | |
| data_abstract_2_title.append({'title':target_title, 'abs':target_abs}) | |
| for sample in tqdm(random.sample(edge_list, train_num)): | |
| source, target = sample[0], sample[1] | |
| source_title, source_abs = raw_id_2_title_abs[source] | |
| target_title, target_abs = raw_id_2_title_abs[target] | |
| # paper_retrieval prompt | |
| neighbors = list(nx.all_neighbors(raw_graph, source)) | |
| sample_node_list = list(all_train_nodes - set(neighbors) - set([source]) - set([target])) | |
| sampled_neg_nodes = random.sample(sample_node_list, 5) + [target] | |
| random.shuffle(sampled_neg_nodes) | |
| data_paper_retrieval.append({'title':source_title, 'abs':source_abs, 'sample_title': [raw_id_2_title_abs[node][0] for node in sampled_neg_nodes], 'right_title':target_title}) | |
| for sample in tqdm(random.sample(edge_list, train_num)): | |
| source, target = sample[0], sample[1] | |
| source_title, source_abs = raw_id_2_title_abs[source] | |
| target_title, target_abs = raw_id_2_title_abs[target] | |
| # citation_sentence prompt | |
| citation_sentence = raw_id_pair_2_sentence[(source, target)] if (source, target) in raw_id_pair_2_sentence.keys() else raw_id_pair_2_sentence[(target, source)] | |
| data_citation_sentence.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'sentence': citation_sentence}) | |
| for sample in tqdm(random.sample(edge_list, train_num)): | |
| source, target = sample[0], sample[1] | |
| source_title, source_abs = raw_id_2_title_abs[source] | |
| target_title, target_abs = raw_id_2_title_abs[target] | |
| # abs_complete prompt | |
| data_abs_completion.append({'title':source_title, 'abs':source_abs}) | |
| data_abs_completion.append({'title':target_title, 'abs':target_abs}) | |
| for sample in tqdm(random.sample(edge_list, train_num)): | |
| source, target = sample[0], sample[1] | |
| source_title, source_abs = raw_id_2_title_abs[source] | |
| target_title, target_abs = raw_id_2_title_abs[target] | |
| # title_2_abs prompt | |
| data_title_2_abs.append({'title':source_title, 'right_abs':source_abs}) | |
| data_title_2_abs.append({'title':target_title, 'right_abs':target_abs}) | |
| for sample in tqdm(random.sample(edge_list, train_num)): | |
| source, target = sample[0], sample[1] | |
| if source in raw_id_2_intro: | |
| source_intro = raw_id_2_intro[source] | |
| _, source_abs = raw_id_2_title_abs[source] | |
| data_intro_2_abs.append({'intro':source_intro, 'abs':source_abs}) | |
| if target in raw_id_2_intro: | |
| target_intro = raw_id_2_intro[target] | |
| _, target_abs = raw_id_2_title_abs[target] | |
| data_intro_2_abs.append({'intro':target_intro, 'abs':target_abs}) | |
| data_prompt = [] | |
| data_prompt += [self._generate_paper_retrieval_prompt(data_point) for data_point in data_paper_retrieval] | |
| data_prompt += [self._generate_LP_prompt(data_point) for data_point in data_LP] | |
| data_prompt += [self._generate_abstract_2_title_prompt(data_point) for data_point in data_abstract_2_title] | |
| data_prompt += [self._generate_citation_sentence_prompt(data_point) for data_point in data_citation_sentence] | |
| data_prompt += [self._generate_abstract_completion_prompt(data_point) for data_point in data_abs_completion] | |
| data_prompt += [self._generate_title_2_abstract_prompt(data_point) for data_point in data_title_2_abs] | |
| data_prompt += [self._generate_intro_2_abstract_prompt(data_point, context_window) for data_point in data_intro_2_abs] | |
| print("Total prompts:", len(data_prompt)) | |
| random.shuffle(data_prompt) | |
| if self.tokenizer.chat_template is None: | |
| data_tokenized = [self.tokenizer(sample, max_length=context_window, truncation=True) for sample in tqdm(data_prompt)] | |
| else: | |
| data_tokenized = [self.tokenizer.apply_chat_template(sample, max_length=context_window, truncation=True, tokenize=False) for sample in tqdm(data_prompt)] | |
| return data_tokenized | |
| def _generate_LP_prompt(self, data_point: dict): | |
| instruction = "Determine if paper A will cite paper B." | |
| prompt_input = "" | |
| prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n" | |
| prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n" | |
| prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n" | |
| prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n" | |
| if self.tokenizer.chat_template is None: | |
| res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) | |
| res = f"{res}{data_point['label']}" | |
| else: | |
| res = [ | |
| {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, | |
| {"role": "assistant", "content": data_point['label']} | |
| ] | |
| return res | |
| def _generate_abstract_2_title_prompt(self, data_point: dict): | |
| instruction = "Please generate the title of paper based on its abstract." | |
| prompt_input = "" | |
| prompt_input = prompt_input + "Abstract: " + data_point['abs'] + "\n" | |
| if self.tokenizer.chat_template is None: | |
| res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) | |
| res = f"{res}{data_point['title']}" | |
| else: | |
| res = [ | |
| {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, | |
| {"role": "assistant", "content": data_point['title']} | |
| ] | |
| return res | |
| def _generate_paper_retrieval_prompt(self, data_point: dict): | |
| instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers." | |
| prompt_input = "" | |
| prompt_input = prompt_input + "Title of the Paper A: " + data_point['title'] + "\n" | |
| prompt_input = prompt_input + "Abstract of the Paper A: " + data_point['abs'] + "\n" | |
| prompt_input = prompt_input + "candidate papers: " + "\n" | |
| for i in range(len(data_point['sample_title'])): | |
| prompt_input = prompt_input + str(i) + '. ' + data_point['sample_title'][i] + "\n" | |
| if self.tokenizer.chat_template is None: | |
| res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) | |
| res = f"{res}{data_point['right_title']}" | |
| else: | |
| res = [ | |
| {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, | |
| {"role": "assistant", "content": data_point['right_title']} | |
| ] | |
| return res | |
| def _generate_citation_sentence_prompt(self, data_point: dict): | |
| instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section." | |
| prompt_input = "" | |
| prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n" | |
| prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n" | |
| prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n" | |
| prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n" | |
| if self.tokenizer.chat_template is None: | |
| res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) | |
| res = f"{res}{data_point['sentence']}" | |
| else: | |
| res = [ | |
| {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, | |
| {"role": "assistant", "content": data_point['sentence']} | |
| ] | |
| return res | |
| def _generate_abstract_completion_prompt(self, data_point: dict): | |
| instruction = "Please complete the abstract of a paper." | |
| prompt_input = "" | |
| prompt_input = prompt_input + "Title: " + data_point['title'] if data_point['title'] != None else 'Unknown' + "\n" | |
| split_abs = data_point['abs'][: int(0.3*len(data_point['abs']))] | |
| prompt_input = prompt_input + "Part of abstract: " + split_abs + "\n" | |
| if self.tokenizer.chat_template is None: | |
| res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) | |
| res = f"{res}{data_point['abs']}" | |
| else: | |
| res = [ | |
| {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, | |
| {"role": "assistant", "content": data_point['abs']} | |
| ] | |
| return res | |
| def _generate_title_2_abstract_prompt(self, data_point: dict): | |
| instruction = "Please generate the abstract of paper based on its title." | |
| prompt_input = "" | |
| prompt_input = prompt_input + "Title: " + data_point['title'] + "\n" | |
| if self.tokenizer.chat_template is None: | |
| res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) | |
| res = f"{res}{data_point['right_abs']}" | |
| else: | |
| res = [ | |
| {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, | |
| {"role": "assistant", "content": data_point['right_abs']} | |
| ] | |
| return res | |
| def _generate_intro_2_abstract_prompt(self, data_point: dict, context_window): | |
| instruction = "Please generate the abstract of paper based on its introduction section." | |
| prompt_input = "" | |
| prompt_input = prompt_input + "Introduction: " + data_point['intro'] + "\n" | |
| # Reduce it to make it fit | |
| prompt_input = prompt_input[:int(context_window*2)] | |
| if self.tokenizer.chat_template is None: | |
| res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) | |
| res = f"{res}{data_point['abs']}" | |
| else: | |
| res = [ | |
| {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)}, | |
| {"role": "assistant", "content": data_point['abs']} | |
| ] | |
| return res | |