Girinath11 commited on
Commit
d6816d4
·
verified ·
1 Parent(s): 752c496

Merge embeddings and add Transformers support

Browse files
Files changed (1) hide show
  1. model_slm.py +758 -200
model_slm.py CHANGED
@@ -2,11 +2,10 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
5
- from typing import Optional, Tuple, Union
6
- from embeddings import TechEmbeddingLayer, create_padding_mask, create_causal_mask
7
 
8
  # ============================================================================
9
- # TRANSFORMERS COMPATIBILITY - ADD THIS SECTION
10
  # ============================================================================
11
  from transformers import PretrainedConfig
12
  from transformers.modeling_utils import PreTrainedModel
@@ -29,7 +28,6 @@ class MixtureOfRecursionsConfig(PretrainedConfig):
29
  router_type="adaptive",
30
  padding_idx=0,
31
  pos_encoding="learned",
32
- # Transformers standard names (for compatibility)
33
  hidden_size=None,
34
  num_hidden_layers=None,
35
  num_attention_heads=None,
@@ -38,8 +36,6 @@ class MixtureOfRecursionsConfig(PretrainedConfig):
38
  **kwargs
39
  ):
40
  super().__init__(**kwargs)
41
-
42
- # Your model's parameters
43
  self.vocab_size = vocab_size
44
  self.d_model = d_model
45
  self.n_layers = n_layers
@@ -51,8 +47,6 @@ class MixtureOfRecursionsConfig(PretrainedConfig):
51
  self.router_type = router_type
52
  self.padding_idx = padding_idx
53
  self.pos_encoding = pos_encoding
54
-
55
- # Transformers standard aliases (for compatibility)
56
  self.hidden_size = hidden_size or d_model
57
  self.num_hidden_layers = num_hidden_layers or n_layers
58
  self.num_attention_heads = num_attention_heads or n_heads
@@ -60,10 +54,155 @@ class MixtureOfRecursionsConfig(PretrainedConfig):
60
  self.max_position_embeddings = max_position_embeddings or max_seq_len
61
 
62
  # ============================================================================
63
- # END TRANSFORMERS COMPATIBILITY SECTION
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # ============================================================================
65
 
66
- # Constants for default configuration
67
  DEFAULT_D_MODEL = 512
68
  DEFAULT_N_HEADS = 8
69
  DEFAULT_N_LAYERS = 6
@@ -75,29 +214,20 @@ DEFAULT_PADDING_IDX = 0
75
  DEFAULT_ROUTER_TYPE = "adaptive"
76
  DEFAULT_VOCAB_SIZE = 10000
77
 
 
 
 
 
78
  class MultiHeadAttention(nn.Module):
79
  """Multi-head attention mechanism optimized for technical content."""
80
 
81
  def __init__(self, d_model: int, n_heads: int, dropout: float = DEFAULT_DROPOUT):
82
- """
83
- Initialize multi-head attention.
84
-
85
- Args:
86
- d_model (int): Dimension of the model embeddings.
87
- n_heads (int): Number of attention heads.
88
- dropout (float): Dropout rate for regularization.
89
-
90
- Raises:
91
- ValueError: If d_model is not divisible by n_heads.
92
- """
93
  super().__init__()
94
  if d_model % n_heads != 0:
95
  raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads})")
96
-
97
  self.d_model = d_model
98
  self.n_heads = n_heads
99
  self.d_k = d_model // n_heads
100
-
101
  self.w_q = nn.Linear(d_model, d_model, bias=False)
102
  self.w_k = nn.Linear(d_model, d_model, bias=False)
103
  self.w_v = nn.Linear(d_model, d_model, bias=False)
@@ -106,7 +236,6 @@ class MultiHeadAttention(nn.Module):
106
  self._init_weights()
107
 
108
  def _init_weights(self) -> None:
109
- """Initialize weights with Xavier uniform initialization."""
110
  for module in [self.w_q, self.w_k, self.w_v, self.w_o]:
111
  nn.init.xavier_uniform_(module.weight)
112
  if hasattr(module, 'bias') and module.bias is not None:
@@ -120,21 +249,7 @@ class MultiHeadAttention(nn.Module):
120
  mask: Optional[torch.Tensor] = None,
121
  pos_encoding: Optional[nn.Module] = None
122
  ) -> torch.Tensor:
123
- """
124
- Forward pass for multi-head attention.
125
-
126
- Args:
127
- query (torch.Tensor): Query tensor of shape (batch_size, seq_len, d_model).
128
- key (torch.Tensor): Key tensor of shape (batch_size, seq_len, d_model).
129
- value (torch.Tensor): Value tensor of shape (batch_size, seq_len, d_model).
130
- mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, seq_len, seq_len).
131
- pos_encoding (Optional[nn.Module]): Positional encoding module (e.g., RoPE).
132
-
133
- Returns:
134
- torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model).
135
- """
136
  batch_size, seq_len, _ = query.size()
137
-
138
  Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
139
  K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
140
  V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
@@ -158,34 +273,16 @@ class FeedForward(nn.Module):
158
  """Position-wise feed-forward network with GELU activation."""
159
 
160
  def __init__(self, d_model: int, dim_feedforward: int, dropout: float = DEFAULT_DROPOUT):
161
- """
162
- Initialize feed-forward network.
163
-
164
- Args:
165
- d_model (int): Dimension of the model embeddings.
166
- dim_feedforward (int): Dimension of the feed-forward layer.
167
- dropout (float): Dropout rate for regularization.
168
- """
169
  super().__init__()
170
  self.linear1 = nn.Linear(d_model, dim_feedforward)
171
  self.linear2 = nn.Linear(dim_feedforward, d_model)
172
  self.dropout = nn.Dropout(dropout)
173
-
174
  nn.init.xavier_uniform_(self.linear1.weight)
175
  nn.init.zeros_(self.linear1.bias)
176
  nn.init.xavier_uniform_(self.linear2.weight)
177
  nn.init.zeros_(self.linear2.bias)
178
 
179
  def forward(self, x: torch.Tensor) -> torch.Tensor:
180
- """
181
- Forward pass for feed-forward network.
182
-
183
- Args:
184
- x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
185
-
186
- Returns:
187
- torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model).
188
- """
189
  x = F.gelu(self.linear1(x))
190
  x = self.dropout(x)
191
  return self.linear2(x)
@@ -194,17 +291,6 @@ class RecursionRouter(nn.Module):
194
  """Router to determine recursion steps for technical problem processing."""
195
 
196
  def __init__(self, d_model: int, max_steps: int = DEFAULT_MAX_STEPS, router_type: str = DEFAULT_ROUTER_TYPE):
197
- """
198
- Initialize recursion router.
199
-
200
- Args:
201
- d_model (int): Dimension of the model embeddings.
202
- max_steps (int): Maximum number of recursion steps.
203
- router_type (str): Type of router ('adaptive' or 'fixed').
204
-
205
- Raises:
206
- ValueError: If router_type is invalid.
207
- """
208
  super().__init__()
209
  self.max_steps = max_steps
210
  self.router_type = router_type.lower()
@@ -223,15 +309,6 @@ class RecursionRouter(nn.Module):
223
  raise ValueError(f"Invalid router_type: {router_type}. Choose 'adaptive' or 'fixed'.")
224
 
225
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, int]:
226
- """
227
- Determine the number of recursion steps.
228
-
229
- Args:
230
- x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
231
-
232
- Returns:
233
- Union[torch.Tensor, int]: Number of steps (tensor for adaptive, int for fixed).
234
- """
235
  if self.router_type == "adaptive":
236
  seq_repr = x.mean(dim=1)
237
  step_probs = self.complexity_classifier(seq_repr)
@@ -250,21 +327,9 @@ class RecursiveTransformerLayer(nn.Module):
250
  dropout: float = DEFAULT_DROPOUT,
251
  router_type: str = DEFAULT_ROUTER_TYPE
252
  ):
