amazingvince commited on
Commit
5098f25
·
1 Parent(s): b42d79c

Upload 15 files

Browse files
added_tokens.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "<|assistant|>": 32770,
3
+ "<|end|>": 32771,
4
+ "<|system|>": 32768,
5
+ "<|user|>": 32769
6
+ }
attention.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Attention layers."""
5
+
6
+ import math
7
+ import warnings
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from einops import rearrange
12
+ from torch import nn
13
+
14
+ from .low_precision_layernorm import LPLayerNorm
15
+
16
+
17
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
18
+ original_is_causal: bool):
19
+ if original_is_causal and num_query_tokens != num_key_tokens:
20
+ if num_query_tokens != 1:
21
+ raise NotImplementedError(
22
+ 'ReplitLM does not support query and key with different number of tokens, unless number of query tokens is 1.'
23
+ )
24
+ else:
25
+ return False
26
+ return original_is_causal
27
+
28
+
29
+ def scaled_multihead_dot_product_attention(
30
+ query,
31
+ key,
32
+ value,
33
+ n_heads,
34
+ softmax_scale=None,
35
+ attn_bias=None,
36
+ key_padding_mask=None,
37
+ is_causal=False,
38
+ dropout_p=0.0,
39
+ training=False,
40
+ needs_weights=False,
41
+ ):
42
+
43
+ q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
44
+ k = rearrange(key, 'b s (h d) -> b h d s', h=n_heads) # includes key.t()
45
+ v = rearrange(value, 'b s (h d) -> b h s d', h=n_heads)
46
+
47
+ min_val = torch.finfo(q.dtype).min
48
+
49
+ b, _, s_q, d = q.shape
50
+ s_k = k.size(-1)
51
+
52
+ if softmax_scale is None:
53
+ softmax_scale = 1 / math.sqrt(d)
54
+
55
+ attn_weight = q.matmul(k) * softmax_scale
56
+
57
+ if attn_bias is not None:
58
+ if (attn_bias.size(-1) != 1 and
59
+ attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
60
+ attn_bias.size(-2) != s_q):
61
+ raise RuntimeError(
62
+ f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
63
+ )
64
+ attn_weight = attn_weight + attn_bias
65
+
66
+ if key_padding_mask is not None:
67
+ if attn_bias is not None:
68
+ warnings.warn(
69
+ 'Propogating key_padding_mask to the attention module ' +
70
+ 'and applying it within the attention module can cause ' +
71
+ 'unneccessary computation/memory usage. Consider integrating ' +
72
+ 'into attn_bias once and passing that to each attention ' +
73
+ 'module instead.'
74
+ )
75
+ attn_weight = attn_weight.masked_fill(
76
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val)
77
+
78
+ if is_causal:
79
+ s = max(s_q, s_k)
80
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
81
+ causal_mask = causal_mask.tril()
82
+ causal_mask = causal_mask.to(torch.bool)
83
+ causal_mask = ~causal_mask
84
+ causal_mask = causal_mask[-s_q:, -s_k:]
85
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
86
+ min_val)
87
+
88
+ attn_weight = torch.softmax(attn_weight, dim=-1)
89
+
90
+ if dropout_p:
91
+ attn_weight = torch.nn.functional.dropout(attn_weight,
92
+ p=dropout_p,
93
+ training=training,
94
+ inplace=True)
95
+
96
+ out = attn_weight.matmul(v)
97
+ out = rearrange(out, 'b h s d -> b s (h d)')
98
+
99
+ if needs_weights:
100
+ return out, attn_weight
101
+ return out, None
102
+
103
+
104
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
105
+ for tensor in tensors:
106
+ if tensor.dtype not in valid_dtypes:
107
+ raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
108
+ if not tensor.is_cuda:
109
+ raise TypeError(
110
+ f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
111
+
112
+
113
+ def flash_attn_fn(
114
+ query,
115
+ key,
116
+ value,
117
+ n_heads,
118
+ softmax_scale=None,
119
+ attn_bias=None,
120
+ key_padding_mask=None,
121
+ is_causal=False,
122
+ dropout_p=0.0,
123
+ training=False,
124
+ needs_weights=False,
125
+ ):
126
+ try:
127
+ from flash_attn import bert_padding, flash_attn_interface
128
+ except:
129
+ raise RuntimeError('Please install flash_attn==0.2.8')
130
+
131
+ check_valid_inputs(query, key, value)
132
+
133
+ if attn_bias is not None:
134
+ raise NotImplementedError(f'attn_bias not implemented for flash attn.')
135
+
136
+ batch_size, seqlen = query.shape[:2]
137
+
138
+ if key_padding_mask is None:
139
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
140
+ query_padding_mask = key_padding_mask[:, -query.size(1):]
141
+
142
+ query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
143
+ query, query_padding_mask)
144
+ query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
145
+
146
+ key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
147
+ key, key_padding_mask)
148
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
149
+
150
+ value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
151
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
152
+
153
+ dropout_p = dropout_p if training else 0.0
154
+
155
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
156
+
157
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
158
+ query_unpad,
159
+ key_unpad,
160
+ value_unpad,
161
+ cu_seqlens_q,
162
+ cu_seqlens_k,
163
+ max_seqlen_q,
164
+ max_seqlen_k,
165
+ dropout_p,
166
+ softmax_scale=softmax_scale,
167
+ causal=reset_is_causal,
168
+ return_attn_probs=needs_weights)
169
+
170
+ output = bert_padding.pad_input(
171
+ rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
172
+ seqlen)
173
+ return output, None
174
+
175
+
176
+ def triton_flash_attn_fn(
177
+ query,
178
+ key,
179
+ value,
180
+ n_heads,
181
+ softmax_scale=None,
182
+ attn_bias=None,
183
+ key_padding_mask=None,
184
+ is_causal=False,
185
+ dropout_p=0.0,
186
+ training=False,
187
+ needs_weights=False,
188
+ ):
189
+ try:
190
+ from flash_attn import flash_attn_triton # type: ignore
191
+ except:
192
+ raise RuntimeError(
193
+ 'Please install flash_attn==0.2.8 and triton==2.0.0.dev20221202.')
194
+
195
+ check_valid_inputs(query, key, value)
196
+
197
+ if dropout_p:
198
+ raise NotImplementedError(
199
+ f'Dropout not implemented for attn_impl: triton.')
200
+
201
+ if needs_weights:
202
+ raise NotImplementedError(
203
+ f'attn_impl: triton cannot return attn weights.')
204
+
205
+ if key_padding_mask is not None:
206
+ warnings.warn(
207
+ 'Propagating key_padding_mask to the attention module ' +
208
+ 'and applying it within the attention module can cause ' +
209
+ 'unnecessary computation/memory usage. Consider integrating ' +
210
+ 'into attn_bias once and passing that to each attention ' +
211
+ 'module instead.'
212
+ )
213
+ b_size, s_k = key_padding_mask.shape[:2]
214
+
215
+ if attn_bias is None:
216
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
217
+
218
+ attn_bias = attn_bias.masked_fill(
219
+ ~key_padding_mask.view((b_size, 1, 1, s_k)),
220
+ torch.finfo(query.dtype).min)
221
+
222
+ query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
223
+ key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
224
+ value = rearrange(value, 'b s (h d) -> b s h d', h=n_heads)
225
+
226
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
227
+ attn_output = flash_attn_triton.flash_attn_func(query, key, value,
228
+ attn_bias, reset_is_causal,
229
+ softmax_scale)
230
+
231
+ output = attn_output.view(*attn_output.shape[:2], -1)
232
+
233
+ return output, None
234
+
235
+
236
+ class MultiheadAttention(nn.Module):
237
+ """Multi-head self attention.
238
+
239
+ Using torch or triton attention implemetation enables user to also use
240
+ additive bias.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ d_model: int,
246
+ n_heads: int,
247
+ attn_impl: str = 'triton',
248
+ attn_clip_qkv: Optional[float] = None,
249
+ attn_qk_ln: bool = False,
250
+ softmax_scale: Optional[float] = None,
251
+ attn_pdrop: float = 0.0,
252
+ low_precision_layernorm: bool = False,
253
+ device: Optional[str] = None,
254
+ ):
255
+ super().__init__()
256
+
257
+ self.attn_impl = attn_impl
258
+ self.clip_qkv = attn_clip_qkv
259
+ self.attn_qk_ln = attn_qk_ln
260
+
261
+ self.d_model = d_model
262
+ self.n_heads = n_heads
263
+ self.softmax_scale = softmax_scale
264
+ if self.softmax_scale is None:
265
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
266
+ self.attn_dropout_p = attn_pdrop
267
+
268
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
269
+ # for param init fn; enables shape based init of fused layers
270
+ fuse_splits = (d_model, 2 * d_model)
271
+ self.Wqkv._fused = (0, fuse_splits) # type: ignore
272
+
273
+ if self.attn_qk_ln:
274
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
275
+ self.q_ln = layernorm_class(self.d_model, device=device)
276
+ self.k_ln = layernorm_class(self.d_model, device=device)
277
+
278
+ if self.attn_impl == 'flash':
279
+ self.attn_fn = flash_attn_fn
280
+ elif self.attn_impl == 'triton':
281
+ self.attn_fn = triton_flash_attn_fn
282
+ warnings.warn(
283
+ 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +
284
+ 'it uses more memory. When training larger models this can trigger ' +
285
+ 'alloc retries which hurts performance. If encountered, we recommend ' +
286
+ 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
287
+ elif self.attn_impl == 'torch':
288
+ self.attn_fn = scaled_multihead_dot_product_attention
289
+ if torch.cuda.is_available():
290
+ warnings.warn(
291
+ 'Using `attn_impl: torch`. If your model does not use `alibi` or ' +
292
+ '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +
293
+ 'we recommend using `attn_impl: triton`.'
294
+ )
295
+ else:
296
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
297
+
298
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
299
+ self.out_proj._is_residual = True # type: ignore
300
+
301
+ def forward(self,
302
+ x,
303
+ past_key_value=None,
304
+ attn_bias=None,
305
+ attention_mask=None,
306
+ is_causal=True,
307
+ needs_weights=False):
308
+ qkv = self.Wqkv(x)
309
+
310
+ if self.clip_qkv:
311
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
312
+
313
+ query, key, value = qkv.chunk(3, dim=2)
314
+
315
+ key_padding_mask = attention_mask
316
+
317
+ if self.attn_qk_ln:
318
+ # Applying layernorm to qk
319
+ dtype = query.dtype
320
+ query = self.q_ln(query).to(dtype)
321
+ key = self.k_ln(key).to(dtype)
322
+
323
+ if past_key_value is not None:
324
+ if len(past_key_value) != 0:
325
+ key = torch.cat([past_key_value[0], key], dim=1)
326
+ value = torch.cat([past_key_value[1], value], dim=1)
327
+
328
+ past_key_value = (key, value)
329
+
330
+ if attn_bias is not None:
331
+ attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
332
+
333
+ context, attn_weights = self.attn_fn(
334
+ query,
335
+ key,
336
+ value,
337
+ self.n_heads,
338
+ softmax_scale=self.softmax_scale,
339
+ attn_bias=attn_bias,
340
+ key_padding_mask=key_padding_mask,
341
+ is_causal=is_causal,
342
+ dropout_p=self.attn_dropout_p,
343
+ training=self.training,
344
+ needs_weights=needs_weights,
345
+ )
346
+
347
+ return self.out_proj(context), attn_weights, past_key_value
348
+
349
+
350
+ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal,
351
+ use_sequence_id):
352
+ if attn_impl == 'flash':
353
+ return None
354
+ elif attn_impl in ['torch', 'triton']:
355
+ if alibi:
356
+ if (prefix_lm or not causal) or use_sequence_id:
357
+ return (1, n_heads, seq_len, seq_len)
358
+ return (1, n_heads, 1, seq_len)
359
+ elif prefix_lm or use_sequence_id:
360
+ return (1, 1, seq_len, seq_len)
361
+ return None
362
+ else:
363
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
364
+
365
+
366
+ def attn_bias(attn_impl,
367
+ attn_bias,
368
+ n_heads,
369
+ seq_len,
370
+ causal=False,
371
+ alibi=False,
372
+ alibi_bias_max=8):
373
+ if attn_impl == 'flash':
374
+ return None
375
+ elif attn_impl in ['torch', 'triton']:
376
+ if alibi:
377
+ # in place add alibi to attn bias
378
+ device, dtype = attn_bias.device, attn_bias.dtype
379
+ attn_bias = attn_bias.add(
380
+ alibi_bias(n_heads,
381
+ seq_len,
382
+ full=not causal,
383
+ alibi_bias_max=alibi_bias_max,
384
+ device=device,
385
+ dtype=dtype))
386
+ return attn_bias
387
+ else:
388
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
389
+
390
+
391
+ def alibi_bias(n_heads,
392
+ seq_len,
393
+ full=False,
394
+ alibi_bias_max=8,
395
+ device=None,
396
+ dtype=None):
397
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=dtype,
398
+ device=device).view(1, 1, 1, seq_len)
399
+ if full:
400
+ # generate 1 x Heads x SeqLen x SeqLen alibi bias mask
401
+ # otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size)
402
+ alibi_bias = alibi_bias - torch.arange(
403
+ 1 - seq_len, 1, dtype=dtype, device=device).view(1, 1, seq_len, 1)
404
+ alibi_bias = alibi_bias.abs().mul(-1)
405
+
406
+ m = torch.arange(1, n_heads + 1, dtype=dtype, device=device)
407
+ m = m.mul(alibi_bias_max / n_heads)
408
+ alibi_bias = alibi_bias * (1. / (2**m.view(1, n_heads, 1, 1)))
409
+ return alibi_bias
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "replit/replit-code-v1-3b",
3
+ "alibi": true,
4
+ "alibi_bias_max": 8,
5
+ "architectures": [
6
+ "ReplitLM"
7
+ ],
8
+ "attn_clip_qkv": null,
9
+ "attn_impl": "torch",
10
+ "attn_pdrop": 0,
11
+ "attn_qk_ln": false,
12
+ "attn_uses_sequence_id": false,
13
+ "auto_map": {
14
+ "AutoConfig": "replit/replit-code-v1-3b--configuration_replit_lm.ReplitLMConfig",
15
+ "AutoModelForCausalLM": "replit/replit-code-v1-3b--replit_lm.ReplitLM"
16
+ },
17
+ "d_model": 2560,
18
+ "emb_init_std": null,
19
+ "emb_init_uniform_lim": null,
20
+ "emb_pdrop": 0,
21
+ "embedding_fraction": 1.0,
22
+ "fan_mode": "fan_in",
23
+ "init_device": "cpu",
24
+ "init_div_is_residual": true,
25
+ "init_gain": 0,
26
+ "init_nonlinearity": "relu",
27
+ "init_std": 0.02,
28
+ "logit_scale": null,
29
+ "low_precision_layernorm": true,
30
+ "max_seq_len": 2048,
31
+ "mlp_ratio": 4,
32
+ "model_type": "replit_lm",
33
+ "n_heads": 32,
34
+ "n_layers": 32,
35
+ "no_bias": true,
36
+ "param_init_fn": "kaiming_normal_",
37
+ "prefix_lm": false,
38
+ "resid_pdrop": 0,
39
+ "softmax_scale": null,
40
+ "tokenizer_name": "replit/replit-code-v1-3b",
41
+ "torch_dtype": "bfloat16",
42
+ "transformers_version": "4.29.2",
43
+ "use_cache": true,
44
+ "verbose": 0,
45
+ "vocab_size": 32772
46
+ }
configuration_replit_lm.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Forked for ReplitLM"""
5
+
6
+ """A HuggingFace-style model configuration."""
7
+
8
+
9
+ from typing import Optional, Tuple, Union
10
+ from transformers import PretrainedConfig
11
+ class ReplitLMConfig(PretrainedConfig):
12
+ model_type = 'replit_lm'
13
+
14
+ def __init__(
15
+ self,
16
+ d_model: int = 2048,
17
+ n_heads: int = 16,
18
+ n_layers: int = 24,
19
+ mlp_ratio: int = 4,
20
+ max_seq_len: int = 2048,
21
+ vocab_size: int = 50368,
22
+ attn_pdrop: float = 0.0,
23
+ resid_pdrop: float = 0.0,
24
+ emb_pdrop: float = 0.0,
25
+ attn_impl: str = 'triton',
26
+ attn_qk_ln: bool = False,
27
+ attn_clip_qkv: Optional[float] = None,
28
+ softmax_scale: Optional[float] = None,
29
+ prefix_lm: Optional[bool] = False,
30
+ attn_uses_sequence_id: Optional[bool] = False,
31
+ alibi: bool = False,
32
+ alibi_bias_max: int = 8,
33
+ init_device: str = 'cpu',
34
+ logit_scale: Optional[Union[float, str]] = None,
35
+ no_bias: bool = False,
36
+ verbose: int = 0,
37
+ param_init_fn: str = 'kaiming_normal_',
38
+ init_div_is_residual: Union[int, float, str, bool] = True,
39
+ init_std: float = 0.02,
40
+ emb_init_std: Optional[float] = None,
41
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float],
42
+ float]] = None,
43
+ init_gain: float = 0,
44
+ fan_mode: str = 'fan_in',
45
+ init_nonlinearity: str = 'relu',
46
+ embedding_fraction: float = 1.0,
47
+ low_precision_layernorm: bool = True,
48
+ use_cache: bool = False,
49
+ **kwargs,
50
+ ):
51
+ """The ReplitLM configuration class.
52
+
53
+ Args:
54
+ d_model (int): The size of the embedding dimension of the model.
55
+ n_heads (int): The number of attention heads.
56
+ n_layers (int): The number of layers in the model.
57
+ mlp_ratio (int): The ratio of the up/down scale in the MLP.
58
+ max_seq_len (int): The maximum sequence length of the model.
59
+ vocab_size (int): The size of the vocabulary.
60
+ attn_pdrop (float): The dropout probability for the attention layers.
61
+ resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
62
+ emb_pdrop (float): The dropout probability for the embedding layer.
63
+ attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
64
+ attn_qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
65
+ attn_clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
66
+ this value.
67
+ softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
68
+ use the default scale of ``1/sqrt(d_keys)``.
69
+ prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
70
+ extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
71
+ can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
72
+ attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
73
+ When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
74
+ which sub-sequence each token belongs to.
75
+ Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
76
+ alibi (bool): Whether to use the alibi bias instead of position embeddings.
77
+ alibi_bias_max (int): The maximum value of the alibi bias.
78
+ init_device (str): The device to use for parameter initialization.
79
+ logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
80
+ no_bias (bool): Whether to use bias in all layers.
81
+ verbose (int): The verbosity level. 0 is silent.
82
+ param_init_fn (str): The parameter initialization scheme to use. One of 'default_', 'baseline_', 'kaiming_uniform_',
83
+ 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 'xavier_normal_'.
84
+ init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
85
+ init_std (float): The standard deviation of the normal distribution used to initialize the model,
86
+ if using the baseline_ parameter initialization scheme.
87
+ emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
88
+ emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
89
+ used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
90
+ init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
91
+ fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
92
+ init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
93
+ embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
94
+ low_precision_layernorm (bool): Whether to use low precision layer normalization.
95
+ use_cache (bool): Whether or not the model should return the last key/values attentions
96
+ """
97
+ self.d_model = d_model
98
+ self.n_heads = n_heads
99
+ self.n_layers = n_layers
100
+ self.mlp_ratio = mlp_ratio
101
+ self.max_seq_len = max_seq_len
102
+ self.vocab_size = vocab_size
103
+ self.attn_pdrop = attn_pdrop
104
+ self.resid_pdrop = resid_pdrop
105
+ self.emb_pdrop = emb_pdrop
106
+ self.attn_impl = attn_impl
107
+ self.attn_qk_ln = attn_qk_ln
108
+ self.attn_clip_qkv = attn_clip_qkv
109
+ self.softmax_scale = softmax_scale
110
+ self.prefix_lm = prefix_lm
111
+ self.attn_uses_sequence_id = attn_uses_sequence_id
112
+ self.alibi = alibi
113
+ self.alibi_bias_max = alibi_bias_max
114
+ self.init_device = init_device
115
+ self.logit_scale = logit_scale
116
+ self.no_bias = no_bias
117
+ self.verbose = verbose
118
+ self.param_init_fn = param_init_fn
119
+ self.init_div_is_residual = init_div_is_residual
120
+ self.init_std = init_std
121
+ self.emb_init_std = emb_init_std
122
+ self.emb_init_uniform_lim = emb_init_uniform_lim
123
+ self.init_std = init_std
124
+ self.init_gain = init_gain
125
+ self.fan_mode = fan_mode
126
+ self.init_nonlinearity = init_nonlinearity
127
+ self.embedding_fraction = embedding_fraction
128
+ self.low_precision_layernorm = low_precision_layernorm
129
+ self.use_cache = use_cache
130
+ if 'name' in kwargs:
131
+ del kwargs['name']
132
+ if 'loss_fn' in kwargs:
133
+ del kwargs['loss_fn']
134
+ super().__init__(**kwargs)
135
+
136
+ self._validate_config()
137
+
138
+ def _validate_config(self):
139
+ if self.d_model % self.n_heads != 0:
140
+ raise ValueError('d_model must be divisible by n_heads')
141
+ if any(prob < 0 or prob > 1
142
+ for prob in [self.attn_pdrop, self.resid_pdrop, self.emb_pdrop]):
143
+ raise ValueError(
144
+ 'attn_pdrop, resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1'
145
+ )
146
+ if self.attn_impl not in ['torch', 'flash', 'triton']:
147
+ raise ValueError(f'Unknown attn_impl={self.attn_impl}')
148
+ if self.prefix_lm and self.attn_impl not in ['torch', 'triton']:
149
+ raise NotImplementedError(
150
+ 'prefix_lm only implemented with torch and triton attention.')
151
+ if self.alibi and self.attn_impl not in ['torch', 'triton']:
152
+ raise NotImplementedError(
153
+ 'alibi only implemented with torch and triton attention.')
154
+ if self.attn_uses_sequence_id and self.attn_impl not in [
155
+ 'torch', 'triton'
156
+ ]:
157
+ raise NotImplementedError(
158
+ 'attn_uses_sequence_id only implemented with torch and triton attention.'
159
+ )
160
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
161
+ raise ValueError(
162
+ 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
163
+ )
164
+ if isinstance(self.logit_scale,
165
+ str) and self.logit_scale != 'inv_sqrt_d_model':
166
+ raise ValueError(
167
+ f"{self.logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
168
+ )
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.29.2",
4
+ "use_cache": false
5
+ }
gpt_blocks.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """GPT Blocks used for the GPT Model."""
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from .attention import MultiheadAttention
12
+ from .low_precision_layernorm import LPLayerNorm
13
+
14
+
15
+ class GPTMLP(nn.Module):
16
+
17
+ def __init__(self,
18
+ d_model: int,
19
+ mlp_ratio: int,
20
+ device: Optional[str] = None):
21
+ super().__init__()
22
+ self.mlp_up = nn.Linear(d_model, mlp_ratio * d_model, device=device)
23
+ self.mlp_act = nn.GELU(approximate='none')
24
+ self.mlp_down = nn.Linear(mlp_ratio * d_model, d_model, device=device)
25
+ self.mlp_down._is_residual = True # type: ignore
26
+
27
+ def forward(self, x):
28
+ return self.mlp_down(self.mlp_act(self.mlp_up(x)))
29
+
30
+
31
+ class GPTBlock(nn.Module):
32
+
33
+ def __init__(self,
34
+ attn_impl: str,
35
+ d_model: int,
36
+ n_heads: int,
37
+ mlp_ratio: int,
38
+ attn_clip_qkv: Optional[float] = None,
39
+ attn_qk_ln: bool = False,
40
+ softmax_scale: Optional[float] = None,
41
+ attn_pdrop: float = 0.0,
42
+ alibi: bool = False,
43
+ resid_pdrop: float = 0.0,
44
+ low_precision_layernorm: bool = False,
45
+ device: Optional[str] = None,
46
+ **kwargs):
47
+ del kwargs # unused, just to capture any extra args from the config
48
+ super().__init__()
49
+
50
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
51
+
52
+ self.ln_1 = layernorm_class(d_model, device=device)
53
+ self.attn = MultiheadAttention(
54
+ attn_impl=attn_impl,
55
+ attn_clip_qkv=attn_clip_qkv,
56
+ attn_qk_ln=attn_qk_ln,
57
+ softmax_scale=softmax_scale,
58
+ attn_pdrop=attn_pdrop,
59
+ d_model=d_model,
60
+ n_heads=n_heads,
61
+ device=device,
62
+ )
63
+ self.ln_2 = layernorm_class(d_model, device=device)
64
+ self.mlp = GPTMLP(
65
+ d_model=d_model,
66
+ mlp_ratio=mlp_ratio,
67
+ device=device,
68
+ )
69
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
70
+ self.resid_mlp_dropout = nn.Dropout(resid_pdrop)
71
+
72
+ def forward(
73
+ self,
74
+ x: torch.Tensor,
75
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
76
+ attn_bias: Optional[torch.Tensor] = None,
77
+ attention_mask: Optional[torch.ByteTensor] = None,
78
+ is_causal: bool = True,
79
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
80
+ a = self.ln_1(x)
81
+ b, _, past_key_value = self.attn(a,
82
+ past_key_value=past_key_value,
83
+ attn_bias=attn_bias,
84
+ attention_mask=attention_mask,
85
+ is_causal=is_causal)
86
+ x = x + self.resid_attn_dropout(b)
87
+ m = self.ln_2(x)
88
+ n = self.mlp(m)
89
+ x = x + self.resid_mlp_dropout(n)
90
+ return x, past_key_value
low_precision_layernorm.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class LPLayerNorm(torch.nn.LayerNorm):
6
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
7
+ super().__init__(
8
+ normalized_shape=normalized_shape,
9
+ eps=eps,
10
+ elementwise_affine=elementwise_affine,
11
+ device=device,
12
+ dtype=dtype,
13
+ )
14
+
15
+ def forward(self, x):
16
+ module_device = x.device
17
+ downcast_x = _cast_if_autocast_enabled(x)
18
+ downcast_weight = _cast_if_autocast_enabled(
19
+ self.weight) if self.weight is not None else self.weight
20
+ downcast_bias = _cast_if_autocast_enabled(
21
+ self.bias) if self.bias is not None else self.bias
22
+ with torch.autocast(enabled=False, device_type=module_device.type):
23
+ return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
24
+
25
+
26
+ def _cast_if_autocast_enabled(tensor):
27
+ if torch.is_autocast_enabled():
28
+ if tensor.device.type == 'cuda':
29
+ dtype = torch.get_autocast_gpu_dtype()
30
+ elif tensor.device.type == 'cpu':
31
+ dtype = torch.get_autocast_cpu_dtype()
32
+ else:
33
+ raise NotImplementedError()
34
+ return tensor.to(dtype=dtype)
35
+ return tensor
modeling_replit_chat.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple, flexible implementation of a GPT model.
2
+
3
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
+ """
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutputWithPast,
14
+ CausalLMOutputWithPast,
15
+ )
16
+ from .attention import attn_bias_shape, build_attn_bias
17
+ from .blocks import MPTBlock
18
+ from .norm import NORM_CLASS_REGISTRY
19
+ from .configuration_mpt import MPTConfig
20
+ from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
21
+ from .hf_prefixlm_converter import (
22
+ add_bidirectional_mask_if_missing,
23
+ convert_hf_causal_lm_to_prefix_lm,
24
+ )
25
+ from .meta_init_context import init_empty_weights
26
+ from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
27
+
28
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
29
+
30
+
31
+ class MPTPreTrainedModel(PreTrainedModel):
32
+ config_class = MPTConfig
33
+ base_model_prefix = "model"
34
+
35
+
36
+ class MPTModel(MPTPreTrainedModel):
37
+ def __init__(self, config: MPTConfig):
38
+ config._validate_config()
39
+ super().__init__(config)
40
+ self.attn_impl = config.attn_config["attn_impl"]
41
+ self.prefix_lm = config.attn_config["prefix_lm"]
42
+ self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
43
+ self.alibi = config.attn_config["alibi"]
44
+ self.alibi_bias_max = config.attn_config["alibi_bias_max"]
45
+ if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
46
+ norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
47
+ raise NotImplementedError(
48
+ f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
49
+ )
50
+ norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
51
+ self.embedding_fraction = config.embedding_fraction
52
+ self.wte = nn.Embedding(
53
+ config.vocab_size, config.d_model, device=config.init_device
54
+ )
55
+ if not self.alibi:
56
+ self.wpe = nn.Embedding(
57
+ config.max_seq_len, config.d_model, device=config.init_device
58
+ )
59
+ self.emb_drop = nn.Dropout(config.emb_pdrop)
60
+ self.blocks = nn.ModuleList(
61
+ [
62
+ MPTBlock(device=config.init_device, **config.to_dict())
63
+ for _ in range(config.n_layers)
64
+ ]
65
+ )
66
+ self.norm_f = norm_class(config.d_model, device=config.init_device)
67
+ if config.init_device != "meta":
68
+ self.apply(self.param_init_fn)
69
+ self.is_causal = not self.prefix_lm
70
+ self._attn_bias_initialized = False
71
+ self.attn_bias = None
72
+ self.attn_bias_shape = attn_bias_shape(
73
+ self.attn_impl,
74
+ config.n_heads,
75
+ config.max_seq_len,
76
+ self.alibi,
77
+ prefix_lm=self.prefix_lm,
78
+ causal=self.is_causal,
79
+ use_sequence_id=self.attn_uses_sequence_id,
80
+ )
81
+ if config.no_bias:
82
+ for module in self.modules():
83
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
84
+ if config.verbose:
85
+ warnings.warn(f"Removing bias ({module.bias}) from {module}.")
86
+ module.register_parameter("bias", None)
87
+ if config.verbose and config.verbose > 2:
88
+ print(self)
89
+ if "verbose" not in self.config.init_config:
90
+ self.config.init_config["verbose"] = self.config.verbose
91
+ if self.config.init_config["verbose"] > 1:
92
+ init_fn_name = self.config.init_config["name"]
93
+ warnings.warn(f"Using {init_fn_name} initialization.")
94
+
95
+ def get_input_embeddings(self):
96
+ return self.wte
97
+
98
+ def set_input_embeddings(self, value):
99
+ self.wte = value
100
+
101
+ @torch.no_grad()
102
+ def _attn_bias(
103
+ self,
104
+ device,
105
+ dtype,
106
+ attention_mask: Optional[torch.ByteTensor] = None,
107
+ prefix_mask: Optional[torch.ByteTensor] = None,
108
+ sequence_id: Optional[torch.LongTensor] = None,
109
+ ):
110
+ if not self._attn_bias_initialized:
111
+ if self.attn_bias_shape:
112
+ self.attn_bias = torch.zeros(
113
+ self.attn_bias_shape, device=device, dtype=dtype
114
+ )
115
+ self.attn_bias = build_attn_bias(
116
+ self.attn_impl,
117
+ self.attn_bias,
118
+ self.config.n_heads,
119
+ self.config.max_seq_len,
120
+ causal=self.is_causal,
121
+ alibi=self.alibi,
122
+ alibi_bias_max=self.alibi_bias_max,
123
+ )
124
+ self._attn_bias_initialized = True
125
+ if self.attn_impl == "flash":
126
+ return (self.attn_bias, attention_mask)
127
+ if self.attn_bias is not None:
128
+ self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
129
+ attn_bias = self.attn_bias
130
+ if self.prefix_lm:
131
+ assert isinstance(attn_bias, torch.Tensor)
132
+ assert isinstance(prefix_mask, torch.Tensor)
133
+ attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
134
+ if self.attn_uses_sequence_id and sequence_id is not None:
135
+ assert isinstance(attn_bias, torch.Tensor)
136
+ attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
137
+ if attention_mask is not None:
138
+ s_k = attention_mask.shape[-1]
139
+ if attn_bias is None:
140
+ attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
141
+ else:
142
+ attn_bias = attn_bias[:, :, :, -s_k:]
143
+ if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
144
+ raise ValueError(
145
+ f"attention_mask shape={attention_mask.shape} "
146
+ + f"and prefix_mask shape={prefix_mask.shape} are not equal."
147
+ )
148
+ min_val = torch.finfo(attn_bias.dtype).min
149
+ attn_bias = attn_bias.masked_fill(
150
+ ~attention_mask.view(-1, 1, 1, s_k), min_val
151
+ )
152
+ return (attn_bias, None)
153
+
154
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
155
+ (s_k, s_q) = attn_bias.shape[-2:]
156
+ if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
157
+ raise ValueError(
158
+ "attn_bias does not match the expected shape. "
159
+ + f"The last two dimensions should both be {self.config.max_length} "
160
+ + f"but are {s_k} and {s_q}."
161
+ )
162
+ seq_len = prefix_mask.shape[-1]
163
+ if seq_len > self.config.max_seq_len:
164
+ raise ValueError(
165
+ f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
166
+ )
167
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
168
+ causal = torch.tril(
169
+ torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
170
+ ).view(1, 1, seq_len, seq_len)
171
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
172
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
173
+ min_val = torch.finfo(attn_bias.dtype).min
174
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
175
+ return attn_bias
176
+
177
+ def _apply_sequence_id(
178
+ self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
179
+ ):
180
+ seq_len = sequence_id.shape[-1]
181
+ if seq_len > self.config.max_seq_len:
182
+ raise ValueError(
183
+ f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
184
+ )
185
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
186
+ cannot_attend = torch.logical_not(
187
+ torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
188
+ ).unsqueeze(1)
189
+ min_val = torch.finfo(attn_bias.dtype).min
190
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
191
+ return attn_bias
192
+
193
+ def forward(
194
+ self,
195
+ input_ids: torch.LongTensor,
196
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
197
+ attention_mask: Optional[torch.ByteTensor] = None,
198
+ prefix_mask: Optional[torch.ByteTensor] = None,
199
+ sequence_id: Optional[torch.LongTensor] = None,
200
+ return_dict: Optional[bool] = None,
201
+ output_attentions: Optional[bool] = None,
202
+ output_hidden_states: Optional[bool] = None,
203
+ use_cache: Optional[bool] = None,
204
+ ):
205
+ return_dict = (
206
+ return_dict if return_dict is not None else self.config.return_dict
207
+ )
208
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
209
+ if attention_mask is not None:
210
+ attention_mask = attention_mask.bool()
211
+ if prefix_mask is not None:
212
+ prefix_mask = prefix_mask.bool()
213
+ if not return_dict:
214
+ raise NotImplementedError(
215
+ "return_dict False is not implemented yet for MPT"
216
+ )
217
+ if output_attentions:
218
+ raise NotImplementedError(
219
+ "output_attentions is not implemented yet for MPT"
220
+ )
221
+ if (
222
+ attention_mask is not None
223
+ and attention_mask[:, 0].sum() != attention_mask.shape[0]
224
+ and self.training
225
+ ):
226
+ raise NotImplementedError(
227
+ "MPT does not support training with left padding."
228
+ )
229
+ if self.prefix_lm and prefix_mask is None:
230
+ raise ValueError(
231
+ "prefix_mask is a required argument when MPT is configured with prefix_lm=True."
232
+ )
233
+ if self.training:
234
+ if self.attn_uses_sequence_id and sequence_id is None:
235
+ raise ValueError(
236
+ "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
237
+ + "and the model is in train mode."
238
+ )
239
+ elif self.attn_uses_sequence_id is False and sequence_id is not None:
240
+ warnings.warn(
241
+ "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
242
+ + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
243
+ )
244
+ S = input_ids.size(1)
245
+ assert (
246
+ S <= self.config.max_seq_len
247
+ ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
248
+ tok_emb = self.wte(input_ids)
249
+ if self.alibi:
250
+ x = tok_emb
251
+ else:
252
+ past_position = 0
253
+ if past_key_values is not None:
254
+ if len(past_key_values) != self.config.n_layers:
255
+ raise ValueError(
256
+ f"past_key_values must provide a past_key_value for each attention "
257
+ + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
258
+ )
259
+ past_position = past_key_values[0][0].size(1)
260
+ if S + past_position > self.config.max_seq_len:
261
+ raise ValueError(
262
+ f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
263
+ )
264
+ pos = torch.arange(
265
+ past_position,
266
+ S + past_position,
267
+ dtype=torch.long,
268
+ device=input_ids.device,
269
+ ).unsqueeze(0)
270
+ if attention_mask is not None:
271
+ pos = torch.clamp(
272
+ pos
273
+ - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
274
+ :, past_position:
275
+ ],
276
+ min=0,
277
+ )
278
+ pos_emb = self.wpe(pos)
279
+ x = tok_emb + pos_emb
280
+ if self.embedding_fraction == 1:
281
+ x = self.emb_drop(x)
282
+ else:
283
+ x_shrunk = x * self.embedding_fraction + x.detach() * (
284
+ 1 - self.embedding_fraction
285
+ )
286
+ assert isinstance(self.emb_drop, nn.Module)
287
+ x = self.emb_drop(x_shrunk)
288
+ (attn_bias, attention_mask) = self._attn_bias(
289
+ device=x.device,
290
+ dtype=x.dtype,
291
+ attention_mask=attention_mask,
292
+ prefix_mask=prefix_mask,
293
+ sequence_id=sequence_id,
294
+ )
295
+ if use_cache and past_key_values is None:
296
+ past_key_values = [() for _ in range(self.config.n_layers)]
297
+ all_hidden_states = () if output_hidden_states else None
298
+ for b_idx, block in enumerate(self.blocks):
299
+ if output_hidden_states:
300
+ assert all_hidden_states is not None
301
+ all_hidden_states = all_hidden_states + (x,)
302
+ past_key_value = (
303
+ past_key_values[b_idx] if past_key_values is not None else None
304
+ )
305
+ (x, past_key_value) = block(
306
+ x,
307
+ past_key_value=past_key_value,
308
+ attn_bias=attn_bias,
309
+ attention_mask=attention_mask,
310
+ is_causal=self.is_causal,
311
+ )
312
+ if past_key_values is not None:
313
+ past_key_values[b_idx] = past_key_value
314
+ x = self.norm_f(x)
315
+ return BaseModelOutputWithPast(
316
+ last_hidden_state=x,
317
+ past_key_values=past_key_values,
318
+ hidden_states=all_hidden_states,
319
+ )
320
+
321
+ def param_init_fn(self, module):
322
+ init_fn_name = self.config.init_config["name"]
323
+ MODEL_INIT_REGISTRY[init_fn_name](
324
+ module=module,
325
+ n_layers=self.config.n_layers,
326
+ d_model=self.config.d_model,
327
+ **self.config.init_config,
328
+ )
329
+
330
+ def fsdp_wrap_fn(self, module):
331
+ return isinstance(module, MPTBlock)
332
+
333
+ def activation_checkpointing_fn(self, module):
334
+ return isinstance(module, MPTBlock)
335
+
336
+
337
+ class MPTForCausalLM(MPTPreTrainedModel):
338
+ def __init__(self, config: MPTConfig):
339
+ super().__init__(config)
340
+ if not config.tie_word_embeddings:
341
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
342
+ self.transformer = MPTModel(config)
343
+ self.logit_scale = None
344
+ if config.logit_scale is not None:
345
+ logit_scale = config.logit_scale
346
+ if isinstance(logit_scale, str):
347
+ if logit_scale == "inv_sqrt_d_model":
348
+ logit_scale = 1 / math.sqrt(config.d_model)
349
+ else:
350
+ raise ValueError(
351
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
352
+ )
353
+ self.logit_scale = logit_scale
354
+
355
+ def get_input_embeddings(self):
356
+ return self.transformer.wte
357
+
358
+ def set_input_embeddings(self, value):
359
+ self.transformer.wte = value
360
+
361
+ def get_output_embeddings(self):
362
+ return self.transformer.wte
363
+
364
+ def set_output_embeddings(self, new_embeddings):
365
+ self.transformer.wte = new_embeddings
366
+
367
+ def set_decoder(self, decoder):
368
+ self.transformer = decoder
369
+
370
+ def get_decoder(self):
371
+ return self.transformer
372
+
373
+ def forward(
374
+ self,
375
+ input_ids: torch.LongTensor,
376
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
377
+ attention_mask: Optional[torch.ByteTensor] = None,
378
+ prefix_mask: Optional[torch.ByteTensor] = None,
379
+ sequence_id: Optional[torch.LongTensor] = None,
380
+ labels: Optional[torch.LongTensor] = None,
381
+ return_dict: Optional[bool] = None,
382
+ output_attentions: Optional[bool] = None,
383
+ output_hidden_states: Optional[bool] = None,
384
+ use_cache: Optional[bool] = None,
385
+ ):
386
+ return_dict = (
387
+ return_dict if return_dict is not None else self.config.return_dict
388
+ )
389
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
390
+ outputs = self.transformer(
391
+ input_ids=input_ids,
392
+ past_key_values=past_key_values,
393
+ attention_mask=attention_mask,
394
+ prefix_mask=prefix_mask,
395
+ sequence_id=sequence_id,
396
+ return_dict=return_dict,
397
+ output_attentions=output_attentions,
398
+ output_hidden_states=output_hidden_states,
399
+ use_cache=use_cache,
400
+ )
401
+ logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
402
+ if self.logit_scale is not None:
403
+ if self.logit_scale == 0:
404
+ warnings.warn(
405
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
406
+ )
407
+ logits *= self.logit_scale
408
+ loss = None
409
+ if labels is not None:
410
+ labels = torch.roll(labels, shifts=-1)
411
+ labels[:, -1] = -100
412
+ loss = F.cross_entropy(
413
+ logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
414
+ )
415
+ return CausalLMOutputWithPast(
416
+ loss=loss,
417
+ logits=logits,
418
+ past_key_values=outputs.past_key_values,
419
+ hidden_states=outputs.hidden_states,
420
+ )
421
+
422
+ def param_init_fn(self, module):
423
+ init_fn_name = self.config.init_config["name"]
424
+ MODEL_INIT_REGISTRY[init_fn_name](
425
+ module=module,
426
+ n_layers=self.config.n_layers,
427
+ d_model=self.config.d_model,
428
+ **self.config.init_config,
429
+ )
430
+
431
+ def fsdp_wrap_fn(self, module):
432
+ return isinstance(module, MPTBlock)
433
+
434
+ def activation_checkpointing_fn(self, module):
435
+ return isinstance(module, MPTBlock)
436
+
437
+ def prepare_inputs_for_generation(
438
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
439
+ ):
440
+ if inputs_embeds is not None:
441
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
442
+ attention_mask = kwargs["attention_mask"].bool()
443
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
444
+ raise NotImplementedError(
445
+ "MPT does not support generation with right padding."
446
+ )
447
+ if self.transformer.attn_uses_sequence_id and self.training:
448
+ sequence_id = torch.zeros_like(input_ids[:1])
449
+ else:
450
+ sequence_id = None
451
+ if past_key_values is not None:
452
+ input_ids = input_ids[:, -1].unsqueeze(-1)
453
+ if self.transformer.prefix_lm:
454
+ prefix_mask = torch.ones_like(attention_mask)
455
+ if kwargs.get("use_cache") == False:
456
+ raise NotImplementedError(
457
+ "MPT with prefix_lm=True does not support use_cache=False."
458
+ )
459
+ else:
460
+ prefix_mask = None
461
+ return {
462
+ "input_ids": input_ids,
463
+ "attention_mask": attention_mask,
464
+ "prefix_mask": prefix_mask,
465
+ "sequence_id": sequence_id,
466
+ "past_key_values": past_key_values,
467
+ "use_cache": kwargs.get("use_cache", True),
468
+ }
469
+
470
+ @staticmethod
471
+ def _reorder_cache(past_key_values, beam_idx):
472
+ """Used by HuggingFace generate when using beam search with kv-caching.
473
+
474
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
475
+ for an example in transformers.
476
+ """
477
+ reordered_past = []
478
+ for layer_past in past_key_values:
479
+ reordered_past += [
480
+ tuple(
481
+ (past_state.index_select(0, beam_idx) for past_state in layer_past)
482
+ )
483
+ ]
484
+ return reordered_past
param_init_fns.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import math
4
+ import warnings
5
+ from collections.abc import Sequence
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ def torch_default_param_init_fn_(
14
+ module: nn.Module,
15
+ verbose: int = 0,
16
+ **kwargs,
17
+ ):
18
+ del kwargs # unused, just to capture any extra args from the config
19
+ if verbose > 1:
20
+ warnings.warn(
21
+ f"Initializing network using module's reset_parameters attribute")
22
+
23
+ if hasattr(module, 'reset_parameters'):
24
+ module.reset_parameters() # type: ignore
25
+
26
+
27
+ def fused_init_helper_(module: nn.Module, init_fn_):
28
+ # parameter initialization is often based on the parameters shape.
29
+ # If a layer is fused, initialization should be based on the shapes
30
+ # of the original tensor instead of the shape of the fused tensor.
31
+ # Layers which are fused should have the _fused attibute defined.
32
+ # The first element of _fused is the dimension along which the tensor is fused.
33
+ # This is followed by an iterable of split indices."
34
+
35
+ _fused = getattr(module, '_fused', None)
36
+
37
+ if _fused is None:
38
+ raise RuntimeError(f'Internal logic error')
39
+
40
+ dim, splits = _fused
41
+ splits = (0, *splits, module.weight.size(dim)) # type: ignore
42
+ for s, e in zip(splits[:-1], splits[1:]):
43
+ slice_indices = [slice(None)] * module.weight.ndim # type: ignore
44
+ slice_indices[dim] = slice(s, e)
45
+ init_fn_(module.weight[slice_indices]) # type: ignore
46
+
47
+
48
+ def generic_param_init_fn_(
49
+ module: nn.Module,
50
+ init_fn_,
51
+ n_layers: int,
52
+ d_model: Optional[int] = None,
53
+ init_div_is_residual: Union[int, float, str, bool] = True,
54
+ emb_init_std: Optional[float] = None,
55
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
56
+ verbose: int = 0,
57
+ **kwargs,
58
+ ):
59
+ del kwargs # unused, just to capture any extra args from the config
60
+ if verbose > 1:
61
+ warnings.warn(
62
+ f'If model has bias parameters they are initialized to 0.')
63
+
64
+ # enable user to divide _is_residual weights by
65
+ # a value which defaults to math.sqrt(2 * cfg.n_layers)
66
+ init_div_is_residual = init_div_is_residual
67
+
68
+ if init_div_is_residual is False:
69
+ # not used, for pyright
70
+ div_is_residual = 1.0
71
+ elif init_div_is_residual is True:
72
+ div_is_residual = math.sqrt(2 * n_layers)
73
+ elif isinstance(init_div_is_residual, float) or isinstance(
74
+ init_div_is_residual, int):
75
+ div_is_residual = init_div_is_residual
76
+ elif isinstance(init_div_is_residual,
77
+ str) and init_div_is_residual.isnumeric():
78
+ # do not trust YAML parsing to always convert numbers to numbers
79
+ div_is_residual = float(init_div_is_residual)
80
+ else:
81
+ # not used, for pyright
82
+ div_is_residual = 1.0
83
+ raise ValueError(
84
+ f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
85
+ )
86
+
87
+ if init_div_is_residual is not False:
88
+ if verbose > 1:
89
+ warnings.warn(
90
+ f'Initializing _is_residual layers then dividing them by {div_is_residual}.' +
91
+ f'set `init_div_is_residual: false` in model config to disable this.'
92
+ )
93
+
94
+ if isinstance(module, nn.Linear):
95
+ # Linear
96
+ if hasattr(module, '_fused'):
97
+ fused_init_helper_(module, init_fn_)
98
+ else:
99
+ init_fn_(module.weight)
100
+ if module.bias is not None:
101
+ torch.nn.init.zeros_(module.bias)
102
+
103
+ if init_div_is_residual is not False and getattr(
104
+ module, '_is_residual', False):
105
+ with torch.no_grad():
106
+ module.weight.div_(div_is_residual)
107
+
108
+ elif isinstance(module, nn.Embedding):
109
+ # Embedding
110
+ if emb_init_std is not None:
111
+ std = emb_init_std
112
+ if std == 0:
113
+ warnings.warn(f'Embedding layer initialized to 0.')
114
+ emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
115
+ if verbose > 1:
116
+ warnings.warn(
117
+ f'Embedding layer initialized using normal distribution with mean=0 and {std=}.'
118
+ )
119
+ elif emb_init_uniform_lim is not None:
120
+ lim = emb_init_uniform_lim
121
+ if isinstance(lim, Sequence):
122
+ if len(lim) > 2:
123
+ raise ValueError(
124
+ f'Uniform init requires a min and a max limit. User input: {lim}.'
125
+ )
126
+ if lim[0] == lim[1]:
127
+ warnings.warn(f'Embedding layer initialized to {lim[0]}.')
128
+ else:
129
+ if lim == 0:
130
+ warnings.warn(f'Embedding layer initialized to 0.')
131
+ lim = [-lim, lim]
132
+ a, b = lim
133
+ emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
134
+ if verbose > 1:
135
+ warnings.warn(
136
+ f'Embedding layer initialized using uniform distribution in range {lim}.'
137
+ )
138
+ else:
139
+ emb_init_fn_ = init_fn_
140
+
141
+ emb_init_fn_(module.weight)
142
+
143
+ elif isinstance(module, nn.LayerNorm):
144
+ # LayerNorm
145
+ if verbose > 1:
146
+ warnings.warn(
147
+ f'LayerNorm gamma weights are set to 1. If the layer has a bias it is initialized to 0.'
148
+ )
149
+ torch.nn.init.ones_(module.weight)
150
+ if module.bias is not None:
151
+ torch.nn.init.zeros_(module.bias)
152
+
153
+ elif isinstance(module, nn.MultiheadAttention):
154
+ # torch's MultiheadAttention
155
+ if module._qkv_same_embed_dim:
156
+ assert module.in_proj_weight is not None
157
+ assert module.q_proj_weight is None and module.k_proj_weight is None and module.v_proj_weight is None
158
+ assert d_model is not None
159
+ # in_proj_weight is actually 3 layers and should be split up for width based init
160
+ _d = d_model
161
+ splits = (0, _d, 2 * _d, 3 * _d)
162
+ for s, e in zip(splits[:-1], splits[1:]):
163
+ init_fn_(module.in_proj_weight[s:e])
164
+ else:
165
+ assert module.q_proj_weight is not None and module.k_proj_weight is not None and module.v_proj_weight is not None
166
+ assert module.in_proj_weight is None
167
+ init_fn_(module.q_proj_weight)
168
+ init_fn_(module.k_proj_weight)
169
+ init_fn_(module.v_proj_weight)
170
+
171
+ # bias
172
+ if module.in_proj_bias is not None:
173
+ torch.nn.init.zeros_(module.in_proj_bias)
174
+ if module.bias_k is not None:
175
+ torch.nn.init.zeros_(module.bias_k)
176
+ if module.bias_v is not None:
177
+ torch.nn.init.zeros_(module.bias_v)
178
+
179
+ # out proj
180
+ init_fn_(module.out_proj.weight)
181
+ if init_div_is_residual is not False and getattr(
182
+ module.out_proj, '_is_residual', False):
183
+ with torch.no_grad():
184
+ module.out_proj.weight.div_(div_is_residual)
185
+ if module.out_proj.bias is not None:
186
+ torch.nn.init.zeros_(module.out_proj.bias)
187
+
188
+ else:
189
+ for _ in module.parameters(recurse=False):
190
+ # raise error if uninitialized module has any parameters
191
+ raise NotImplementedError(
192
+ f'{module.__class__.__name__} parameters are not initialized by param_init_fn.'
193
+ )
194
+
195
+
196
+ def _normal_init_(std, mean=0.0):
197
+ return partial(torch.nn.init.normal_, mean=mean, std=std)
198
+
199
+
200
+ def _normal_param_init_fn_(
201
+ module: nn.Module,
202
+ std: float,
203
+ n_layers: int,
204
+ d_model: Optional[int] = None,
205
+ init_div_is_residual: Union[int, float, str, bool] = True,
206
+ emb_init_std: Optional[float] = None,
207
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
208
+ verbose: int = 0,
209
+ **kwargs,
210
+ ):
211
+ del kwargs # unused, just to capture any extra args from the config
212
+ init_fn_ = _normal_init_(std=std)
213
+
214
+ if verbose > 1:
215
+ warnings.warn(
216
+ f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
217
+
218
+ generic_param_init_fn_(
219
+ module=module,
220
+ init_fn_=init_fn_,
221
+ d_model=d_model,
222
+ n_layers=n_layers,
223
+ init_div_is_residual=init_div_is_residual,
224
+ emb_init_std=emb_init_std,
225
+ emb_init_uniform_lim=emb_init_uniform_lim,
226
+ verbose=verbose,
227
+ )
228
+
229
+
230
+ def baseline_param_init_fn_(
231
+ module: nn.Module,
232
+ init_std: float,
233
+ n_layers: int,
234
+ d_model: Optional[int] = None,
235
+ init_div_is_residual: Union[int, float, str, bool] = True,
236
+ emb_init_std: Optional[float] = None,
237
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
238
+ verbose: int = 0,
239
+ **kwargs,
240
+ ):
241
+ del kwargs # unused, just to capture any extra args from the config
242
+ if init_std is None:
243
+ raise ValueError(
244
+ 'You must set model.init_std to a float value to use the default initialization scheme.'
245
+ )
246
+ _normal_param_init_fn_(
247
+ module=module,
248
+ std=init_std,
249
+ d_model=d_model,
250
+ n_layers=n_layers,
251
+ init_div_is_residual=init_div_is_residual,
252
+ emb_init_std=emb_init_std,
253
+ emb_init_uniform_lim=emb_init_uniform_lim,
254
+ verbose=verbose,
255
+ )
256
+
257
+
258
+ def small_param_init_fn_(
259
+ module: nn.Module,
260
+ n_layers: int,
261
+ d_model: int,
262
+ init_div_is_residual: Union[int, float, str, bool] = True,
263
+ emb_init_std: Optional[float] = None,
264
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
265
+ verbose: int = 0,
266
+ **kwargs,
267
+ ):
268
+ del kwargs # unused, just to capture any extra args from the config
269
+ # very close to kaiming normal
270
+ # from Transformers without Tears (2019) - Nguyen & Salazar
271
+ std = math.sqrt(2 / (5 * d_model))
272
+ _normal_param_init_fn_(
273
+ module=module,
274
+ std=std,
275
+ d_model=d_model,
276
+ n_layers=n_layers,
277
+ init_div_is_residual=init_div_is_residual,
278
+ emb_init_std=emb_init_std,
279
+ emb_init_uniform_lim=emb_init_uniform_lim,
280
+ verbose=verbose,
281
+ )
282
+
283
+
284
+ def neox_param_init_fn_(
285
+ module: nn.Module,
286
+ n_layers: int,
287
+ d_model: int,
288
+ emb_init_std: Optional[float] = None,
289
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
290
+ verbose: int = 0,
291
+ **kwargs,
292
+ ):
293
+ """From section 2.3.1 of GPT-NeoX-20B:
294
+
295
+ An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
296
+ see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
297
+ and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
298
+ """
299
+ del kwargs # unused, just to capture any extra args from the config
300
+ residual_div = n_layers / math.sqrt(10) # small std / wang std
301
+
302
+ if verbose > 1:
303
+ warnings.warn(f'setting init_div_is_residual to {residual_div}')
304
+
305
+ small_param_init_fn_(
306
+ module=module,
307
+ d_model=d_model,
308
+ n_layers=n_layers,
309
+ init_div_is_residual=residual_div,
310
+ emb_init_std=emb_init_std,
311
+ emb_init_uniform_lim=emb_init_uniform_lim,
312
+ verbose=verbose,
313
+ )
314
+
315
+
316
+ def kaiming_uniform_param_init_fn_(
317
+ module: nn.Module,
318
+ n_layers: int,
319
+ d_model: Optional[int] = None,
320
+ init_div_is_residual: Union[int, float, str, bool] = True,
321
+ emb_init_std: Optional[float] = None,
322
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
323
+ init_gain: float = 0,
324
+ fan_mode: str = 'fan_in',
325
+ init_nonlinearity: str = 'leaky_relu',
326
+ verbose: int = 0,
327
+ **kwargs,
328
+ ):
329
+ del kwargs # unused, just to capture any extra args from the config
330
+
331
+ if verbose > 1:
332
+ warnings.warn(
333
+ f'Using nn.init.kaiming_uniform_ init fn with parameters: ' +
334
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
335
+ )
336
+
337
+ kaiming_uniform_ = partial(nn.init.kaiming_uniform_,
338
+ a=init_gain,
339
+ mode=fan_mode,
340
+ nonlinearity=init_nonlinearity)
341
+
342
+ generic_param_init_fn_(
343
+ module=module,
344
+ init_fn_=kaiming_uniform_,
345
+ d_model=d_model,
346
+ n_layers=n_layers,
347
+ init_div_is_residual=init_div_is_residual,
348
+ emb_init_std=emb_init_std,
349
+ emb_init_uniform_lim=emb_init_uniform_lim,
350
+ verbose=verbose,
351
+ )
352
+
353
+
354
+ def kaiming_normal_param_init_fn_(
355
+ module: nn.Module,
356
+ n_layers: int,
357
+ d_model: Optional[int] = None,
358
+ init_div_is_residual: Union[int, float, str, bool] = True,
359
+ emb_init_std: Optional[float] = None,
360
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
361
+ init_gain: float = 0,
362
+ fan_mode: str = 'fan_in',
363
+ init_nonlinearity: str = 'leaky_relu',
364
+ verbose: int = 0,
365
+ **kwargs,
366
+ ):
367
+ del kwargs # unused, just to capture any extra args from the config
368
+
369
+ if verbose > 1:
370
+ warnings.warn(
371
+ f'Using nn.init.kaiming_normal_ init fn with parameters: ' +
372
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
373
+ )
374
+
375
+ kaiming_normal_ = partial(torch.nn.init.kaiming_normal_,
376
+ a=init_gain,
377
+ mode=fan_mode,
378
+ nonlinearity=init_nonlinearity)
379
+
380
+ generic_param_init_fn_(
381
+ module=module,
382
+ init_fn_=kaiming_normal_,
383
+ d_model=d_model,
384
+ n_layers=n_layers,
385
+ init_div_is_residual=init_div_is_residual,
386
+ emb_init_std=emb_init_std,
387
+ emb_init_uniform_lim=emb_init_uniform_lim,
388
+ verbose=verbose,
389
+ )
390
+
391
+
392
+ def xavier_uniform_param_init_fn_(
393
+ module: nn.Module,
394
+ n_layers: int,
395
+ d_model: Optional[int] = None,
396
+ init_div_is_residual: Union[int, float, str, bool] = True,
397
+ emb_init_std: Optional[float] = None,
398
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
399
+ init_gain: float = 0,
400
+ verbose: int = 0,
401
+ **kwargs,
402
+ ):
403
+ del kwargs # unused, just to capture any extra args from the config
404
+ xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
405
+
406
+ if verbose > 1:
407
+ warnings.warn(
408
+ f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' +
409
+ f'gain={init_gain}'
410
+ )
411
+
412
+ generic_param_init_fn_(
413
+ module=module,
414
+ init_fn_=xavier_uniform_,
415
+ d_model=d_model,
416
+ n_layers=n_layers,
417
+ init_div_is_residual=init_div_is_residual,
418
+ emb_init_std=emb_init_std,
419
+ emb_init_uniform_lim=emb_init_uniform_lim,
420
+ verbose=verbose,
421
+ )
422
+
423
+
424
+ def xavier_normal_param_init_fn_(
425
+ module: nn.Module,
426
+ n_layers: int,
427
+ d_model: Optional[int] = None,
428
+ init_div_is_residual: Union[int, float, str, bool] = True,
429
+ emb_init_std: Optional[float] = None,
430
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
431
+ init_gain: float = 0,
432
+ verbose: int = 0,
433
+ **kwargs,
434
+ ):
435
+ xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
436
+
437
+ if verbose > 1:
438
+ warnings.warn(
439
+ f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' +
440
+ f'gain={init_gain}'
441
+ )
442
+
443
+ generic_param_init_fn_(
444
+ module=module,
445
+ init_fn_=xavier_normal_,
446
+ d_model=d_model,
447
+ n_layers=n_layers,
448
+ init_div_is_residual=init_div_is_residual,
449
+ emb_init_std=emb_init_std,
450
+ emb_init_uniform_lim=emb_init_uniform_lim,
451
+ verbose=verbose,
452
+ )
453
+
454
+
455
+ MODEL_INIT_REGISTRY = {
456
+ 'default_': torch_default_param_init_fn_,
457
+ 'baseline_': baseline_param_init_fn_,
458
+ 'kaiming_uniform_': kaiming_uniform_param_init_fn_,
459
+ 'kaiming_normal_': kaiming_normal_param_init_fn_,
460
+ 'neox_init_': neox_param_init_fn_,
461
+ 'small_init_': small_param_init_fn_,
462
+ 'xavier_uniform_': xavier_uniform_param_init_fn_,
463
+ 'xavier_normal_': xavier_normal_param_init_fn_,
464
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4576dc9a84c76799a2b60a25f0b5bb4c13552cd68f7a3032eb58b9cc334452f
3
+ size 5201316581
replit_lm.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Forked from the MosaicGPT model class from the Mosaic Examples codebase of date May 1st, 2023.
5
+ Permalink: https://github.com/mosaicml/examples/blob/52cd4fef69497f225a034fcd10692f8613732d10/examples/llm/src/models/mosaic_gpt/mosaic_gpt.py
6
+ """
7
+
8
+ """A simple, flexible implementation of a GPT model.
9
+
10
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
11
+ """
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import warnings
18
+
19
+ from transformers import PreTrainedModel
20
+ from transformers.modeling_outputs import (
21
+ CausalLMOutputWithPast,
22
+ BaseModelOutputWithPast,
23
+ )
24
+ from typing import List, Optional, Tuple
25
+
26
+ from .attention import (
27
+ attn_bias as module_attn_bias,
28
+ attn_bias_shape as module_attn_bias_shape,
29
+ )
30
+ from .gpt_blocks import GPTBlock
31
+ from .configuration_replit_lm import ReplitLMConfig
32
+ from .param_init_fns import MODEL_INIT_REGISTRY
33
+ from .low_precision_layernorm import LPLayerNorm
34
+
35
+
36
+ class ReplitLM(PreTrainedModel):
37
+ config_class = ReplitLMConfig
38
+ base_model_prefix = "replit_lm"
39
+
40
+ def __init__(self, config: ReplitLMConfig):
41
+ super().__init__(config)
42
+
43
+ if config.attn_impl == "flash" and config.alibi:
44
+ raise RuntimeError(
45
+ "ALiBi is not supported with flash attention. Please use triton or torch."
46
+ )
47
+
48
+ self.attn_impl = config.attn_impl
49
+ self.prefix_lm = config.prefix_lm
50
+ self.attn_uses_sequence_id = config.attn_uses_sequence_id
51
+ self.alibi = config.alibi
52
+ self.alibi_bias_max = config.alibi_bias_max
53
+
54
+ layernorm_class = (
55
+ LPLayerNorm if config.low_precision_layernorm else nn.LayerNorm
56
+ )
57
+
58
+ # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
59
+ # both report this helping with stabilizing training
60
+ self.embedding_fraction = config.embedding_fraction
61
+
62
+ self.transformer = nn.ModuleDict(
63
+ {
64
+ "wte": nn.Embedding(
65
+ config.vocab_size, config.d_model, device=config.init_device
66
+ )
67
+ }
68
+ )
69
+ if not self.alibi:
70
+ self.transformer.update(
71
+ {
72
+ "wpe": nn.Embedding(
73
+ config.max_seq_len, config.d_model, device=config.init_device
74
+ )
75
+ }
76
+ )
77
+ self.transformer.update({"emb_drop": nn.Dropout(config.emb_pdrop)})
78
+ self.transformer.update(
79
+ {
80
+ "blocks": nn.ModuleList(
81
+ [
82
+ GPTBlock(device=config.init_device, **config.to_dict())
83
+ for _ in range(config.n_layers)
84
+ ]
85
+ )
86
+ }
87
+ )
88
+ self.transformer.update(
89
+ {"ln_f": layernorm_class(config.d_model, device=config.init_device)}
90
+ )
91
+
92
+ # enables scaling output logits; similar to a softmax "temperature"
93
+ # PaLM paper uses scale 1/sqrt(config.d_model)
94
+ self.logit_scale = None
95
+ if config.logit_scale is not None:
96
+ logit_scale = config.logit_scale
97
+ if isinstance(logit_scale, str):
98
+ if logit_scale == "inv_sqrt_d_model":
99
+ logit_scale = 1 / math.sqrt(config.d_model)
100
+ else:
101
+ raise ValueError(
102
+ f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
103
+ )
104
+ self.logit_scale = logit_scale
105
+
106
+ if config.init_device != "meta":
107
+ print(
108
+ f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
109
+ )
110
+ self.apply(self.param_init_fn)
111
+
112
+ self.is_causal = not self.prefix_lm
113
+
114
+ # define attn mask
115
+ self._attn_bias_initialized = False
116
+ self.attn_bias = None
117
+ self.attn_bias_shape = module_attn_bias_shape(
118
+ self.attn_impl,
119
+ config.n_heads,
120
+ config.max_seq_len,
121
+ self.alibi,
122
+ prefix_lm=self.prefix_lm,
123
+ causal=self.is_causal,
124
+ use_sequence_id=self.attn_uses_sequence_id,
125
+ )
126
+
127
+ if config.no_bias:
128
+ for module in self.modules():
129
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
130
+ if config.verbose:
131
+ print(f"Removing bias ({module.bias}) from {module}.")
132
+ module.register_parameter("bias", None)
133
+
134
+ if config.verbose and config.verbose > 2:
135
+ print(self)
136
+
137
+ self.logit_scale = None
138
+ if config.logit_scale is not None:
139
+ logit_scale = config.logit_scale
140
+ if isinstance(logit_scale, str):
141
+ if logit_scale == "inv_sqrt_d_model":
142
+ logit_scale = 1 / math.sqrt(config.d_model)
143
+ else:
144
+ raise ValueError(
145
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
146
+ )
147
+ self.logit_scale = logit_scale
148
+
149
+ def get_input_embeddings(self):
150
+ return self.transformer.wte
151
+
152
+ def set_input_embeddings(self, value):
153
+ self.transformer.wte = value
154
+
155
+ @torch.no_grad()
156
+ def _attn_bias(
157
+ self,
158
+ device,
159
+ dtype,
160
+ attention_mask: Optional[torch.ByteTensor] = None,
161
+ prefix_mask: Optional[torch.ByteTensor] = None,
162
+ sequence_id: Optional[torch.LongTensor] = None,
163
+ ):
164
+ if not self._attn_bias_initialized:
165
+ if self.attn_bias_shape:
166
+ self.attn_bias = torch.zeros(
167
+ self.attn_bias_shape, device=device, dtype=dtype
168
+ )
169
+ self.attn_bias = module_attn_bias(
170
+ self.attn_impl,
171
+ self.attn_bias,
172
+ self.config.n_heads,
173
+ self.config.max_seq_len,
174
+ causal=self.is_causal,
175
+ alibi=self.alibi,
176
+ alibi_bias_max=self.alibi_bias_max,
177
+ )
178
+ self._attn_bias_initialized = True
179
+
180
+ # flash does not support prefix_lm and will incorporate any
181
+ # attention_mask inside the attention module
182
+ if self.attn_impl == "flash":
183
+ return self.attn_bias, attention_mask
184
+
185
+ attn_bias = self.attn_bias
186
+
187
+ # If using torch or triton, we incorporate the prefix_mask (if appropriate)
188
+ if self.prefix_lm:
189
+ assert isinstance(attn_bias, torch.Tensor) # pyright
190
+ assert isinstance(prefix_mask, torch.Tensor) # pyright
191
+ attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
192
+
193
+ # If using torch or triton, we incorporate sequence_id (if appropriate)
194
+ if self.attn_uses_sequence_id and sequence_id is not None:
195
+ assert isinstance(attn_bias, torch.Tensor) # pyright
196
+ attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
197
+
198
+ # If using torch or triton, we incorporate attention_mask. This will output
199
+ # None in place of attention_mask since it will not be further needed in the
200
+ # attention modules.
201
+ if attention_mask is not None:
202
+ s_k = attention_mask.shape[-1]
203
+ if attn_bias is None:
204
+ attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
205
+ else:
206
+ attn_bias = attn_bias[:, :, :, -s_k:]
207
+ if prefix_mask is not None and (attention_mask.shape != prefix_mask.shape):
208
+ raise ValueError(
209
+ f"attention_mask shape={attention_mask.shape} "
210
+ + f"and prefix_mask shape={prefix_mask.shape} are not equal."
211
+ )
212
+ min_val = torch.finfo(attn_bias.dtype).min
213
+ attn_bias = attn_bias.masked_fill(
214
+ ~attention_mask.view(-1, 1, 1, s_k).bool(), min_val
215
+ )
216
+
217
+ return attn_bias, None
218
+
219
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
220
+ s_k, s_q = attn_bias.shape[-2:]
221
+ if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
222
+ raise ValueError(
223
+ "attn_bias does not match the expected shape. "
224
+ + f"The last two dimensions should both be {self.config.max_length} "
225
+ + f"but are {s_k} and {s_q}."
226
+ )
227
+ seq_len = prefix_mask.shape[-1]
228
+ if seq_len > self.config.max_seq_len:
229
+ raise ValueError(
230
+ f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
231
+ )
232
+
233
+ # select seq_len subset of attn mask
234
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
235
+
236
+ # Mix the causal max and the bidirectional mask to get the full
237
+ # allowable attention (i.e. full = not accounting for padding yet)
238
+ causal = torch.tril(
239
+ torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
240
+ ).view(1, 1, seq_len, seq_len)
241
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
242
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
243
+
244
+ min_val = torch.finfo(attn_bias.dtype).min
245
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
246
+
247
+ return attn_bias
248
+
249
+ def _apply_sequence_id(
250
+ self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
251
+ ):
252
+ seq_len = sequence_id.shape[-1]
253
+ if seq_len > self.config.max_seq_len:
254
+ raise ValueError(
255
+ f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
256
+ )
257
+
258
+ # select seq_len subset of attn mask
259
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
260
+
261
+ # Restrict attention to tokens that share the same value
262
+ # in sequence_id
263
+ cannot_attend = torch.logical_not(
264
+ torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
265
+ ).unsqueeze(1)
266
+ min_val = torch.finfo(attn_bias.dtype).min
267
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
268
+
269
+ return attn_bias
270
+
271
+ def forward(
272
+ self,
273
+ input_ids: torch.LongTensor,
274
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
275
+ attention_mask: Optional[torch.ByteTensor] = None,
276
+ prefix_mask: Optional[torch.ByteTensor] = None,
277
+ sequence_id: Optional[torch.LongTensor] = None,
278
+ labels: Optional[torch.LongTensor] = None,
279
+ return_dict: Optional[bool] = None,
280
+ output_attentions: Optional[bool] = None,
281
+ output_hidden_states: Optional[bool] = None,
282
+ use_cache: Optional[bool] = None,
283
+ ):
284
+ return_dict = (
285
+ return_dict if return_dict is not None else self.config.return_dict
286
+ )
287
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
288
+
289
+ # These args are passed in by keyword in huggingface's generate function
290
+ # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
291
+ # but have not yet been fully implemented in ReplitLM
292
+ if not return_dict:
293
+ raise NotImplementedError(
294
+ "return_dict False is not implemented yet for ReplitLM"
295
+ )
296
+ if output_attentions:
297
+ raise NotImplementedError(
298
+ "output_attentions is not implemented yet for ReplitLM"
299
+ )
300
+
301
+ if (
302
+ attention_mask is not None
303
+ and attention_mask[:, 0].sum() != attention_mask.shape[0]
304
+ and self.training
305
+ ):
306
+ raise NotImplementedError(
307
+ "ReplitLM does not support training with left padding."
308
+ )
309
+
310
+ if self.prefix_lm and prefix_mask is None:
311
+ raise ValueError(
312
+ "prefix_mask is a required argument when ReplitLM is configured with prefix_lm=True."
313
+ )
314
+
315
+ if self.training:
316
+ if self.attn_uses_sequence_id and sequence_id is None:
317
+ raise ValueError(
318
+ "sequence_id is a required argument when ReplitLM is configured with attn_uses_sequence_id=True "
319
+ + "and the model is in train mode."
320
+ )
321
+ elif (self.attn_uses_sequence_id is False) and (sequence_id is not None):
322
+ warnings.warn(
323
+ "ReplitLM received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
324
+ + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
325
+ )
326
+
327
+ S = input_ids.size(1)
328
+
329
+ assert (
330
+ S <= self.config.max_seq_len
331
+ ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
332
+
333
+ tok_emb = self.transformer.wte(input_ids) # type: ignore
334
+ if self.alibi:
335
+ x = tok_emb
336
+ else:
337
+ past_position = 0
338
+ if past_key_values is not None:
339
+ if len(past_key_values) != self.config.n_layers:
340
+ raise ValueError(
341
+ f"past_key_values must provide a past_key_value for each attention "
342
+ + f"layer in the network ({len(past_key_values)=}; {self.config.n_layers=})."
343
+ )
344
+ # get the key tensor whose spec should be (batch, seq, dim), and
345
+ # collect the `seq`, so that the position embedding is shifted
346
+ past_position = past_key_values[0][0].size(1)
347
+
348
+ if S + past_position > self.config.max_seq_len:
349
+ raise ValueError(
350
+ f"Cannot forward input with past sequence length {past_position} and current sequence length "
351
+ f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
352
+ )
353
+ pos = torch.arange(
354
+ past_position,
355
+ S + past_position,
356
+ dtype=torch.long,
357
+ device=input_ids.device,
358
+ ).unsqueeze(0)
359
+ if attention_mask is not None:
360
+ # adjust the position indices to account for padding tokens
361
+ pos = torch.clamp(
362
+ pos
363
+ - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
364
+ :, past_position:
365
+ ],
366
+ min=0,
367
+ )
368
+
369
+ pos_emb = self.transformer.wpe(pos) # type: ignore
370
+ x = tok_emb + pos_emb
371
+
372
+ if self.embedding_fraction == 1:
373
+ x = self.transformer.emb_drop(x) # type: ignore
374
+ else:
375
+ # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
376
+ x_shrunk = (x * self.embedding_fraction) + (
377
+ x.detach() * (1 - self.embedding_fraction)
378
+ )
379
+ assert isinstance(self.transformer.emb_drop, nn.Module) # pyright
380
+ x = self.transformer.emb_drop(x_shrunk)
381
+
382
+ attn_bias, attention_mask = self._attn_bias(
383
+ device=x.device,
384
+ dtype=x.dtype,
385
+ attention_mask=attention_mask,
386
+ prefix_mask=prefix_mask,
387
+ sequence_id=sequence_id,
388
+ )
389
+
390
+ # initialize the past key values cache if it should be used
391
+ if use_cache and past_key_values is None:
392
+ past_key_values = [() for _ in range(self.config.n_layers)] # type: ignore
393
+
394
+ all_hidden_states = () if output_hidden_states else None
395
+ for b_idx, block in enumerate(self.transformer.blocks): # type: ignore
396
+ if output_hidden_states:
397
+ assert all_hidden_states is not None # pyright
398
+ all_hidden_states = all_hidden_states + (x,)
399
+ past_key_value = (
400
+ past_key_values[b_idx] if past_key_values is not None else None
401
+ )
402
+ x, past_key_value = block(
403
+ x,
404
+ past_key_value=past_key_value,
405
+ attn_bias=attn_bias,
406
+ attention_mask=attention_mask,
407
+ is_causal=self.is_causal,
408
+ )
409
+ if past_key_values is not None:
410
+ past_key_values[b_idx] = past_key_value
411
+
412
+ x = self.transformer.ln_f(x) # type: ignore
413
+
414
+ outputs = BaseModelOutputWithPast(
415
+ last_hidden_state=x,
416
+ past_key_values=past_key_values,
417
+ hidden_states=all_hidden_states,
418
+ )
419
+
420
+ logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
421
+ if self.logit_scale is not None:
422
+ if self.logit_scale == 0:
423
+ warnings.warn(
424
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
425
+ )
426
+ logits *= self.logit_scale
427
+ loss = None
428
+ if labels is not None:
429
+ labels = torch.roll(labels, shifts=-1)
430
+ labels[:, -1] = -100
431
+ loss = F.cross_entropy(
432
+ logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
433
+ )
434
+ return CausalLMOutputWithPast(
435
+ loss=loss,
436
+ logits=logits,
437
+ past_key_values=outputs.past_key_values,
438
+ hidden_states=outputs.hidden_states,
439
+ )
440
+
441
+ # Param Initialization, needed for device='meta' fast initialization
442
+ def param_init_fn(self, module):
443
+ init_fn_name = self.config.param_init_fn
444
+ if self.config.verbose > 1:
445
+ warnings.warn(f"Using {init_fn_name} initialization.")
446
+ MODEL_INIT_REGISTRY[init_fn_name](module=module, **self.config.to_dict())
447
+
448
+ # FSDP Wrap function
449
+ def fsdp_wrap_fn(self, module):
450
+ return isinstance(module, GPTBlock)
451
+
452
+ # Activation Checkpointing
453
+ def activation_checkpointing_fn(self, module):
454
+ return isinstance(module, GPTBlock)
455
+
456
+ def prepare_inputs_for_generation(
457
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
458
+ ):
459
+ if inputs_embeds is not None:
460
+ raise NotImplementedError(
461
+ "inputs_embeds is not implemented for ReplitLM yet"
462
+ )
463
+
464
+ attention_mask = kwargs["attention_mask"].bool()
465
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
466
+ raise NotImplementedError(
467
+ "ReplitLM does not support generation with right padding."
468
+ )
469
+
470
+ if self.attn_uses_sequence_id and self.training:
471
+ sequence_id = torch.zeros_like(input_ids[:1])
472
+ else:
473
+ sequence_id = None
474
+
475
+ if past_key_values is not None:
476
+ input_ids = input_ids[:, -1].unsqueeze(-1)
477
+
478
+ if self.prefix_lm:
479
+ # Leverage a convenience of sequential generation!
480
+ prefix_mask = torch.ones_like(attention_mask)
481
+ # This requires that we're using the cache
482
+ if kwargs.get("use_cache") == False:
483
+ raise NotImplementedError(
484
+ "ReplitLM with prefix_lm=True does not support use_cache=False."
485
+ )
486
+ else:
487
+ prefix_mask = None
488
+
489
+ return {
490
+ "input_ids": input_ids,
491
+ "attention_mask": attention_mask,
492
+ "prefix_mask": prefix_mask,
493
+ "sequence_id": sequence_id,
494
+ "past_key_values": past_key_values,
495
+ "use_cache": kwargs.get("use_cache", True),
496
+ }
497
+
498
+ @staticmethod
499
+ def _reorder_cache(past_key_values, beam_idx):
500
+ """Used by HuggingFace generate when using beam search with kv-caching.
501
+
502
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
503
+ for an example in transformers.
504
+ """
505
+ reordered_past = []
506
+ for layer_past in past_key_values:
507
+ reordered_past += [
508
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
509
+ ]
510
+ return reordered_past
511
+
512
+
513
+ # class ReplitLM_2(ReplitLMPreTrainedModel):
514
+ # def __init__(self, config: ReplitLMConfig):
515
+ # super().__init__(config)
516
+ # if not config.tie_word_embeddings:
517
+ # raise ValueError("MPTForCausalLM only supports tied word embeddings")
518
+ # self.transformer = ReplitLM2(config)
519
+ # self.logit_scale = None
520
+ # if config.logit_scale is not None:
521
+ # logit_scale = config.logit_scale
522
+ # if isinstance(logit_scale, str):
523
+ # if logit_scale == "inv_sqrt_d_model":
524
+ # logit_scale = 1 / math.sqrt(config.d_model)
525
+ # else:
526
+ # raise ValueError(
527
+ # f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
528
+ # )
529
+ # self.logit_scale = logit_scale
530
+
531
+ # def get_input_embeddings(self):
532
+ # return self.transformer.transformer.wte
533
+
534
+ # def set_input_embeddings(self, value):
535
+ # self.transformer.transformer.wte = value
536
+
537
+ # def get_output_embeddings(self):
538
+ # return self.transformer.transformer.wte
539
+
540
+ # def set_output_embeddings(self, new_embeddings):
541
+ # self.transformer.transformer.wte = new_embeddings
542
+
543
+ # def set_decoder(self, decoder):
544
+ # self.transformer = decoder
545
+
546
+ # def get_decoder(self):
547
+ # return self.transformer
548
+
549
+ # def forward(
550
+ # self,
551
+ # input_ids: torch.LongTensor,
552
+ # past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
553
+ # attention_mask: Optional[torch.ByteTensor] = None,
554
+ # prefix_mask: Optional[torch.ByteTensor] = None,
555
+ # sequence_id: Optional[torch.LongTensor] = None,
556
+ # labels: Optional[torch.LongTensor] = None,
557
+ # return_dict: Optional[bool] = None,
558
+ # output_attentions: Optional[bool] = None,
559
+ # output_hidden_states: Optional[bool] = None,
560
+ # use_cache: Optional[bool] = None,
561
+ # ):
562
+ # return_dict = (
563
+ # return_dict if return_dict is not None else self.config.return_dict
564
+ # )
565
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
566
+ # outputs = self.transformer(
567
+ # input_ids=input_ids,
568
+ # past_key_values=past_key_values,
569
+ # attention_mask=attention_mask,
570
+ # prefix_mask=prefix_mask,
571
+ # sequence_id=sequence_id,
572
+ # return_dict=return_dict,
573
+ # output_attentions=output_attentions,
574
+ # output_hidden_states=output_hidden_states,
575
+ # use_cache=use_cache,
576
+ # )
577
+ # logits = F.linear(
578
+ # outputs.last_hidden_state, self.transformer.transformer.wte.weight
579
+ # )
580
+ # if self.logit_scale is not None:
581
+ # if self.logit_scale == 0:
582
+ # warnings.warn(
583
+ # f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
584
+ # )
585
+ # logits *= self.logit_scale
586
+ # loss = None
587
+ # if labels is not None:
588
+ # labels = torch.roll(labels, shifts=-1)
589
+ # labels[:, -1] = -100
590
+ # loss = F.cross_entropy(
591
+ # logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
592
+ # )
593
+ # return CausalLMOutputWithPast(
594
+ # loss=loss,
595
+ # logits=logits,
596
+ # past_key_values=outputs.past_key_values,
597
+ # hidden_states=outputs.hidden_states,
598
+ # )
599
+
600
+ # def param_init_fn(self, module):
601
+ # init_fn_name = self.config.param_init_fn
602
+ # if self.config.verbose > 1:
603
+ # warnings.warn(f"Using {init_fn_name} initialization.")
604
+ # MODEL_INIT_REGISTRY[init_fn_name](module=module, **self.config.to_dict())
605
+
606
+ # # FSDP Wrap function
607
+ # def fsdp_wrap_fn(self, module):
608
+ # return isinstance(module, GPTBlock)
609
+
610
+ # # Activation Checkpointing
611
+ # def activation_checkpointing_fn(self, module):
612
+ # return isinstance(module, GPTBlock)
613
+
614
+ # def prepare_inputs_for_generation(
615
+ # self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
616
+ # ):
617
+ # if inputs_embeds is not None:
618
+ # raise NotImplementedError(
619
+ # "inputs_embeds is not implemented for ReplitLM yet"
620
+ # )
621
+
622
+ # attention_mask = kwargs["attention_mask"].bool()
623
+ # if attention_mask[:, -1].sum() != attention_mask.shape[0]:
624
+ # raise NotImplementedError(
625
+ # "ReplitLM does not support generation with right padding."
626
+ # )
627
+
628
+ # if self.attn_uses_sequence_id and self.training:
629
+ # sequence_id = torch.zeros_like(input_ids[:1])
630
+ # else:
631
+ # sequence_id = None
632
+
633
+ # if past_key_values is not None:
634
+ # input_ids = input_ids[:, -1].unsqueeze(-1)
635
+
636
+ # if self.prefix_lm:
637
+ # # Leverage a convenience of sequential generation!
638
+ # prefix_mask = torch.ones_like(attention_mask)
639
+ # # This requires that we're using the cache
640
+ # if kwargs.get("use_cache") == False:
641
+ # raise NotImplementedError(
642
+ # "ReplitLM with prefix_lm=True does not support use_cache=False."
643
+ # )
644
+ # else:
645
+ # prefix_mask = None
646
+
647
+ # return {
648
+ # "input_ids": input_ids,
649
+ # "attention_mask": attention_mask,
650
+ # "prefix_mask": prefix_mask,
651
+ # "sequence_id": sequence_id,
652
+ # "past_key_values": past_key_values,
653
+ # "use_cache": kwargs.get("use_cache", True),
654
+ # }
655
+
656
+ # @staticmethod
657
+ # def _reorder_cache(past_key_values, beam_idx):
658
+ # """Used by HuggingFace generate when using beam search with kv-caching.
659
+
660
+ # See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
661
+ # for an example in transformers.
662
+ # """
663
+ # reordered_past = []
664
+ # for layer_past in past_key_values:
665
+ # reordered_past += [
666
+ # tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
667
+ # ]
668
+ # return reordered_past
replit_lm_tokenizer.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Forked from the file src/transformers/models/bert_generation/tokenization_bert_generation.py from the HuggingFace Transformers library.
17
+ Permalink: https://github.com/huggingface/transformers/blob/04ab5605fbb4ef207b10bf2772d88c53fc242e83/src/transformers/models/bert_generation/tokenization_bert_generation.py
18
+
19
+ Class is modified for compatibility with custom vocabulary and to achieve desired encode/decode behavior for Replit Code v1.3b model.
20
+ """
21
+
22
+ """ Tokenizer class for ReplitLM"""
23
+
24
+
25
+ import os
26
+ import sentencepiece as spm
27
+ from shutil import copyfile
28
+ from transformers import PreTrainedTokenizer
29
+ from typing import Any, Dict, List, Optional, Tuple
30
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
31
+
32
+
33
+ class ReplitLMTokenizer(PreTrainedTokenizer):
34
+ """
35
+ Construct a ReplitLMTokenizer tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
36
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods.
37
+
38
+ Args:
39
+ vocab_file (`str`):
40
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
41
+ contains the vocabulary necessary to instantiate a tokenizer.
42
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
43
+ The end of sequence token.
44
+ bos_token (`str`, *optional*, defaults to `None`):
45
+ The begin of sequence token.
46
+ unk_token (`str`, *optional*, defaults to `"<|unk|>"`):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ pad_token (`str`, *optional*, defaults to `"<|pad|>"`):
50
+ The token used for padding, for example when batching sequences of different lengths.
51
+ sp_model_kwargs (`dict`, *optional*):
52
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
53
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
54
+ to set:
55
+ - `enable_sampling`: Enable subword regularization.
56
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
57
+ - `nbest_size = {0,1}`: No sampling is performed.
58
+ - `nbest_size > 1`: samples from the nbest_size results.
59
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
60
+ using forward-filtering-and-backward-sampling algorithm.
61
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
62
+ BPE-dropout.
63
+ """
64
+
65
+ vocab_files_names = VOCAB_FILES_NAMES
66
+ prefix_tokens: List[int] = []
67
+ model_input_names = ["input_ids", "attention_mask"]
68
+
69
+ def __init__(
70
+ self,
71
+ vocab_file,
72
+ bos_token=None,
73
+ eos_token="<|endoftext|>",
74
+ unk_token="<|unk|>",
75
+ pad_token="<|pad|>",
76
+ sep_token=None,
77
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
78
+ **kwargs,
79
+ ) -> None:
80
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
81
+
82
+ # Add extra_ids to the special token list
83
+ super().__init__(
84
+ bos_token=bos_token,
85
+ eos_token=eos_token,
86
+ unk_token=unk_token,
87
+ pad_token=pad_token,
88
+ sep_token=sep_token,
89
+ sp_model_kwargs=self.sp_model_kwargs,
90
+ **kwargs,
91
+ )
92
+
93
+ self.vocab_file = vocab_file
94
+
95
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
96
+ self.sp_model.Load(vocab_file)
97
+
98
+ @property
99
+ def vocab_size(self):
100
+ return self.sp_model.get_piece_size()
101
+
102
+ def get_vocab(self):
103
+ vocab = {self.convert_ids_to_tokens(
104
+ i): i for i in range(self.vocab_size)}
105
+ vocab.update(self.added_tokens_encoder)
106
+ return vocab
107
+
108
+ def __getstate__(self):
109
+ state = self.__dict__.copy()
110
+ state["sp_model"] = None
111
+ return state
112
+
113
+ def __setstate__(self, d):
114
+ self.__dict__ = d
115
+
116
+ # for backward compatibility
117
+ if not hasattr(self, "sp_model_kwargs"):
118
+ self.sp_model_kwargs = {}
119
+
120
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
121
+ self.sp_model.load(self.vocab_file)
122
+
123
+ def _tokenize(self, text: str) -> List[str]:
124
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
125
+ return self.sp_model.encode(text, out_type=str)
126
+
127
+ def _convert_token_to_id(self, token):
128
+ """Converts a token (str) in an id using the vocab."""
129
+ return self.sp_model.piece_to_id(token)
130
+
131
+ def _convert_id_to_token(self, index):
132
+ """Converts an index (integer) in a token (str) using the vocab."""
133
+ token = self.sp_model.id_to_piece(index)
134
+ return token
135
+
136
+ def convert_tokens_to_string(self, tokens):
137
+ """Converts a sequence of tokens (string) in a single string."""
138
+ return self.sp_model.decode(tokens)
139
+
140
+ def save_vocabulary(self,
141
+ save_directory: str,
142
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
143
+
144
+ if not os.path.isdir(save_directory):
145
+ raise ValueError(
146
+ f"Vocabulary path ({save_directory}) should be a directory")
147
+
148
+ out_vocab_file = os.path.join(
149
+ save_directory, (filename_prefix + "-" if filename_prefix else "") +
150
+ VOCAB_FILES_NAMES["vocab_file"])
151
+
152
+ if os.path.abspath(
153
+ self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(
154
+ self.vocab_file):
155
+ copyfile(self.vocab_file, out_vocab_file)
156
+ elif not os.path.isfile(self.vocab_file):
157
+ with open(out_vocab_file, "wb") as fi:
158
+ content_spiece_model = self.sp_model.serialized_model_proto()
159
+ fi.write(content_spiece_model)
160
+
161
+ return (out_vocab_file, )
special_tokens_map.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|system|>",
4
+ "<|user|>",
5
+ "<|assistant|>",
6
+ "<|end|>"
7
+ ],
8
+ "eos_token": "<|endoftext|>",
9
+ "pad_token": "<|pad|>",
10
+ "unk_token": "<|unk|>"
11
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e1ba8b7df0701723d2d901c7a42182fe77bf0045173f2cdb474ca6ea3eb1c02
3
+ size 707660
tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "replit/replit-code-v1-3b--replit_lm_tokenizer.ReplitLMTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "bos_token": null,
9
+ "clean_up_tokenization_spaces": false,
10
+ "eos_token": "<|endoftext|>",
11
+ "model_max_length": 2048,
12
+ "pad_token": "<|pad|>",
13
+ "padding_side": "right",
14
+ "sep_token": null,
15
+ "sp_model_kwargs": {},
16
+ "tokenizer_class": "ReplitLMTokenizer",
17
+ "unk_token": "<|unk|>"
18
+ }