Other
English
minecraft
action prediction
Kqte commited on
Commit
8bcfbaa
·
verified ·
1 Parent(s): ae471e2

Upload 2 files

Browse files
Files changed (2) hide show
  1. model/parse_gold.py +44 -0
  2. model/parse_incremental.py +1 -1
model/parse_gold.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from datasets import load_dataset
5
+ from tqdm import tqdm
6
+
7
+ device_map = "auto"
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ "/path/to/llamipa/adapter",
10
+ return_dict=True,
11
+ torch_dtype=torch.float16,
12
+ device_map=device_map)
13
+
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained("/path/to/meta-llama3-8b/",add_eos_token=True)
16
+
17
+ tokenizer.pad_token_id = tokenizer.eos_token_id + 1
18
+ tokenizer.padding_side = "right"
19
+
20
+ pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, max_new_tokens=100)
21
+
22
+ test_dataset = load_dataset("json", data_files={'test':'/path/to/parser_test_15_gold.jsonl'})["test"]
23
+
24
+
25
+ def formatting_prompts_func(example):
26
+ output_texts = []
27
+ for i in range(len(example['sample'])):
28
+ text = f"<|begin_of_text|>Identify the discourse structure (DS) for the new turn in the following excerpt :\n {example['sample'][i]}\n ### DS:"
29
+ output_texts.append(text)
30
+ return output_texts
31
+
32
+
33
+ test_texts = formatting_prompts_func(test_dataset)
34
+
35
+ print("Test Length:", len(test_texts))
36
+
37
+ f = open("/path/to/test-output-file.txt","w")
38
+
39
+ for text in tqdm(test_texts):
40
+ print(text)
41
+ print(pipe(text)[0]["generated_text"], file=f)
42
+
43
+ f.close()
44
+
model/parse_incremental.py CHANGED
@@ -89,7 +89,7 @@ def format_gen(preds):
89
 
90
 
91
  def formatting_prompts_func(example):
92
- output_text = '<|begin_of_text|>Identify the discourse structure (DS) for the new turn in the following excerpt :\n' + example + '\n ### DS:'
93
  return output_text
94
 
95
  f = open("/path/to/test-output-file.txt","w")
 
89
 
90
 
91
  def formatting_prompts_func(example):
92
+ output_text = '<|begin_of_text|>Identify the discourse structure (DS) for the new turn in the following excerpt :\n ' + example + '\n ### DS:'
93
  return output_text
94
 
95
  f = open("/path/to/test-output-file.txt","w")