253
- """
254
- Initialize recursive transformer layer.
255
-
256
- Args:
257
- d_model (int): Dimension of the model embeddings.
258
- n_heads (int): Number of attention heads.
259
- dim_feedforward (int): Dimension of the feed-forward layer.
260
- max_steps (int): Maximum number of recursion steps.
261
- dropout (float): Dropout rate for regularization.
262
- router_type (str): Type of router ('adaptive' or 'fixed').
263
- """
264
  super().__init__()
265
  self.max_steps = max_steps
266
  self.d_model = d_model
267
-
268
  self.attention = MultiHeadAttention(d_model, n_heads, dropout)
269
  self.feedforward = FeedForward(d_model, dim_feedforward, dropout)
270
  self.norm1 = nn.LayerNorm(d_model)
@@ -274,7 +339,6 @@ class RecursiveTransformerLayer(nn.Module):
274
  self.step_projections = nn.ModuleList([
275
  nn.Linear(d_model, d_model) for _ in range(max_steps)
276
  ])
277
-
278
  for proj in self.step_projections:
279
  nn.init.xavier_uniform_(proj.weight)
280
  nn.init.zeros_(proj.bias)
@@ -285,17 +349,6 @@ class RecursiveTransformerLayer(nn.Module):
285
  mask: Optional[torch.Tensor] = None,
286
  pos_encoding: Optional[nn.Module] = None
287
  ) -> Tuple[torch.Tensor, torch.Tensor]:
288
- """
289
- Forward pass for recursive transformer layer.
290
-
291
- Args:
292
- x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
293
- mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, seq_len, seq_len).
294
- pos_encoding (Optional[nn.Module]): Positional encoding module (e.g., RoPE).
295
-
296
- Returns:
297
- Tuple[torch.Tensor, torch.Tensor]: Output tensor and computation loss.
298
- """
299
  steps = self.router(x)
300
  if isinstance(steps, (int, torch.Tensor)) and not torch.is_tensor(steps):
301
  return self._recursive_forward_fixed(x, mask, steps, pos_encoding)
@@ -308,11 +361,9 @@ class RecursiveTransformerLayer(nn.Module):
308
  num_steps: int,
309
  pos_encoding: Optional[nn.Module]
310
  ) -> Tuple[torch.Tensor, torch.Tensor]:
311
- """Fixed recursion forward pass."""
312
  device = x.device
313
  batch_size = x.shape[0]
314
  computation_loss = torch.tensor(0.0, device=device)
315
-
316
  for step in range(min(num_steps, self.max_steps)):
317
  step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
318
  attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
@@ -320,7 +371,6 @@ class RecursiveTransformerLayer(nn.Module):
320
  fed_forward = self.feedforward(x)
321
  x = self.norm2(x + self.dropout(fed_forward))
322
  computation_loss += torch.tensor(0.1, device=device) * batch_size
323
-
324
  return x, computation_loss
325
 
326
  def _recursive_forward_adaptive(
@@ -330,18 +380,15 @@ class RecursiveTransformerLayer(nn.Module):
330
  steps: torch.Tensor,
331
  pos_encoding: Optional[nn.Module]
332
  ) -> Tuple[torch.Tensor, torch.Tensor]:
333
- """Adaptive recursion forward pass."""
334
  batch_size, seq_len, d_model = x.shape
335
  device = x.device
336
  max_batch_steps = int(steps.max().item())
337
  computation_loss = torch.tensor(0.0, device=device)
338
  active_batches = torch.ones(batch_size, device=device, dtype=torch.bool)
339
-
340
  for step in range(min(max_batch_steps, self.max_steps)):
341
  step_mask = (steps > step) & active_batches
342
  if not step_mask.any():
343
  break
344
-
345
  step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
346
  attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
347
  attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended))
@@ -351,7 +398,6 @@ class RecursiveTransformerLayer(nn.Module):
351
  x = self.norm2(x + self.dropout(fed_forward))
352
  computation_loss += torch.tensor(0.1, device=device) * step_mask.sum()
353
  active_batches &= (steps > step)
354
-
355
  return x, computation_loss
356
 
357
  class MixtureOfRecursions(nn.Module):
@@ -371,27 +417,10 @@ class MixtureOfRecursions(nn.Module):
371
  padding_idx: int = DEFAULT_PADDING_IDX,
372
  pos_encoding: str = "learned"
373
  ):
374
- """
375
- Initialize the Mixture of Recursions model.
376
-
377
- Args:
378
- vocab_size (int): Size of the vocabulary.
379
- d_model (int): Dimension of the model embeddings.
380
- n_layers (int): Number of transformer layers.
381
- n_heads (int): Number of attention heads.
382
- max_steps (int): Maximum number of recursion steps.
383
- dim_feedforward (int): Dimension of the feed-forward layer.
384
- dropout (float): Dropout rate for regularization.
385
- max_seq_len (int): Maximum sequence length.
386
- router_type (str): Type of router ('adaptive' or 'fixed').
387
- padding_idx (int): Index for padding token.
388
- pos_encoding (str): Type of positional encoding ('learned', 'sinusoidal', 'rope').
389
- """
390
  super().__init__()
391
  self.d_model = d_model
392
  self.vocab_size = vocab_size
393
  self.padding_idx = padding_idx
394
-
395
  self.embeddings = TechEmbeddingLayer(
396
  vocab_size=vocab_size,
397
  d_model=d_model,
@@ -400,7 +429,6 @@ class MixtureOfRecursions(nn.Module):
400
  padding_idx=padding_idx,
401
  pos_encoding=pos_encoding
402
  )
403
-
404
  self.layers = nn.ModuleList([
405
  RecursiveTransformerLayer(
406
  d_model=d_model,
@@ -411,39 +439,24 @@ class MixtureOfRecursions(nn.Module):
411
  router_type=router_type
412
  ) for _ in range(n_layers)
413
  ])
414
-
415
  self.final_norm = nn.LayerNorm(d_model)
416
  self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
417
  self._init_weights()
418
 
419
  def _init_weights(self) -> None:
420
- """Initialize weights for the language model head."""
421
  nn.init.xavier_uniform_(self.lm_head.weight)
422
 
423
  def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
424
- """
425
- Forward pass for the model.
426
-
427
- Args:
428
- input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
429
- attention_mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, seq_len).
430
-
431
- Returns:
432
- Tuple[torch.Tensor, torch.Tensor]: Logits and total computation loss.
433
- """
434
  batch_size, seq_len = input_ids.shape
435
  padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
436
  causal_mask = create_causal_mask(seq_len, input_ids.device)
437
  combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
438
-
439
  x = self.embeddings(input_ids)
440
  pos_encoding = self.embeddings.get_positional_encoding()
441
-
442
  total_computation_loss = torch.tensor(0.0, device=x.device)
443
  for layer in self.layers:
444
  x, comp_loss = layer(x, combined_mask, pos_encoding)
445
  total_computation_loss += comp_loss
446
-
447
  x = self.final_norm(x)
448
  logits = self.lm_head(x)
449
  return logits, total_computation_loss
@@ -455,27 +468,13 @@ class MixtureOfRecursions(nn.Module):
455
  top_k: Optional[int] = None,
456
  top_p: Optional[float] = None
457
  ) -> torch.Tensor:
458
- """
459
- Generate the next token for a given input sequence.
460
-
461
- Args:
462
- input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
463
- temperature (float): Temperature for softmax scaling.
464
- top_k (Optional[int]): Number of top-k tokens to sample from.
465
- top_p (Optional[float]): Cumulative probability for nucleus sampling.
466
-
467
- Returns:
468
- torch.Tensor: Next token IDs of shape (batch_size, 1).
469
- """
470
  self.eval()
471
  with torch.no_grad():
472
  logits, _ = self.forward(input_ids)
473
  last_logits = logits[:, -1, :] / temperature
474
-
475
  if top_k is not None:
476
  indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
477
  last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
478
-
479
  if top_p is not None:
480
  sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
481
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
@@ -484,7 +483,6 @@ class MixtureOfRecursions(nn.Module):
484
  sorted_indices_to_remove[..., 0] = False
485
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
486
  last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
487
-
488
  probs = F.softmax(last_logits, dim=-1)
489
  return torch.multinomial(probs, num_samples=1)
490
 
