|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from typing import Optional, Tuple, Union, List |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import PretrainedConfig |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
class MixtureOfRecursionsConfig(PretrainedConfig): |
|
|
"""Configuration class for MixtureOfRecursions model.""" |
|
|
|
|
|
model_type = "mixture_of_recursions" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=31985, |
|
|
d_model=384, |
|
|
n_layers=12, |
|
|
n_heads=6, |
|
|
max_steps=4, |
|
|
dim_feedforward=2048, |
|
|
dropout=0.1, |
|
|
max_seq_len=128, |
|
|
router_type="adaptive", |
|
|
padding_idx=0, |
|
|
pos_encoding="learned", |
|
|
hidden_size=None, |
|
|
num_hidden_layers=None, |
|
|
num_attention_heads=None, |
|
|
intermediate_size=None, |
|
|
max_position_embeddings=None, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.n_layers = n_layers |
|
|
self.n_heads = n_heads |
|
|
self.max_steps = max_steps |
|
|
self.dim_feedforward = dim_feedforward |
|
|
self.dropout = dropout |
|
|
self.max_seq_len = max_seq_len |
|
|
self.router_type = router_type |
|
|
self.padding_idx = padding_idx |
|
|
self.pos_encoding = pos_encoding |
|
|
self.hidden_size = hidden_size or d_model |
|
|
self.num_hidden_layers = num_hidden_layers or n_layers |
|
|
self.num_attention_heads = num_attention_heads or n_heads |
|
|
self.intermediate_size = intermediate_size or dim_feedforward |
|
|
self.max_position_embeddings = max_position_embeddings or max_seq_len |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_BASE = 10000.0 |
|
|
DEFAULT_CUTOFFS = [2000, 10000] |
|
|
DEFAULT_DIV_VAL = 4.0 |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
"""Sinusoidal positional encoding for transformer models.""" |
|
|
|
|
|
def __init__(self, d_model: int, max_seq_len: int = 512, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
pe = torch.zeros(max_seq_len, d_model) |
|
|
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(DEFAULT_BASE) / d_model)) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term[:, :-1] if d_model % 2 == 1 else div_term) |
|
|
self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
batch_size, seq_len, d_model = x.size() |
|
|
if d_model != self.d_model: |
|
|
raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}") |
|
|
x = x + self.pe[:, :seq_len] |
|
|
return self.dropout(x) |
|
|
|
|
|
class LearnedPositionalEmbedding(nn.Module): |
|
|
"""Learned positional embeddings for transformer models.""" |
|
|
|
|
|
def __init__(self, max_seq_len: int, d_model: int, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.max_seq_len = max_seq_len |
|
|
self.d_model = d_model |
|
|
self.pos_embedding = nn.Embedding(max_seq_len, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
nn.init.normal_(self.pos_embedding.weight, std=0.02) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
batch_size, seq_len, d_model = x.size() |
|
|
if seq_len > self.max_seq_len: |
|
|
raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}") |
|
|
if d_model != self.d_model: |
|
|
raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}") |
|
|
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1) |
|
|
pos_emb = self.pos_embedding(positions) |
|
|
x = x + pos_emb |
|
|
return self.dropout(x) |
|
|
|
|
|
class RotaryPositionalEmbedding(nn.Module): |
|
|
"""Rotary Positional Embedding (RoPE) for transformer models.""" |
|
|
|
|
|
def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = DEFAULT_BASE): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.max_seq_len = max_seq_len |
|
|
self.base = base |
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model)) |
|
|
self.register_buffer('inv_freq', inv_freq) |
|
|
self._seq_len_cached = 0 |
|
|
self._cos_cached = None |
|
|
self._sin_cached = None |
|
|
|
|
|
def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None: |
|
|
if seq_len > self._seq_len_cached: |
|
|
self._seq_len_cached = seq_len |
|
|
t = torch.arange(seq_len, device=device, dtype=torch.float32) |
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
self._cos_cached = freqs.cos().to(dtype) |
|
|
self._sin_cached = freqs.sin().to(dtype) |
|
|
|
|
|
def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
|
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] |
|
|
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
|
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
batch_size, seq_len, num_heads, head_dim = q.shape |
|
|
self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype) |
|
|
cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1) |
|
|
sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1) |
|
|
q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) |
|
|
k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) |
|
|
q_rot = self._rotate_half(q, cos, sin) |
|
|
k_rot = self._rotate_half(k, cos, sin) |
|
|
q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) |
|
|
k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) |
|
|
return q_rot, k_rot |
|
|
|
|
|
class TechEmbeddingLayer(nn.Module): |
|
|
"""Comprehensive embedding layer with token and positional embeddings.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int, |
|
|
d_model: int, |
|
|
max_seq_len: int = 512, |
|
|
dropout: float = 0.1, |
|
|
padding_idx: int = 0, |
|
|
pos_encoding: str = "learned", |
|
|
layer_norm: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.vocab_size = vocab_size |
|
|
self.padding_idx = padding_idx |
|
|
self.pos_encoding_type = pos_encoding.lower() |
|
|
self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) |
|
|
|
|
|
if pos_encoding == "sinusoidal": |
|
|
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout) |
|
|
elif pos_encoding == "learned": |
|
|
self.pos_encoding = LearnedPositionalEmbedding(max_seq_len, d_model, dropout) |
|
|
elif pos_encoding == "rope": |
|
|
self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len) |
|
|
else: |
|
|
raise ValueError(f"Unknown positional encoding type: {pos_encoding}") |
|
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity() |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self) -> None: |
|
|
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02) |
|
|
if self.padding_idx is not None: |
|
|
nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0) |
|
|
|
|
|
def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
if (input_ids >= self.vocab_size).any(): |
|
|
raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})") |
|
|
embeddings = self.token_embedding(input_ids) |
|
|
if self.pos_encoding_type != "rope": |
|
|
embeddings = self.pos_encoding(embeddings) |
|
|
embeddings = self.layer_norm(embeddings) |
|
|
return self.dropout(embeddings) |
|
|
|
|
|
def get_positional_encoding(self) -> Optional[nn.Module]: |
|
|
return self.pos_encoding if self.pos_encoding_type == "rope" else None |
|
|
|
|
|
def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = 0) -> torch.Tensor: |
|
|
return input_ids == padding_idx |
|
|
|
|
|
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: |
|
|
return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_D_MODEL = 512 |
|
|
DEFAULT_N_HEADS = 8 |
|
|
DEFAULT_N_LAYERS = 6 |
|
|
DEFAULT_MAX_STEPS = 4 |
|
|
DEFAULT_DIM_FEEDFORWARD = 2048 |
|
|
DEFAULT_DROPOUT = 0.1 |
|
|
DEFAULT_MAX_SEQ_LEN = 512 |
|
|
DEFAULT_PADDING_IDX = 0 |
|
|
DEFAULT_ROUTER_TYPE = "adaptive" |
|
|
DEFAULT_VOCAB_SIZE = 10000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
"""Multi-head attention mechanism optimized for technical content.""" |
|
|
|
|
|
def __init__(self, d_model: int, n_heads: int, dropout: float = DEFAULT_DROPOUT): |
|
|
super().__init__() |
|
|
if d_model % n_heads != 0: |
|
|
raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads})") |
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.d_k = d_model // n_heads |
|
|
self.w_q = nn.Linear(d_model, d_model, bias=False) |
|
|
self.w_k = nn.Linear(d_model, d_model, bias=False) |
|
|
self.w_v = nn.Linear(d_model, d_model, bias=False) |
|
|
self.w_o = nn.Linear(d_model, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self) -> None: |
|
|
for module in [self.w_q, self.w_k, self.w_v, self.w_o]: |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if hasattr(module, 'bias') and module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
pos_encoding: Optional[nn.Module] = None |
|
|
) -> torch.Tensor: |
|
|
batch_size, seq_len, _ = query.size() |
|
|
Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) |
|
|
K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) |
|
|
V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) |
|
|
|
|
|
if pos_encoding is not None: |
|
|
Q, K = pos_encoding(Q, K) |
|
|
|
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.unsqueeze(1).expand(batch_size, self.n_heads, seq_len, seq_len) |
|
|
scores = scores.masked_fill(mask, float('-inf')) |
|
|
|
|
|
attention_weights = F.softmax(scores, dim=-1) |
|
|
attention_weights = self.dropout(attention_weights) |
|
|
attended = torch.matmul(attention_weights, V) |
|
|
attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) |
|
|
return self.w_o(attended) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
"""Position-wise feed-forward network with GELU activation.""" |
|
|
|
|
|
def __init__(self, d_model: int, dim_feedforward: int, dropout: float = DEFAULT_DROPOUT): |
|
|
super().__init__() |
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
nn.init.xavier_uniform_(self.linear1.weight) |
|
|
nn.init.zeros_(self.linear1.bias) |
|
|
nn.init.xavier_uniform_(self.linear2.weight) |
|
|
nn.init.zeros_(self.linear2.bias) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = F.gelu(self.linear1(x)) |
|
|
x = self.dropout(x) |
|
|
return self.linear2(x) |
|
|
|
|
|
class RecursionRouter(nn.Module): |
|
|
"""Router to determine recursion steps for technical problem processing.""" |
|
|
|
|
|
def __init__(self, d_model: int, max_steps: int = DEFAULT_MAX_STEPS, router_type: str = DEFAULT_ROUTER_TYPE): |
|
|
super().__init__() |
|
|
self.max_steps = max_steps |
|
|
self.router_type = router_type.lower() |
|
|
|
|
|
if self.router_type == "adaptive": |
|
|
self.complexity_classifier = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 4), |
|
|
nn.GELU(), |
|
|
nn.Dropout(DEFAULT_DROPOUT), |
|
|
nn.Linear(d_model // 4, max_steps + 1), |
|
|
nn.Softmax(dim=-1) |
|
|
) |
|
|
elif self.router_type == "fixed": |
|
|
self.register_buffer('fixed_steps', torch.tensor(max_steps, dtype=torch.long)) |
|
|
else: |
|
|
raise ValueError(f"Invalid router_type: {router_type}. Choose 'adaptive' or 'fixed'.") |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, int]: |
|
|
if self.router_type == "adaptive": |
|
|
seq_repr = x.mean(dim=1) |
|
|
step_probs = self.complexity_classifier(seq_repr) |
|
|
return torch.argmax(step_probs, dim=-1) |
|
|
return self.fixed_steps.item() |
|
|
|
|
|
class RecursiveTransformerLayer(nn.Module): |
|
|
"""Transformer layer with recursive computation capability.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_heads: int, |
|
|
dim_feedforward: int, |
|
|
max_steps: int = DEFAULT_MAX_STEPS, |
|
|
dropout: float = DEFAULT_DROPOUT, |
|
|
router_type: str = DEFAULT_ROUTER_TYPE |
|
|
): |
|
|
super().__init__() |
|
|
self.max_steps = max_steps |
|
|
self.d_model = d_model |
|
|
self.attention = MultiHeadAttention(d_model, n_heads, dropout) |
|
|
self.feedforward = FeedForward(d_model, dim_feedforward, dropout) |
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.router = RecursionRouter(d_model, max_steps, router_type) |
|
|
self.step_projections = nn.ModuleList([ |
|
|
nn.Linear(d_model, d_model) for _ in range(max_steps) |
|
|
]) |
|
|
for proj in self.step_projections: |
|
|
nn.init.xavier_uniform_(proj.weight) |
|
|
nn.init.zeros_(proj.bias) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
pos_encoding: Optional[nn.Module] = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
steps = self.router(x) |
|
|
if isinstance(steps, (int, torch.Tensor)) and not torch.is_tensor(steps): |
|
|
return self._recursive_forward_fixed(x, mask, steps, pos_encoding) |
|
|
return self._recursive_forward_adaptive(x, mask, steps, pos_encoding) |
|
|
|
|
|
def _recursive_forward_fixed( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
mask: Optional[torch.Tensor], |
|
|
num_steps: int, |
|
|
pos_encoding: Optional[nn.Module] |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
device = x.device |
|
|
batch_size = x.shape[0] |
|
|
computation_loss = torch.tensor(0.0, device=device) |
|
|
for step in range(min(num_steps, self.max_steps)): |
|
|
step_input = self.step_projections[step](x) if step < len(self.step_projections) else x |
|
|
attended = self.attention(step_input, step_input, step_input, mask, pos_encoding) |
|
|
x = self.norm1(x + self.dropout(attended)) |
|
|
fed_forward = self.feedforward(x) |
|
|
x = self.norm2(x + self.dropout(fed_forward)) |
|
|
computation_loss += torch.tensor(0.1, device=device) * batch_size |
|
|
return x, computation_loss |
|
|
|
|
|
def _recursive_forward_adaptive( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
mask: Optional[torch.Tensor], |
|
|
steps: torch.Tensor, |
|
|
pos_encoding: Optional[nn.Module] |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
batch_size, seq_len, d_model = x.shape |
|
|
device = x.device |
|
|
max_batch_steps = int(steps.max().item()) |
|
|
computation_loss = torch.tensor(0.0, device=device) |
|
|
active_batches = torch.ones(batch_size, device=device, dtype=torch.bool) |
|
|
for step in range(min(max_batch_steps, self.max_steps)): |
|
|
step_mask = (steps > step) & active_batches |
|
|
if not step_mask.any(): |
|
|
break |
|
|
step_input = self.step_projections[step](x) if step < len(self.step_projections) else x |
|
|
attended = self.attention(step_input, step_input, step_input, mask, pos_encoding) |
|
|
attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended)) |
|
|
x = self.norm1(x + self.dropout(attended)) |
|
|
fed_forward = self.feedforward(x) |
|
|
fed_forward = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), fed_forward, torch.zeros_like(fed_forward)) |
|
|
x = self.norm2(x + self.dropout(fed_forward)) |
|
|
computation_loss += torch.tensor(0.1, device=device) * step_mask.sum() |
|
|
active_batches &= (steps > step) |
|
|
return x, computation_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MixtureOfRecursionsPreTrainedModel(PreTrainedModel): |
|
|
"""PreTrainedModel wrapper for MixtureOfRecursions.""" |
|
|
|
|
|
config_class = MixtureOfRecursionsConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize weights.""" |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
class MixtureOfRecursions(MixtureOfRecursionsPreTrainedModel): |
|
|
"""Transformer model with mixture of recursive layers for technical content.""" |
|
|
|
|
|
def __init__(self, config: MixtureOfRecursionsConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.d_model = config.d_model |
|
|
self.vocab_size = config.vocab_size |
|
|
self.padding_idx = config.padding_idx |
|
|
|
|
|
self.embeddings = TechEmbeddingLayer( |
|
|
vocab_size=config.vocab_size, |
|
|
d_model=config.d_model, |
|
|
max_seq_len=config.max_seq_len, |
|
|
dropout=config.dropout, |
|
|
padding_idx=config.padding_idx, |
|
|
pos_encoding=config.pos_encoding |
|
|
) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
RecursiveTransformerLayer( |
|
|
d_model=config.d_model, |
|
|
n_heads=config.n_heads, |
|
|
dim_feedforward=config.dim_feedforward, |
|
|
max_steps=config.max_steps, |
|
|
dropout=config.dropout, |
|
|
router_type=config.router_type |
|
|
) for _ in range(config.n_layers) |
|
|
]) |
|
|
|
|
|
self.final_norm = nn.LayerNorm(config.d_model) |
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True |
|
|
): |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0) |
|
|
causal_mask = create_causal_mask(seq_len, input_ids.device) |
|
|
combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0) |
|
|
|
|
|
|
|
|
x = self.embeddings(input_ids) |
|
|
pos_encoding = self.embeddings.get_positional_encoding() |
|
|
|
|
|
total_computation_loss = torch.tensor(0.0, device=x.device) |
|
|
for layer in self.layers: |
|
|
x, comp_loss = layer(x, combined_mask, pos_encoding) |
|
|
total_computation_loss += comp_loss |
|
|
|
|
|
x = self.final_norm(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1)) |
|
|
loss += 0.01 * total_computation_loss |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
from transformers.modeling_outputs import CausalLMOutput |
|
|
return CausalLMOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
def generate_step( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
temperature: float = 1.0, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None |
|
|
) -> torch.Tensor: |
|
|
self.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = self.forward(input_ids, return_dict=True) |
|
|
logits = outputs.logits |
|
|
last_logits = logits[:, -1, :] / temperature |
|
|
|
|
|
if top_k is not None: |
|
|
indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None] |
|
|
last_logits = last_logits.masked_fill(indices_to_remove, float('-inf')) |
|
|
|
|
|
if top_p is not None: |
|
|
sorted_logits, sorted_indices = torch.sort(last_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = False |
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
last_logits = last_logits.masked_fill(indices_to_remove, float('-inf')) |
|
|
|
|
|
probs = F.softmax(last_logits, dim=-1) |
|
|
return torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
MixtureOfRecursions.register_for_auto_class("AutoModelForCausalLM") |
|
|
|
|
|
def count_parameters(model: nn.Module) -> Tuple[int, int]: |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
return total_params, trainable_params |
|
|
|
|
|
def main(): |
|
|
"""Test the MixtureOfRecursions model and its components.""" |
|
|
print("Initializing MixtureOfRecursions model...") |
|
|
config = MixtureOfRecursionsConfig( |
|
|
vocab_size=DEFAULT_VOCAB_SIZE, |
|
|
d_model=DEFAULT_D_MODEL, |
|
|
n_layers=DEFAULT_N_LAYERS, |
|
|
n_heads=DEFAULT_N_HEADS, |
|
|
max_steps=DEFAULT_MAX_STEPS, |
|
|
dim_feedforward=DEFAULT_DIM_FEEDFORWARD, |
|
|
dropout=DEFAULT_DROPOUT, |
|
|
router_type=DEFAULT_ROUTER_TYPE |
|
|
) |
|
|
model = MixtureOfRecursions(config) |
|
|
|
|
|
total_params, trainable_params = count_parameters(model) |
|
|
print(f"Total parameters: {total_params:,}") |
|
|
print(f"Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
print("\nTesting forward pass...") |
|
|
batch_size, seq_len = 4, 128 |
|
|
input_ids = torch.randint(0, DEFAULT_VOCAB_SIZE, (batch_size, seq_len)) |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
attention_mask[:, -10:] = 0 |
|
|
|
|
|
outputs = model(input_ids, attention_mask, return_dict=True) |
|
|
logits = outputs.logits |
|
|
|
|
|
assert logits.shape == (batch_size, seq_len, DEFAULT_VOCAB_SIZE), f"Unexpected logits shape: {logits.shape}" |
|
|
print(f"Input shape: {input_ids.shape}") |
|
|
print(f"Output logits shape: {logits.shape}") |
|
|
print(f"Expected logits shape: ({batch_size}, {seq_len}, {DEFAULT_VOCAB_SIZE})") |
|
|
|
|
|
print("\nTesting generation step...") |
|
|
next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9) |
|
|
print(f"Generated next token: {next_token.item()}") |
|
|
|
|
|
print("\nModel test completed successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |