Delta-Vector commited on
Commit
7d4dd10
·
verified ·
1 Parent(s): 664f5f7

Upload dan-chat-apertus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dan-chat-apertus.py +237 -0
dan-chat-apertus.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing the DanApertusPromptTokenizingStrategy and DanApertusPrompter class"""
2
+
3
+ import copy
4
+ import logging
5
+ from typing import Generator, Tuple
6
+
7
+ from axolotl.prompt_tokenizers import (
8
+ PromptTokenizingStrategy,
9
+ # Assuming these are available from the base library
10
+ # parse_tokenized_to_result,
11
+ # tokenize_prompt_default,
12
+ )
13
+
14
+ LOG = logging.getLogger("axolotl")
15
+
16
+ IGNORE_TOKEN_ID = -100
17
+
18
+ # New Special Tokens and Prefixes based on the Jinja2 template
19
+ turn_separator = "" # The Jinja2 template relies on the tokens themselves for separation
20
+
21
+ system_prefix = "<|system_start|>"
22
+ system_suffix = "<|system_end|>"
23
+ developer_prefix = "<|developer_start|>"
24
+ developer_suffix = "<|developer_end|>"
25
+ user_prefix = "<|user_start|>"
26
+ user_suffix = "<|user_end|>"
27
+ assistant_prefix = "<|assistant_start|>"
28
+ assistant_suffix = "<|assistant_end|>"
29
+ tool_calls_prefix = "<|tools_prefix|>"
30
+ tool_calls_suffix = "<|tools_suffix|>"
31
+
32
+ # This is a simplification and would need a full Jinja2 renderer to be 100% accurate
33
+ # We'll use a placeholder for the developer block
34
+ DEVELOPER_BLOCK_PLACEHOLDER = f"{developer_prefix}Deliberation: disabled\nTool Capabilities: disabled{developer_suffix}"
35
+ DEFAULT_SYSTEM_PROMPT = "You are Dolphin, a helpful, unbiased, and uncensored AI assistant." # Date is a placeholder
36
+
37
+ class DanApertusPromptTokenizingStrategy(PromptTokenizingStrategy):
38
+ def __init__(self, prompter, tokenizer, train_on_inputs, sequence_len, *args, **kwargs):
39
+ super().__init__(prompter, tokenizer, *args, **kwargs)
40
+
41
+ # Tokenize the assistant prefix for use in calculating labels
42
+ res = self._tokenize(assistant_prefix, add_eos_token=False, strip_bos_token=True)
43
+ self.bot_prefix_token_ids = res["input_ids"]
44
+
45
+ # The new format doesn't have a simple turn_separator token like "\n"
46
+ self.turn_separator_token_ids = []
47
+
48
+ self.train_on_inputs = train_on_inputs
49
+ self.sequence_len = sequence_len
50
+
51
+ def tokenize_prompt(self, prompt):
52
+ # 1. Build prompt parts, which now includes system and developer context
53
+ # This will include a virtual 'initial_context' part for the system/developer block
54
+ prompt_parts = list(self.prompter.build_prompt(prompt["conversations"]))
55
+ tokenized_parts = []
56
+ total_length = 0
57
+ not_first_turn = False # This flag is still useful for generic separator logic if needed, but not for this specific format
58
+
59
+ # 2. Add the initial system/developer block (simplified)
60
+ # Assuming the first message in conversations is the actual system message if present
61
+ initial_context_message = ""
62
+ initial_context_labels = []
63
+
64
+ # Check for system message in the first turn
65
+ if prompt_parts and prompt_parts[0][0] == "system":
66
+ _, system_msg, _, _ = prompt_parts.pop(0) # Pop off the explicit system message
67
+ else:
68
+ system_msg = DEFAULT_SYSTEM_PROMPT # Use default if not present
69
+
70
+ full_context = f"{system_prefix}{system_msg}{system_suffix}{DEVELOPER_BLOCK_PLACEHOLDER}"
71
+
72
+ res_context = self._tokenize(full_context, add_eos_token=False, strip_bos_token=False)
73
+ initial_context_labels = [IGNORE_TOKEN_ID] * len(res_context["input_ids"])
74
+
75
+ tokenized_parts.append({
76
+ "input_ids": res_context["input_ids"],
77
+ "attention_mask": res_context["attention_mask"],
78
+ "labels": initial_context_labels,
79
+ "role": "context",
80
+ "loss": False
81
+ })
82
+ total_length += len(res_context["input_ids"])
83
+
84
+
85
+ # 3. Process conversation turns
86
+ for role, message, loss, prefix in prompt_parts:
87
+ if total_length >= self.sequence_len:
88
+ break
89
+
90
+ # If prefix is not defined, set it to an empty string
91
+ if prefix is None:
92
+ prefix = ""
93
+
94
+ # Helper to generate prefix and suffix for a role
95
+ role_prefix = ""
96
+ role_suffix = ""
97
+
98
+ if role in ["system", "user", "human"]:
99
+ role_prefix = user_prefix # All user/human/system (within conversation) are user_token
100
+ role_suffix = user_suffix
101
+
102
+ # Assuming the message content is what we want to wrap
103
+ full_text = role_prefix + prefix + message + role_suffix
104
+ res = self._tokenize(full_text, add_eos_token=False, strip_bos_token=True)
105
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
106
+
107
+ elif role in ["model", "gpt"]:
108
+ role_prefix = assistant_prefix
109
+ role_suffix = assistant_suffix
110
+
111
+ # In this complex format, the assistant turn contains the full response
112
+ # (including potential tool calls/thoughts/responses from the Jinja template logic)
113
+ # We assume 'message' here is the full, pre-formatted assistant block
114
+
115
+ # Tokenize the full block with its prefix/suffix
116
+ full_text = role_prefix + prefix + message + role_suffix
117
+ res = self._tokenize(full_text, add_eos_token=True, strip_bos_token=True)
118
+
119
+ # Labels for assistant (model) turn
120
+ if not loss:
121
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
122
+ else:
123
+ # Treat the entire assistant block as the ground truth if loss=True
124
+ # We tokenize the *full* text but only train on the response part.
125
+ # This is an approximation. A more accurate way would be to only train
126
+ # on the message *content* tokens, excluding prefix/suffix/tool tokens.
127
+
128
+ # Approximate prefix length as the role_prefix length
129
+ # We strip_bos_token=True above, so we only need to account for role_prefix
130
+ res_prefix = self._tokenize(role_prefix, add_eos_token=False, strip_bos_token=True)
131
+ prefix_len = len(res_prefix["input_ids"])
132
+
133
+ # Labels: IGNORE for the prefix, real tokens for the rest
134
+ labels = [IGNORE_TOKEN_ID] * prefix_len + [*copy.deepcopy(res["input_ids"])][prefix_len:]
135
+
136
+ elif role == "tool":
137
+ # Tool messages are tricky in this format as they are nested inside the assistant turn
138
+ # The Prompter should probably not yield a separate 'tool' role
139
+ # For compatibility, we'll wrap it minimally, but this might not match the template
140
+ role_prefix = "["
141
+ role_suffix = "]"
142
+
143
+ full_text = role_prefix + prefix + message + role_suffix
144
+ res = self._tokenize(full_text, add_eos_token=False, strip_bos_token=True)
145
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) # Tool output is usually not trained on
146
+
147
+ else:
148
+ LOG.warning(f"unknown role in conversation: {role}")
149
+ continue
150
+
151
+ part_length = len(res["input_ids"])
152
+ if total_length + part_length > self.sequence_len:
153
+ break
154
+
155
+ tokenized_parts.append({
156
+ "input_ids": res["input_ids"],
157
+ "attention_mask": res["attention_mask"],
158
+ "labels": labels,
159
+ "role": role,
160
+ "loss": loss
161
+ })
162
+ total_length += part_length
163
+ not_first_turn = True
164
+
165
+ result = {
166
+ "input_ids": [],
167
+ "attention_mask": [],
168
+ "labels": []
169
+ }
170
+
171
+ # Check if the last turn is a human/user/system turn or loss = False
172
+ while tokenized_parts and (tokenized_parts[-1]["role"] in ["human", "user", "system", "tool"] or not tokenized_parts[-1]["loss"]):
173
+ tokenized_parts.pop()
174
+
175
+ # Ensure we have a conversation (user + model turn)
176
+ if not any(part["role"] in ["human", "user", "system"] for part in tokenized_parts):
177
+ return result
178
+ if not any(part["role"] in ["model", "gpt"] for part in tokenized_parts):
179
+ return result
180
+
181
+ # Concatenate the final result
182
+ for part in tokenized_parts:
183
+ result["input_ids"] += part["input_ids"]
184
+ result["attention_mask"] += part["attention_mask"]
185
+ result["labels"] += part["labels"]
186
+
187
+ return result
188
+
189
+ # Helper functions can remain similar, but _tokenize_with_turn is less relevant
190
+ # given the new explicit role_prefix/suffix tokens
191
+ def _tokenize_with_turn(self, role_prefix, message, not_first_turn, add_eos_token=True):
192
+ # This function is now largely redundant due to the new structure, but kept
193
+ # for compatibility with the base class if other methods call it.
194
+ # It's simplified to ignore the turn_separator and rely on the prefixes.
195
+ full_message = role_prefix + message.strip()
196
+ return self._tokenize(full_message, add_eos_token=add_eos_token, strip_bos_token=True)
197
+
198
+ def _get_labels(self, res, loss, not_first_turn):
199
+ # Redefined to work with the assistant_prefix length
200
+ if not loss:
201
+ return [IGNORE_TOKEN_ID] * len(res["input_ids"])
202
+
203
+ # Calculate the length of the assistant_prefix tokenization
204
+ prefix_len = len(self.bot_prefix_token_ids)
205
+ return [IGNORE_TOKEN_ID] * prefix_len + [*copy.deepcopy(res["input_ids"])][prefix_len:]
206
+
207
+
208
+ class DanApertusPrompter:
209
+ """
210
+ Prompter for DanApertus format.
211
+ """
212
+
213
+ def __init__(self, *args, **kwargs):
214
+ pass
215
+
216
+ def build_prompt(self, source, *args, **kwargs) -> Generator[Tuple[str, str, bool, str], None, None]:
217
+ # This part remains mostly the same, yielding (role, message, loss, prefix) tuples
218
+ # The complex formatting is now handled by the TokenizingStrategy's logic
219
+ for msg in source:
220
+ from_value = msg["from"]
221
+ # Assuming 'value' in the input data is the *text* content of the message
222
+ message_value = msg["value"]
223
+
224
+ # Set loss based on the message source
225
+ loss = msg.get("loss")
226
+ if loss is None:
227
+ loss = True if from_value in ["gpt", "model"] else False # Changed default for safety, but typically True for model output
228
+
229
+ # Set prefix, defaulting to an empty string if not present
230
+ prefix = msg.get("prefix", "")
231
+
232
+ yield from_value, message_value, loss, prefix
233
+
234
+
235
+ def load(tokenizer, cfg):
236
+ # This remains the entry point
237
+ return DanApertusPromptTokenizingStrategy(DanApertusPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)