@@ -492,15 +490,6 @@ class TextGenerator:
492
  """Text generation utility for the MixtureOfRecursions model."""
493
 
494
  def __init__(self, model: nn.Module, tokenizer: 'Tokenizer', max_length: int = DEFAULT_MAX_SEQ_LEN, device: Optional[torch.device] = None):
495
- """
496
- Initialize the text generator.
497
-
498
- Args:
499
- model (nn.Module): The transformer model.
500
- tokenizer (Tokenizer): Tokenizer for encoding/decoding text.
501
- max_length (int): Maximum sequence length for generation.
502
- device (Optional[torch.device]): Device to run the model on.
503
- """
504
  self.model = model
505
  self.tokenizer = tokenizer
506
  self.max_length = max_length
@@ -518,36 +507,16 @@ class TextGenerator:
518
  top_p: Optional[float] = 0.9,
519
  max_new_tokens: Optional[int] = None
520
  ) -> str:
521
- """
522
- Generate text based on a prompt.
523
-
524
- Args:
525
- prompt (str): Input prompt for generation.
526
- method (str): Generation method ('greedy', 'sample', 'top_k', 'nucleus').
527
- temperature (float): Temperature for softmax scaling.
528
- top_k (Optional[int]): Number of top-k tokens to sample from.
529
- top_p (Optional[float]): Cumulative probability for nucleus sampling.
530
- max_new_tokens (Optional[int]): Maximum number of new tokens to generate.
531
-
532
- Returns:
533
- str: Generated text response.
534
-
535
- Raises:
536
- ValueError: If the generation method is invalid.
537
- """
538
  max_new_tokens = max_new_tokens or self.max_length
539
  input_text = f"<|user|> {prompt}"
540
  input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
541
  input_tensor = torch.tensor([input_ids], device=self.device)
542
-
543
  self.model.eval()
544
  generated_ids = []
545
-
546
  with torch.no_grad():
547
  for _ in range(max_new_tokens):
548
  if input_tensor.size(1) > self.max_length:
549
  input_tensor = input_tensor[:, -self.max_length:]
550
-
551
  if method == "greedy":
552
  next_token = self._greedy_generate(input_tensor)
553
  elif method == "sample":
@@ -558,38 +527,30 @@ class TextGenerator:
558
  next_token = self._nucleus_generate(input_tensor, temperature, top_p)
559
  else:
560
  raise ValueError(f"Unknown generation method: {method}")
561
-
562
  next_token_id = next_token.item()
563
  generated_ids.append(next_token_id)
564
  input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
565
-
566
  if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
567
  break
568
-
569
  full_ids = input_ids + generated_ids
570
  full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
571
-
572
  if "<|assistant|>" in full_text:
573
  response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
574
  else:
575
  response = full_text.split("<|endoftext|>")[0].strip()
576
-
577
  return response if response else "No response generated."
578
 
579
  def _greedy_generate(self, input_tensor: torch.Tensor) -> torch.Tensor:
580
- """Generate the next token using greedy decoding."""
581
  logits, _ = self.model(input_tensor)
582
  return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
583
 
584
  def _sample_generate(self, input_tensor: torch.Tensor, temperature: float) -> torch.Tensor:
585
- """Generate the next token using random sampling."""
586
  logits, _ = self.model(input_tensor)
587
  logits = logits[:, -1, :] / temperature
588
  probs = F.softmax(logits, dim=-1)
589
  return torch.multinomial(probs, num_samples=1)
590
 
591
  def _top_k_generate(self, input_tensor: torch.Tensor, temperature: float, top_k: int) -> torch.Tensor:
592
- """Generate the next token using top-k sampling."""
593
  logits, _ = self.model(input_tensor)
594
  logits = logits[:, -1, :] / temperature
595
  top_k_logits, top_k_indices = torch.topk(logits, top_k)
@@ -598,19 +559,616 @@ class TextGenerator:
598
  return top_k_indices.gather(-1, next_token_idx)
599
 
600
  def _nucleus_generate(self, input_tensor: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor:
601
- """Generate the next token using nucleus (top-p) sampling."""
602
  return self.model.generate_step(input_tensor, temperature, top_p=top_p)
603
 
604
  def count_parameters(model: nn.Module) -> Tuple[int, int]:
605
- """
606
- Count total and trainable parameters in the model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
 
608
- Args:
609
- model (nn.Module): The model to analyze.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
611
- Returns:
612
- Tuple[int, int]: Total and trainable parameter counts.
613
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  total_params = sum(p.numel() for p in model.parameters())
615
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
616
  return total_params, trainable_params
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
5
+ from typing import Optional, Tuple, Union, List
 
6
 
7
  # ============================================================================
8
+ # TRANSFORMERS COMPATIBILITY
9
  # ============================================================================
10
  from transformers import PretrainedConfig
11
  from transformers.modeling_utils import PreTrainedModel
 
28
  router_type="adaptive",
29
  padding_idx=0,
30
  pos_encoding="learned",
 
31
  hidden_size=None,
32
  num_hidden_layers=None,
33
  num_attention_heads=None,
 
36
  **kwargs
37
  ):
38
  super().__init__(**kwargs)
 
 
39
  self.vocab_size = vocab_size
40
  self.d_model = d_model
41
  self.n_layers = n_layers
 
47
  self.router_type = router_type
48
  self.padding_idx = padding_idx
49
  self.pos_encoding = pos_encoding
 
 
50
  self.hidden_size = hidden_size or d_model
51
  self.num_hidden_layers = num_hidden_layers or n_layers
52
  self.num_attention_heads = num_attention_heads or n_heads
 
54
  self.max_position_embeddings = max_position_embeddings or max_seq_len
55
 
56
  # ============================================================================
57
+ # EMBEDDINGS MODULE (merged from embeddings.py)
58
+ # ============================================================================
59
+
60
+ DEFAULT_BASE = 10000.0
61
+ DEFAULT_CUTOFFS = [2000, 10000]
62
+ DEFAULT_DIV_VAL = 4.0
63
+
64
+ class PositionalEncoding(nn.Module):
65
+ """Sinusoidal positional encoding for transformer models."""
66
+
67
+ def __init__(self, d_model: int, max_seq_len: int = 512, dropout: float = 0.1):
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.dropout = nn.Dropout(dropout)
71
+ pe = torch.zeros(max_seq_len, d_model)
72
+ position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
73
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(DEFAULT_BASE) / d_model))
74
+ pe[:, 0::2] = torch.sin(position * div_term)
75
+ pe[:, 1::2] = torch.cos(position * div_term[:, :-1] if d_model % 2 == 1 else div_term)
76
+ self.register_buffer('pe', pe.unsqueeze(0))
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ batch_size, seq_len, d_model = x.size()
80
+ if d_model != self.d_model:
81
+ raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
82
+ x = x + self.pe[:, :seq_len]
83
+ return self.dropout(x)
84
+
85
+ class LearnedPositionalEmbedding(nn.Module):
86
+ """Learned positional embeddings for transformer models."""
87
+
88
+ def __init__(self, max_seq_len: int, d_model: int, dropout: float = 0.1):
89
+ super().__init__()
90
+ self.max_seq_len = max_seq_len
91
+ self.d_model = d_model
92
+ self.pos_embedding = nn.Embedding(max_seq_len, d_model)
93
+ self.dropout = nn.Dropout(dropout)
94
+ nn.init.normal_(self.pos_embedding.weight, std=0.02)
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ batch_size, seq_len, d_model = x.size()
98
+ if seq_len > self.max_seq_len:
99
+ raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")
100
+ if d_model != self.d_model:
101
+ raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
102
+ positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
103
+ pos_emb = self.pos_embedding(positions)
104
+ x = x + pos_emb
105
+ return self.dropout(x)
106
+
107
+ class RotaryPositionalEmbedding(nn.Module):
108
+ """Rotary Positional Embedding (RoPE) for transformer models."""
109
+
110
+ def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = DEFAULT_BASE):
111
+ super().__init__()
112
+ self.d_model = d_model
113
+ self.max_seq_len = max_seq_len
114
+ self.base = base
115
+ inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
116
+ self.register_buffer('inv_freq', inv_freq)
117
+ self._seq_len_cached = 0
118
+ self._cos_cached = None
119
+ self._sin_cached = None
120
+
121
+ def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
122
+ if seq_len > self._seq_len_cached:
123
+ self._seq_len_cached = seq_len
124
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
125
+ freqs = torch.outer(t, self.inv_freq)
126
+ self._cos_cached = freqs.cos().to(dtype)
127
+ self._sin_cached = freqs.sin().to(dtype)
128
+
129
+ def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
130
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
131
+ return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
132
+
133
+ def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ batch_size, seq_len, num_heads, head_dim = q.shape
135
+ self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype)
136
+ cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
137
+ sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
138
+ q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
139
+ k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
140
+ q_rot = self._rotate_half(q, cos, sin)
141
+ k_rot = self._rotate_half(k, cos, sin)
142
+ q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
143
+ k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
144
+ return q_rot, k_rot
145
+
146
+ class TechEmbeddingLayer(nn.Module):
147
+ """Comprehensive embedding layer with token and positional embeddings."""
148
+
149
+ def __init__(
150
+ self,
151
+ vocab_size: int,
152
+ d_model: int,
153
+ max_seq_len: int = 512,
154
+ dropout: float = 0.1,
155
+ padding_idx: int = 0,
156
+ pos_encoding: str = "learned",
157
+ layer_norm: bool = True,
158
+ ):
159
+ super().__init__()
160
+ self.d_model = d_model
161
+ self.vocab_size = vocab_size
162
+ self.padding_idx = padding_idx
163
+ self.pos_encoding_type = pos_encoding.lower()
164
+ self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
165
+
166
+ if pos_encoding == "sinusoidal":
167
+ self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
168
+ elif pos_encoding == "learned":
169
+ self.pos_encoding = LearnedPositionalEmbedding(max_seq_len, d_model, dropout)
170
+ elif pos_encoding == "rope":
171
+ self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len)
172
+ else:
173
+ raise ValueError(f"Unknown positional encoding type: {pos_encoding}")
174
+
175
+ self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity()
176
+ self.dropout = nn.Dropout(dropout)
177
+ self._init_weights()
178
+
179
+ def _init_weights(self) -> None:
180
+ nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
181
+ if self.padding_idx is not None:
182
+ nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0)
183
+
184
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
185
+ if (input_ids >= self.vocab_size).any():
186
+ raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
187
+ embeddings = self.token_embedding(input_ids)
188
+ if self.pos_encoding_type != "rope":
189
+ embeddings = self.pos_encoding(embeddings)
190
+ embeddings = self.layer_norm(embeddings)
191
+ return self.dropout(embeddings)
192
+
193
+ def get_positional_encoding(self) -> Optional[nn.Module]:
194
+ return self.pos_encoding if self.pos_encoding_type == "rope" else None
195
+
196
+ def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = 0) -> torch.Tensor:
197
+ return input_ids == padding_idx
198
+
199
+ def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
200
+ return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
201
+
202
+ # ============================================================================
203
+ # MODEL CONSTANTS
204
  # ============================================================================
205
 
 
206
  DEFAULT_D_MODEL = 512
207
  DEFAULT_N_HEADS = 8
208
  DEFAULT_N_LAYERS = 6
 
214
  DEFAULT_ROUTER_TYPE = "adaptive"
215
  DEFAULT_VOCAB_SIZE = 10000
216
 
217
+ # ============================================================================
218
+ # MODEL COMPONENTS
219
+ # ============================================================================
220
+
221
  class MultiHeadAttention(nn.Module):
222
  """Multi-head attention mechanism optimized for technical content."""
223
 
224
  def __init__(self, d_model: int, n_heads: int, dropout: float = DEFAULT_DROPOUT):
 
 
 
 
 
 
 
 
 
 
 
225
  super().__init__()
226
  if d_model % n_heads != 0:
227
  raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads})")
 
228
  self.d_model = d_model
229
  self.n_heads = n_heads
230
  self.d_k = d_model // n_heads
 
231
  self.w_q = nn.Linear(d_model, d_model, bias=False)
232
  self.w_k = nn.Linear(d_model, d_model, bias=False)
233
  self.w_v = nn.Linear(d_model, d_model, bias=False)
 
236
  self._init_weights()
237
 
238
  def _init_weights(self) -> None:
 
239
  for module in [self.w_q, self.w_k, self.w_v, self.w_o]:
240
  nn.init.xavier_uniform_(module.weight)
241
  if hasattr(module, 'bias') and module.bias is not None:
 
249
  mask: Optional[torch.Tensor] = None,
250
  pos_encoding: Optional[nn.Module] = None
251
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  batch_size, seq_len, _ = query.size()
 
253
  Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
254
  K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
255
  V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
 
273
  """Position-wise feed-forward network with GELU activation."""
274
 
275
  def __init__(self, d_model: int, dim_feedforward: int, dropout: float = DEFAULT_DROPOUT):
 
 
 
 
 
 
 
 
276
  super().__init__()
277
  self.linear1 = nn.Linear(d_model, dim_feedforward)
278
  self.linear2 = nn.Linear(dim_feedforward, d_model)
279
  self.dropout = nn.Dropout(dropout)
 
280
  nn.init.xavier_uniform_(self.linear1.weight)
281
  nn.init.zeros_(self.linear1.bias)
282
  nn.init.xavier_uniform_(self.linear2.weight)
283
  nn.init.zeros_(self.linear2.bias)
284
 
285
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
286
  x = F.gelu(self.linear1(x))
287
  x = self.dropout(x)
288
  return self.linear2(x)
 
291
  """Router to determine recursion steps for technical problem processing."""
292
 
293
  def __init__(self, d_model: int, max_steps: int = DEFAULT_MAX_STEPS, router_type: str = DEFAULT_ROUTER_TYPE):
 
 
 
 
 
 
 
 
 
 
 
294
  super().__init__()
295
  self.max_steps = max_steps
296
  self.router_type = router_type.lower()
 
309
  raise ValueError(f"Invalid router_type: {router_type}. Choose 'adaptive' or 'fixed'.")
310
 
311
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, int]:
 
 
 
 
 
 
 
 
 
312
  if self.router_type == "adaptive":
313
  seq_repr = x.mean(dim=1)
314
  step_probs = self.complexity_classifier(seq_repr)
 
327
  dropout: float = DEFAULT_DROPOUT,
328
  router_type: str = DEFAULT_ROUTER_TYPE
329
  ):
 
 
 
 
 
 
 
 
 
 
 
330
  super().__init__()
331
  self.max_steps = max_steps
332
  self.d_model = d_model
 
333
  self.attention = MultiHeadAttention(d_model, n_heads, dropout)
334
  self.feedforward = FeedForward(d_model, dim_feedforward, dropout)
335
  self.norm1 = nn.LayerNorm(d_model)
 
339
  self.step_projections = nn.ModuleList([
340
  nn.Linear(d_model, d_model) for _ in range(max_steps)
341
  ])
 
342
  for proj in self.step_projections:
343
  nn.init.xavier_uniform_(proj.weight)
344
  nn.init.zeros_(proj.bias)
 
349
  mask: Optional[torch.Tensor] = None,
350
  pos_encoding: Optional[nn.Module] = None
351
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
352
  steps = self.router(x)
353
  if isinstance(steps, (int, torch.Tensor)) and not torch.is_tensor(steps):
354
  return self._recursive_forward_fixed(x, mask, steps, pos_encoding)
 
361
  num_steps: int,
362
  pos_encoding: Optional[nn.Module]
363
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
364
  device = x.device
365
  batch_size = x.shape[0]
366
  computation_loss = torch.tensor(0.0, device=device)
 
367
  for step in range(min(num_steps, self.max_steps)):
368
  step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
369
  attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
 
371
  fed_forward = self.feedforward(x)
372
  x = self.norm2(x + self.dropout(fed_forward))
373
  computation_loss += torch.tensor(0.1, device=device) * batch_size
 
374
  return x, computation_loss
375
 
376
  def _recursive_forward_adaptive(
 
380
  steps: torch.Tensor,
381
  pos_encoding: Optional[nn.Module]
382
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
383
  batch_size, seq_len, d_model = x.shape
384
  device = x.device
385
  max_batch_steps = int(steps.max().item())
386
  computation_loss = torch.tensor(0.0, device=device)
387
  active_batches = torch.ones(batch_size, device=device, dtype=torch.bool)
 
388
  for step in range(min(max_batch_steps, self.max_steps)):
389
  step_mask = (steps > step) & active_batches
390
  if not step_mask.any():
391
  break
 
392
  step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
393
  attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
394
  attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended))
 
398
  x = self.norm2(x + self.dropout(fed_forward))
399
  computation_loss += torch.tensor(0.1, device=device) * step_mask.sum()
400
  active_batches &= (steps > step)
 
401
  return x, computation_loss
402
 
403
  class MixtureOfRecursions(nn.Module):
 
417
  padding_idx: int = DEFAULT_PADDING_IDX,
418
  pos_encoding: str = "learned"
419
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  super().__init__()
421
  self.d_model = d_model
422
  self.vocab_size = vocab_size
423
  self.padding_idx = padding_idx
 
424
  self.embeddings = TechEmbeddingLayer(
425
  vocab_size=vocab_size,
426
  d_model=d_model,
 
429
  padding_idx=padding_idx,
430
  pos_encoding=pos_encoding
431
  )
 
432
  self.layers = nn.ModuleList([
433
  RecursiveTransformerLayer(
434
  d_model=d_model,
 
439
  router_type=router_type
440
  ) for _ in range(n_layers)
441
  ])
 
442
  self.final_norm = nn.LayerNorm(d_model)
443
  self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
444
  self._init_weights()
445
 
446
  def _init_weights(self) -> None:
 
447
  nn.init.xavier_uniform_(self.lm_head.weight)
448
 
449
  def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
450
  batch_size, seq_len = input_ids.shape
451
  padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
452
  causal_mask = create_causal_mask(seq_len, input_ids.device)
453
  combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
 
454
  x = self.embeddings(input_ids)
455
  pos_encoding = self.embeddings.get_positional_encoding()
 
456
  total_computation_loss = torch.tensor(0.0, device=x.device)
457
  for layer in self.layers:
458
  x, comp_loss = layer(x, combined_mask, pos_encoding)
459
  total_computation_loss += comp_loss
 
460
  x = self.final_norm(x)
461
  logits = self.lm_head(x)
462
  return logits, total_computation_loss
 
468
  top_k: Optional[int] = None,
469
  top_p: Optional[float] = None
470
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
471
  self.eval()
472
  with torch.no_grad():
473
  logits, _ = self.forward(input_ids)
474
  last_logits = logits[:, -1, :] / temperature
 
475
  if top_k is not None:
476
  indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
477
  last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
 
478
  if top_p is not None:
479
  sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
480
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
 
483
  sorted_indices_to_remove[..., 0] = False
484
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
485
  last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
 
486
  probs = F.softmax(last_logits, dim=-1)
487
  return torch.multinomial(probs, num_samples=1)
488
 
 
490
  """Text generation utility for the MixtureOfRecursions model."""
491
 
492
  def __init__(self, model: nn.Module, tokenizer: 'Tokenizer', max_length: int = DEFAULT_MAX_SEQ_LEN, device: Optional[torch.device] = None):
 
 
 
 
 
 
 
 
 
493
  self.model = model
494
  self.tokenizer = tokenizer
495
  self.max_length = max_length
 
507
  top_p: Optional[float] = 0.9,
508
  max_new_tokens: Optional[int] = None
509
  ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  max_new_tokens = max_new_tokens or self.max_length
511
  input_text = f"<|user|> {prompt}"
512
  input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
513
  input_tensor = torch.tensor([input_ids], device=self.device)
 
514
  self.model.eval()
515
  generated_ids = []
 
516
  with torch.no_grad():
517
  for _ in range(max_new_tokens):
518
  if input_tensor.size(1) > self.max_length:
519
  input_tensor = input_tensor[:, -self.max_length:]
 
520
  if method == "greedy":
521
  next_token = self._greedy_generate(input_tensor)
522
  elif method == "sample":
 
527
  next_token = self._nucleus_generate(input_tensor, temperature, top_p)
528
  else:
529
  raise ValueError(f"Unknown generation method: {method}")
 
530
  next_token_id = next_token.item()
531
  generated_ids.append(next_token_id)
532
  input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
 
533
  if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
534
  break
 
535
  full_ids = input_ids + generated_ids
536
  full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
 
537
  if "<|assistant|>" in full_text:
538
  response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
539
  else:
540
  response = full_text.split("<|endoftext|>")[0].strip()
 
541
  return response if response else "No response generated."
542
 
543
  def _greedy_generate(self, input_tensor: torch.Tensor) -> torch.Tensor:
 
544
  logits, _ = self.model(input_tensor)
545
  return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
546
 
547
  def _sample_generate(self, input_tensor: torch.Tensor, temperature: float) -> torch.Tensor:
 
548
  logits, _ = self.model(input_tensor)
549
  logits = logits[:, -1, :] / temperature
550
  probs = F.softmax(logits, dim=-1)
551
  return torch.multinomial(probs, num_samples=1)
552
 
553
  def _top_k_generate(self, input_tensor: torch.Tensor, temperature: float, top_k: int) -> torch.Tensor:
 
554
  logits, _ = self.model(input_tensor)
555
  logits = logits[:, -1, :] / temperature
556
  top_k_logits, top_k_indices = torch.topk(logits, top_k)
 
559
  return top_k_indices.gather(-1, next_token_idx)
560
 
561
  def _nucleus_generate(self, input_tensor: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor:
 
562
  return self.model.generate_step(input_tensor, temperature, top_p=top_p)
563
 
564
  def count_parameters(model: nn.Module) -> Tuple[int, int]:
565
+ total_params = sum(p.numel() for p in model.parameters())
566
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
567
+ return total_params, trainable_params
568
+
569
+ def main():
570
+ """Test the MixtureOfRecursions model and its components."""
571
+ print("Initializing MixtureOfRecursions model...")
572
+ model = MixtureOfRecursions(
573
+ vocab_size=DEFAULT_VOCAB_SIZE,
574
+ d_model=DEFAULT_D_MODEL,
575
+ n_layers=DEFAULT_N_LAYERS,
576
+ n_heads=DEFAULT_N_HEADS,
577
+ max_steps=DEFAULT_MAX_STEPS,
578
+ dim_feedforward=DEFAULT_DIM_FEEDFORWARD,
579
+ dropout=DEFAULT_DROPOUT,
580
+ router_type=DEFAULT_ROUTER_TYPE
581
+ )
582
+
583
+ total_params, trainable_params = count_parameters(model)
584
+ print(f"Total parameters: {total_params:,}")
585
+ print(f"Trainable parameters: {trainable_params:,}")
586
+
587
+ print("\nTesting forward pass...")
588
+ batch_size, seq_len = 4, 128
589
+ input_ids = torch.randint(0, DEFAULT_VOCAB_SIZE, (batch_size, seq_len))
590
+ attention_mask = torch.ones_like(input_ids)
591
+ attention_mask[:, -10:] = 0
592
+
593
+ logits, comp_loss = model(input_ids, attention_mask)
594
+
595
+ assert logits.shape == (batch_size, seq_len, DEFAULT_VOCAB_SIZE), f"Unexpected logits shape: {logits.shape}"
596
+ print(f"Input shape: {input_ids.shape}")
597
+ print(f"Output logits shape: {logits.shape}")
598
+ print(f"Expected logits shape: ({batch_size}, {seq_len}, {DEFAULT_VOCAB_SIZE})")
599
+ print(f"Computation loss: {comp_loss:.4f}")
600
+
601
+ print("\nTesting generation step...")
602
+ next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
603
+ print(f"Generated next token: {next_token.item()}")
604
+
605
+ print("\nModel test completed successfully!")
606
 
607
+ if __name__ == "__main__":
608
+ main()import torch
609
+ import torch.nn as nn
610
+ import torch.nn.functional as F
611
+ import math
612
+ from typing import Optional, Tuple, Union, List
613
+
614
+ # ============================================================================
615
+ # TRANSFORMERS COMPATIBILITY
616
+ # ============================================================================
617
+ from transformers import PretrainedConfig
618
+ from transformers.modeling_utils import PreTrainedModel
619
+
620
+ class MixtureOfRecursionsConfig(PretrainedConfig):
621
+ """Configuration class for MixtureOfRecursions model."""
622
+
623
+ model_type = "mixture_of_recursions"
624
+
625
+ def __init__(
626
+ self,
627
+ vocab_size=31985,
628
+ d_model=384,
629
+ n_layers=12,
630
+ n_heads=6,
631
+ max_steps=4,
632
+ dim_feedforward=2048,
633
+ dropout=0.1,
634
+ max_seq_len=128,
635
+ router_type="adaptive",
636
+ padding_idx=0,
637
+ pos_encoding="learned",
638
+ hidden_size=None,
639
+ num_hidden_layers=None,
640
+ num_attention_heads=None,
641
+ intermediate_size=None,
642
+ max_position_embeddings=None,
643
+ **kwargs
644
+ ):
645
+ super().__init__(**kwargs)
646
+ self.vocab_size = vocab_size
647
+ self.d_model = d_model
648
+ self.n_layers = n_layers
649
+ self.n_heads = n_heads
650
+ self.max_steps = max_steps
651
+ self.dim_feedforward = dim_feedforward
652
+ self.dropout = dropout
653
+ self.max_seq_len = max_seq_len
654
+ self.router_type = router_type
655
+ self.padding_idx = padding_idx
656
+ self.pos_encoding = pos_encoding
657
+ self.hidden_size = hidden_size or d_model
658
+ self.num_hidden_layers = num_hidden_layers or n_layers
659
+ self.num_attention_heads = num_attention_heads or n_heads
660
+ self.intermediate_size = intermediate_size or dim_feedforward
661
+ self.max_position_embeddings = max_position_embeddings or max_seq_len
662
+
663
+ # ============================================================================
664
+ # EMBEDDINGS MODULE (merged from embeddings.py)
665
+ # ============================================================================
666
+
667
+ DEFAULT_BASE = 10000.0
668
+ DEFAULT_CUTOFFS = [2000, 10000]
669
+ DEFAULT_DIV_VAL = 4.0
670
+
671
+ class PositionalEncoding(nn.Module):
672
+ """Sinusoidal positional encoding for transformer models."""
673
+
674
+ def __init__(self, d_model: int, max_seq_len: int = 512, dropout: float = 0.1):
675
+ super().__init__()
676
+ self.d_model = d_model
677
+ self.dropout = nn.Dropout(dropout)
678
+ pe = torch.zeros(max_seq_len, d_model)
679
+ position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
680
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(DEFAULT_BASE) / d_model))
681
+ pe[:, 0::2] = torch.sin(position * div_term)
682
+ pe[:, 1::2] = torch.cos(position * div_term[:, :-1] if d_model % 2 == 1 else div_term)
683
+ self.register_buffer('pe', pe.unsqueeze(0))
684
+
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ batch_size, seq_len, d_model = x.size()
687
+ if d_model != self.d_model:
688
+ raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
689
+ x = x + self.pe[:, :seq_len]
690
+ return self.dropout(x)
691
+
692
+ class LearnedPositionalEmbedding(nn.Module):
693
+ """Learned positional embeddings for transformer models."""
694
+
695
+ def __init__(self, max_seq_len: int, d_model: int, dropout: float = 0.1):
696
+ super().__init__()
697
+ self.max_seq_len = max_seq_len
698
+ self.d_model = d_model
699
+ self.pos_embedding = nn.Embedding(max_seq_len, d_model)
700
+ self.dropout = nn.Dropout(dropout)
701
+ nn.init.normal_(self.pos_embedding.weight, std=0.02)
702
+
703
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
704
+ batch_size, seq_len, d_model = x.size()
705
+ if seq_len > self.max_seq_len:
706
+ raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")
707
+ if d_model != self.d_model:
708
+ raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
709
+ positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
710
+ pos_emb = self.pos_embedding(positions)
711
+ x = x + pos_emb
712
+ return self.dropout(x)
713
+
714
+ class RotaryPositionalEmbedding(nn.Module):
715
+ """Rotary Positional Embedding (RoPE) for transformer models."""
716
+
717
+ def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = DEFAULT_BASE):
718
+ super().__init__()
719
+ self.d_model = d_model
720
+ self.max_seq_len = max_seq_len
721
+ self.base = base
722
+ inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
723
+ self.register_buffer('inv_freq', inv_freq)
724
+ self._seq_len_cached = 0
725
+ self._cos_cached = None
726
+ self._sin_cached = None
727
+
728
+ def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
729
+ if seq_len > self._seq_len_cached:
730
+ self._seq_len_cached = seq_len
731
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
732
+ freqs = torch.outer(t, self.inv_freq)
733
+ self._cos_cached = freqs.cos().to(dtype)
734
+ self._sin_cached = freqs.sin().to(dtype)
735
+
736
+ def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
737
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
738
+ return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
739
+
740
+ def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
741
+ batch_size, seq_len, num_heads, head_dim = q.shape
742
+ self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype)
743
+ cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
744
+ sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
745
+ q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
746
+ k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
747
+ q_rot = self._rotate_half(q, cos, sin)
748
+ k_rot = self._rotate_half(k, cos, sin)
749
+ q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
750
+ k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
751
+ return q_rot, k_rot
752
+
753
+ class TechEmbeddingLayer(nn.Module):
754
+ """Comprehensive embedding layer with token and positional embeddings."""
755
+
756
+ def __init__(
757
+ self,
758
+ vocab_size: int,
759
+ d_model: int,
760
+ max_seq_len: int = 512,
761
+ dropout: float = 0.1,
762
+ padding_idx: int = 0,
763
+ pos_encoding: str = "learned",
764
+ layer_norm: bool = True,
765
+ ):
766
+ super().__init__()
767
+ self.d_model = d_model
768
+ self.vocab_size = vocab_size
769
+ self.padding_idx = padding_idx
770
+ self.pos_encoding_type = pos_encoding.lower()
771
+ self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
772
+
773
+ if pos_encoding == "sinusoidal":
774
+ self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
775
+ elif pos_encoding == "learned":
776
+ self.pos_encoding = LearnedPositionalEmbedding(max_seq_len, d_model, dropout)
777
+ elif pos_encoding == "rope":
778
+ self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len)
779
+ else:
780
+ raise ValueError(f"Unknown positional encoding type: {pos_encoding}")
781
+
782
+ self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity()
783
+ self.dropout = nn.Dropout(dropout)
784
+ self._init_weights()
785
+
786
+ def _init_weights(self) -> None:
787
+ nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
788
+ if self.padding_idx is not None:
789
+ nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0)
790
+
791
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
792
+ if (input_ids >= self.vocab_size).any():
793
+ raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
794
+ embeddings = self.token_embedding(input_ids)
795
+ if self.pos_encoding_type != "rope":
796
+ embeddings = self.pos_encoding(embeddings)
797
+ embeddings = self.layer_norm(embeddings)
798
+ return self.dropout(embeddings)
799
+
800
+ def get_positional_encoding(self) -> Optional[nn.Module]:
801
+ return self.pos_encoding if self.pos_encoding_type == "rope" else None
802
+
803
+ def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = 0) -> torch.Tensor:
804
+ return input_ids == padding_idx
805
+
806
+ def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
807
+ return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
808
+
809
+ # ============================================================================
810
+ # MODEL CONSTANTS
811
+ # ============================================================================
812
+
813
+ DEFAULT_D_MODEL = 512
814
+ DEFAULT_N_HEADS = 8
815
+ DEFAULT_N_LAYERS = 6
816
+ DEFAULT_MAX_STEPS = 4
817
+ DEFAULT_DIM_FEEDFORWARD = 2048
818
+ DEFAULT_DROPOUT = 0.1
819
+ DEFAULT_MAX_SEQ_LEN = 512
820
+ DEFAULT_PADDING_IDX = 0
821
+ DEFAULT_ROUTER_TYPE = "adaptive"
822
+ DEFAULT_VOCAB_SIZE = 10000
823
+
824
+ # ============================================================================
825
+ # MODEL COMPONENTS
826
+ # ============================================================================
827
 
828
+ class MultiHeadAttention(nn.Module):
829
+ """Multi-head attention mechanism optimized for technical content."""
830
+
831
+ def __init__(self, d_model: int, n_heads: int, dropout: float = DEFAULT_DROPOUT):
832
+ super().__init__()
833
+ if d_model % n_heads != 0:
834
+ raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads})")
835
+ self.d_model = d_model
836
+ self.n_heads = n_heads
837
+ self.d_k = d_model // n_heads
838
+ self.w_q = nn.Linear(d_model, d_model, bias=False)
839
+ self.w_k = nn.Linear(d_model, d_model, bias=False)
840
+ self.w_v = nn.Linear(d_model, d_model, bias=False)
841
+ self.w_o = nn.Linear(d_model, d_model)
842
+ self.dropout = nn.Dropout(dropout)
843
+ self._init_weights()
844
+
845
+ def _init_weights(self) -> None:
846
+ for module in [self.w_q, self.w_k, self.w_v, self.w_o]:
847
+ nn.init.xavier_uniform_(module.weight)
848
+ if hasattr(module, 'bias') and module.bias is not None:
849
+ nn.init.zeros_(module.bias)
850
+
851
+ def forward(
852
+ self,
853
+ query: torch.Tensor,
854
+ key: torch.Tensor,
855
+ value: torch.Tensor,
856
+ mask: Optional[torch.Tensor] = None,
857
+ pos_encoding: Optional[nn.Module] = None
858
+ ) -> torch.Tensor:
859
+ batch_size, seq_len, _ = query.size()
860
+ Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
861
+ K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
862
+ V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
863
+
864
+ if pos_encoding is not None:
865
+ Q, K = pos_encoding(Q, K)
866
+
867
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
868
+
869
+ if mask is not None:
870
+ mask = mask.unsqueeze(1).expand(batch_size, self.n_heads, seq_len, seq_len)
871
+ scores = scores.masked_fill(mask, float('-inf'))
872
+
873
+ attention_weights = F.softmax(scores, dim=-1)
874
+ attention_weights = self.dropout(attention_weights)
875
+ attended = torch.matmul(attention_weights, V)
876
+ attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
877
+ return self.w_o(attended)
878
+
879
+ class FeedForward(nn.Module):
880
+ """Position-wise feed-forward network with GELU activation."""
881
+
882
+ def __init__(self, d_model: int, dim_feedforward: int, dropout: float = DEFAULT_DROPOUT):
883
+ super().__init__()
884
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
885
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
886
+ self.dropout = nn.Dropout(dropout)
887
+ nn.init.xavier_uniform_(self.linear1.weight)
888
+ nn.init.zeros_(self.linear1.bias)
889
+ nn.init.xavier_uniform_(self.linear2.weight)
890
+ nn.init.zeros_(self.linear2.bias)
891
+
892
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
893
+ x = F.gelu(self.linear1(x))
894
+ x = self.dropout(x)
895
+ return self.linear2(x)
896
+
897
+ class RecursionRouter(nn.Module):
898
+ """Router to determine recursion steps for technical problem processing."""
899
+
900
+ def __init__(self, d_model: int, max_steps: int = DEFAULT_MAX_STEPS, router_type: str = DEFAULT_ROUTER_TYPE):
901
+ super().__init__()
902
+ self.max_steps = max_steps
903
+ self.router_type = router_type.lower()
904
+
905
+ if self.router_type == "adaptive":
906
+ self.complexity_classifier = nn.Sequential(
907
+ nn.Linear(d_model, d_model // 4),
908
+ nn.GELU(),
909
+ nn.Dropout(DEFAULT_DROPOUT),
910
+ nn.Linear(d_model // 4, max_steps + 1),
911
+ nn.Softmax(dim=-1)
912
+ )
913
+ elif self.router_type == "fixed":
914
+ self.register_buffer('fixed_steps', torch.tensor(max_steps, dtype=torch.long))
915
+ else:
916
+ raise ValueError(f"Invalid router_type: {router_type}. Choose 'adaptive' or 'fixed'.")
917
+
918
+ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, int]:
919
+ if self.router_type == "adaptive":
920
+ seq_repr = x.mean(dim=1)
921
+ step_probs = self.complexity_classifier(seq_repr)
922
+ return torch.argmax(step_probs, dim=-1)
923
+ return self.fixed_steps.item()
924
+
925
+ class RecursiveTransformerLayer(nn.Module):
926
+ """Transformer layer with recursive computation capability."""
927
+
928
+ def __init__(
929
+ self,
930
+ d_model: int,
931
+ n_heads: int,
932
+ dim_feedforward: int,
933
+ max_steps: int = DEFAULT_MAX_STEPS,
934
+ dropout: float = DEFAULT_DROPOUT,
935
+ router_type: str = DEFAULT_ROUTER_TYPE
936
+ ):
937
+ super().__init__()
938
+ self.max_steps = max_steps
939
+ self.d_model = d_model
940
+ self.attention = MultiHeadAttention(d_model, n_heads, dropout)
941
+ self.feedforward = FeedForward(d_model, dim_feedforward, dropout)
942
+ self.norm1 = nn.LayerNorm(d_model)
943
+ self.norm2 = nn.LayerNorm(d_model)
944
+ self.dropout = nn.Dropout(dropout)
945
+ self.router = RecursionRouter(d_model, max_steps, router_type)
946
+ self.step_projections = nn.ModuleList([
947
+ nn.Linear(d_model, d_model) for _ in range(max_steps)
948
+ ])
949
+ for proj in self.step_projections:
950
+ nn.init.xavier_uniform_(proj.weight)
951
+ nn.init.zeros_(proj.bias)
952
+
953
+ def forward(
954
+ self,
955
+ x: torch.Tensor,
956
+ mask: Optional[torch.Tensor] = None,
957
+ pos_encoding: Optional[nn.Module] = None
958
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
959
+ steps = self.router(x)
960
+ if isinstance(steps, (int, torch.Tensor)) and not torch.is_tensor(steps):
961
+ return self._recursive_forward_fixed(x, mask, steps, pos_encoding)
962
+ return self._recursive_forward_adaptive(x, mask, steps, pos_encoding)
963
+
964
+ def _recursive_forward_fixed(
965
+ self,
966
+ x: torch.Tensor,
967
+ mask: Optional[torch.Tensor],
968
+ num_steps: int,
969
+ pos_encoding: Optional[nn.Module]
970
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
971
+ device = x.device
972
+ batch_size = x.shape[0]
973
+ computation_loss = torch.tensor(0.0, device=device)
974
+ for step in range(min(num_steps, self.max_steps)):
975
+ step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
976
+ attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
977
+ x = self.norm1(x + self.dropout(attended))
978
+ fed_forward = self.feedforward(x)
979
+ x = self.norm2(x + self.dropout(fed_forward))
980
+ computation_loss += torch.tensor(0.1, device=device) * batch_size
981
+ return x, computation_loss
982
+
983
+ def _recursive_forward_adaptive(
984
+ self,
985
+ x: torch.Tensor,
986
+ mask: Optional[torch.Tensor],
987
+ steps: torch.Tensor,
988
+ pos_encoding: Optional[nn.Module]
989
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
990
+ batch_size, seq_len, d_model = x.shape
991
+ device = x.device
992
+ max_batch_steps = int(steps.max().item())
993
+ computation_loss = torch.tensor(0.0, device=device)
994
+ active_batches = torch.ones(batch_size, device=device, dtype=torch.bool)
995
+ for step in range(min(max_batch_steps, self.max_steps)):
996
+ step_mask = (steps > step) & active_batches
997
+ if not step_mask.any():
998
+ break
999
+ step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
1000
+ attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
1001
+ attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended))
1002
+ x = self.norm1(x + self.dropout(attended))
1003
+ fed_forward = self.feedforward(x)
1004
+ fed_forward = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), fed_forward, torch.zeros_like(fed_forward))
1005
+ x = self.norm2(x + self.dropout(fed_forward))
1006
+ computation_loss += torch.tensor(0.1, device=device) * step_mask.sum()
1007
+ active_batches &= (steps > step)
1008
+ return x, computation_loss
1009
+
1010
+ class MixtureOfRecursions(nn.Module):
1011
+ """Transformer model with mixture of recursive layers for technical content."""
1012
+
1013
+ def __init__(
1014
+ self,
1015
+ vocab_size: int,
1016
+ d_model: int = DEFAULT_D_MODEL,
1017
+ n_layers: int = DEFAULT_N_LAYERS,
1018
+ n_heads: int = DEFAULT_N_HEADS,
1019
+ max_steps: int = DEFAULT_MAX_STEPS,
1020
+ dim_feedforward: int = DEFAULT_DIM_FEEDFORWARD,
1021
+ dropout: float = DEFAULT_DROPOUT,
1022
+ max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
1023
+ router_type: str = DEFAULT_ROUTER_TYPE,
1024
+ padding_idx: int = DEFAULT_PADDING_IDX,
1025
+ pos_encoding: str = "learned"
1026
+ ):
1027
+ super().__init__()
1028
+ self.d_model = d_model
1029
+ self.vocab_size = vocab_size
1030
+ self.padding_idx = padding_idx
1031
+ self.embeddings = TechEmbeddingLayer(
1032
+ vocab_size=vocab_size,
1033
+ d_model=d_model,
1034
+ max_seq_len=max_seq_len,
1035
+ dropout=dropout,
1036
+ padding_idx=padding_idx,
1037
+ pos_encoding=pos_encoding
1038
+ )
1039
+ self.layers = nn.ModuleList([
1040
+ RecursiveTransformerLayer(
1041
+ d_model=d_model,
1042
+ n_heads=n_heads,
1043
+ dim_feedforward=dim_feedforward,
1044
+ max_steps=max_steps,
1045
+ dropout=dropout,
1046
+ router_type=router_type
1047
+ ) for _ in range(n_layers)
1048
+ ])
1049
+ self.final_norm = nn.LayerNorm(d_model)
1050
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
1051
+ self._init_weights()
1052
+
1053
+ def _init_weights(self) -> None:
1054
+ nn.init.xavier_uniform_(self.lm_head.weight)
1055
+
1056
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
1057
+ batch_size, seq_len = input_ids.shape
1058
+ padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
1059
+ causal_mask = create_causal_mask(seq_len, input_ids.device)
1060
+ combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
1061
+ x = self.embeddings(input_ids)
1062
+ pos_encoding = self.embeddings.get_positional_encoding()
1063
+ total_computation_loss = torch.tensor(0.0, device=x.device)
1064
+ for layer in self.layers:
1065
+ x, comp_loss = layer(x, combined_mask, pos_encoding)
1066
+ total_computation_loss += comp_loss
1067
+ x = self.final_norm(x)
1068
+ logits = self.lm_head(x)
1069
+ return logits, total_computation_loss
1070
+
1071
+ def generate_step(
1072
+ self,
1073
+ input_ids: torch.Tensor,
1074
+ temperature: float = 1.0,
1075
+ top_k: Optional[int] = None,
1076
+ top_p: Optional[float] = None
1077
+ ) -> torch.Tensor:
1078
+ self.eval()
1079
+ with torch.no_grad():
1080
+ logits, _ = self.forward(input_ids)
1081
+ last_logits = logits[:, -1, :] / temperature
1082
+ if top_k is not None:
1083
+ indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
1084
+ last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
1085
+ if top_p is not None:
1086
+ sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
1087
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
1088
+ sorted_indices_to_remove = cumulative_probs > top_p
1089
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1090
+ sorted_indices_to_remove[..., 0] = False
1091
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1092
+ last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
1093
+ probs = F.softmax(last_logits, dim=-1)
1094
+ return torch.multinomial(probs, num_samples=1)
1095
+
1096
+ class TextGenerator:
1097
+ """Text generation utility for the MixtureOfRecursions model."""
1098
+
1099
+ def __init__(self, model: nn.Module, tokenizer: 'Tokenizer', max_length: int = DEFAULT_MAX_SEQ_LEN, device: Optional[torch.device] = None):
1100
+ self.model = model
1101
+ self.tokenizer = tokenizer
1102
+ self.max_length = max_length
1103
+ self.device = device if device else next(model.parameters()).device
1104
+ self.model.to(self.device)
1105
+ self.eos_token_id = tokenizer.vocab.get('<|endoftext|>', -1)
1106
+ self.assistant_token_id = tokenizer.vocab.get('<|assistant|>', -1)
1107
+
1108
+ def generate(
1109
+ self,
1110
+ prompt: str,
1111
+ method: str = "nucleus",
1112
+ temperature: float = 1.0,
1113
+ top_k: Optional[int] = 50,
1114
+ top_p: Optional[float] = 0.9,
1115
+ max_new_tokens: Optional[int] = None
1116
+ ) -> str:
1117
+ max_new_tokens = max_new_tokens or self.max_length
1118
+ input_text = f"<|user|> {prompt}"
1119
+ input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
1120
+ input_tensor = torch.tensor([input_ids], device=self.device)
1121
+ self.model.eval()
1122
+ generated_ids = []
1123
+ with torch.no_grad():
1124
+ for _ in range(max_new_tokens):
1125
+ if input_tensor.size(1) > self.max_length:
1126
+ input_tensor = input_tensor[:, -self.max_length:]
1127
+ if method == "greedy":
1128
+ next_token = self._greedy_generate(input_tensor)
1129
+ elif method == "sample":
1130
+ next_token = self._sample_generate(input_tensor, temperature)
1131
+ elif method == "top_k":
1132
+ next_token = self._top_k_generate(input_tensor, temperature, top_k)
1133
+ elif method == "nucleus" or method == "top_p":
1134
+ next_token = self._nucleus_generate(input_tensor, temperature, top_p)
1135
+ else:
1136
+ raise ValueError(f"Unknown generation method: {method}")
1137
+ next_token_id = next_token.item()
1138
+ generated_ids.append(next_token_id)
1139
+ input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
1140
+ if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
1141
+ break
1142
+ full_ids = input_ids + generated_ids
1143
+ full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
1144
+ if "<|assistant|>" in full_text:
1145
+ response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
1146
+ else:
1147
+ response = full_text.split("<|endoftext|>")[0].strip()
1148
+ return response if response else "No response generated."
1149
+
1150
+ def _greedy_generate(self, input_tensor: torch.Tensor) -> torch.Tensor:
1151
+ logits, _ = self.model(input_tensor)
1152
+ return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
1153
+
1154
+ def _sample_generate(self, input_tensor: torch.Tensor, temperature: float) -> torch.Tensor:
1155
+ logits, _ = self.model(input_tensor)
1156
+ logits = logits[:, -1, :] / temperature
1157
+ probs = F.softmax(logits, dim=-1)
1158
+ return torch.multinomial(probs, num_samples=1)
1159
+
1160
+ def _top_k_generate(self, input_tensor: torch.Tensor, temperature: float, top_k: int) -> torch.Tensor:
1161
+ logits, _ = self.model(input_tensor)
1162
+ logits = logits[:, -1, :] / temperature
1163
+ top_k_logits, top_k_indices = torch.topk(logits, top_k)
1164
+ probs = F.softmax(top_k_logits, dim=-1)
1165
+ next_token_idx = torch.multinomial(probs, num_samples=1)
1166
+ return top_k_indices.gather(-1, next_token_idx)
1167
+
1168
+ def _nucleus_generate(self, input_tensor: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor:
1169
+ return self.model.generate_step(input_tensor, temperature, top_p=top_p)
1170
+
1171
+ def count_parameters(model: nn.Module) -> Tuple[int, int]:
1172
  total_params = sum(p.numel() for p in model.parameters())
1173
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1174
  return total_params, trainable_params