Girinath11's picture
Update model_slm.py
9c59aeb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, Union, List
# ============================================================================
# TRANSFORMERS COMPATIBILITY
# ============================================================================
from transformers import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
class MixtureOfRecursionsConfig(PretrainedConfig):
"""Configuration class for MixtureOfRecursions model."""
model_type = "mixture_of_recursions"
def __init__(
self,
vocab_size=31985,
d_model=384,
n_layers=12,
n_heads=6,
max_steps=4,
dim_feedforward=2048,
dropout=0.1,
max_seq_len=128,
router_type="adaptive",
padding_idx=0,
pos_encoding="learned",
hidden_size=None,
num_hidden_layers=None,
num_attention_heads=None,
intermediate_size=None,
max_position_embeddings=None,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.max_steps = max_steps
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.max_seq_len = max_seq_len
self.router_type = router_type
self.padding_idx = padding_idx
self.pos_encoding = pos_encoding
self.hidden_size = hidden_size or d_model
self.num_hidden_layers = num_hidden_layers or n_layers
self.num_attention_heads = num_attention_heads or n_heads
self.intermediate_size = intermediate_size or dim_feedforward
self.max_position_embeddings = max_position_embeddings or max_seq_len
# ============================================================================
# EMBEDDINGS MODULE
# ============================================================================
DEFAULT_BASE = 10000.0
DEFAULT_CUTOFFS = [2000, 10000]
DEFAULT_DIV_VAL = 4.0
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for transformer models."""
def __init__(self, d_model: int, max_seq_len: int = 512, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(DEFAULT_BASE) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term[:, :-1] if d_model % 2 == 1 else div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, d_model = x.size()
if d_model != self.d_model:
raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
x = x + self.pe[:, :seq_len]
return self.dropout(x)
class LearnedPositionalEmbedding(nn.Module):
"""Learned positional embeddings for transformer models."""
def __init__(self, max_seq_len: int, d_model: int, dropout: float = 0.1):
super().__init__()
self.max_seq_len = max_seq_len
self.d_model = d_model
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
nn.init.normal_(self.pos_embedding.weight, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, d_model = x.size()
if seq_len > self.max_seq_len:
raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")
if d_model != self.d_model:
raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
pos_emb = self.pos_embedding(positions)
x = x + pos_emb
return self.dropout(x)
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Positional Embedding (RoPE) for transformer models."""
def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = DEFAULT_BASE):
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
self.register_buffer('inv_freq', inv_freq)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
if seq_len > self._seq_len_cached:
self._seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
self._cos_cached = freqs.cos().to(dtype)
self._sin_cached = freqs.sin().to(dtype)
def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = q.shape
self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype)
cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
q_rot = self._rotate_half(q, cos, sin)
k_rot = self._rotate_half(k, cos, sin)
q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
return q_rot, k_rot
class TechEmbeddingLayer(nn.Module):
"""Comprehensive embedding layer with token and positional embeddings."""
def __init__(
self,
vocab_size: int,
d_model: int,
max_seq_len: int = 512,
dropout: float = 0.1,
padding_idx: int = 0,
pos_encoding: str = "learned",
layer_norm: bool = True,
):
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.padding_idx = padding_idx
self.pos_encoding_type = pos_encoding.lower()
self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
if pos_encoding == "sinusoidal":
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
elif pos_encoding == "learned":
self.pos_encoding = LearnedPositionalEmbedding(max_seq_len, d_model, dropout)
elif pos_encoding == "rope":
self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len)
else:
raise ValueError(f"Unknown positional encoding type: {pos_encoding}")
self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity()
self.dropout = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self) -> None:
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
if self.padding_idx is not None:
nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
if (input_ids >= self.vocab_size).any():
raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
embeddings = self.token_embedding(input_ids)
if self.pos_encoding_type != "rope":
embeddings = self.pos_encoding(embeddings)
embeddings = self.layer_norm(embeddings)
return self.dropout(embeddings)
def get_positional_encoding(self) -> Optional[nn.Module]:
return self.pos_encoding if self.pos_encoding_type == "rope" else None
def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = 0) -> torch.Tensor:
return input_ids == padding_idx
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
# ============================================================================
# MODEL CONSTANTS
# ============================================================================
DEFAULT_D_MODEL = 512
DEFAULT_N_HEADS = 8
DEFAULT_N_LAYERS = 6
DEFAULT_MAX_STEPS = 4
DEFAULT_DIM_FEEDFORWARD = 2048
DEFAULT_DROPOUT = 0.1
DEFAULT_MAX_SEQ_LEN = 512
DEFAULT_PADDING_IDX = 0
DEFAULT_ROUTER_TYPE = "adaptive"
DEFAULT_VOCAB_SIZE = 10000
# ============================================================================
# MODEL COMPONENTS
# ============================================================================
class MultiHeadAttention(nn.Module):
"""Multi-head attention mechanism optimized for technical content."""
def __init__(self, d_model: int, n_heads: int, dropout: float = DEFAULT_DROPOUT):
super().__init__()
if d_model % n_heads != 0:
raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads})")
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self) -> None:
for module in [self.w_q, self.w_k, self.w_v, self.w_o]:
nn.init.xavier_uniform_(module.weight)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
pos_encoding: Optional[nn.Module] = None
) -> torch.Tensor:
batch_size, seq_len, _ = query.size()
Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
if pos_encoding is not None:
Q, K = pos_encoding(Q, K)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
mask = mask.unsqueeze(1).expand(batch_size, self.n_heads, seq_len, seq_len)
scores = scores.masked_fill(mask, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
attended = torch.matmul(attention_weights, V)
attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.w_o(attended)
class FeedForward(nn.Module):
"""Position-wise feed-forward network with GELU activation."""
def __init__(self, d_model: int, dim_feedforward: int, dropout: float = DEFAULT_DROPOUT):
super().__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.linear1.weight)
nn.init.zeros_(self.linear1.bias)
nn.init.xavier_uniform_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(self.linear1(x))
x = self.dropout(x)
return self.linear2(x)
class RecursionRouter(nn.Module):
"""Router to determine recursion steps for technical problem processing."""
def __init__(self, d_model: int, max_steps: int = DEFAULT_MAX_STEPS, router_type: str = DEFAULT_ROUTER_TYPE):
super().__init__()
self.max_steps = max_steps
self.router_type = router_type.lower()
if self.router_type == "adaptive":
self.complexity_classifier = nn.Sequential(
nn.Linear(d_model, d_model // 4),
nn.GELU(),
nn.Dropout(DEFAULT_DROPOUT),
nn.Linear(d_model // 4, max_steps + 1),
nn.Softmax(dim=-1)
)
elif self.router_type == "fixed":
self.register_buffer('fixed_steps', torch.tensor(max_steps, dtype=torch.long))
else:
raise ValueError(f"Invalid router_type: {router_type}. Choose 'adaptive' or 'fixed'.")
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, int]:
if self.router_type == "adaptive":
seq_repr = x.mean(dim=1)
step_probs = self.complexity_classifier(seq_repr)
return torch.argmax(step_probs, dim=-1)
return self.fixed_steps.item()
class RecursiveTransformerLayer(nn.Module):
"""Transformer layer with recursive computation capability."""
def __init__(
self,
d_model: int,
n_heads: int,
dim_feedforward: int,
max_steps: int = DEFAULT_MAX_STEPS,
dropout: float = DEFAULT_DROPOUT,
router_type: str = DEFAULT_ROUTER_TYPE
):
super().__init__()
self.max_steps = max_steps
self.d_model = d_model
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feedforward = FeedForward(d_model, dim_feedforward, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.router = RecursionRouter(d_model, max_steps, router_type)
self.step_projections = nn.ModuleList([
nn.Linear(d_model, d_model) for _ in range(max_steps)
])
for proj in self.step_projections:
nn.init.xavier_uniform_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
pos_encoding: Optional[nn.Module] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
steps = self.router(x)
if isinstance(steps, (int, torch.Tensor)) and not torch.is_tensor(steps):
return self._recursive_forward_fixed(x, mask, steps, pos_encoding)
return self._recursive_forward_adaptive(x, mask, steps, pos_encoding)
def _recursive_forward_fixed(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor],
num_steps: int,
pos_encoding: Optional[nn.Module]
) -> Tuple[torch.Tensor, torch.Tensor]:
device = x.device
batch_size = x.shape[0]
computation_loss = torch.tensor(0.0, device=device)
for step in range(min(num_steps, self.max_steps)):
step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
x = self.norm1(x + self.dropout(attended))
fed_forward = self.feedforward(x)
x = self.norm2(x + self.dropout(fed_forward))
computation_loss += torch.tensor(0.1, device=device) * batch_size
return x, computation_loss
def _recursive_forward_adaptive(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor],
steps: torch.Tensor,
pos_encoding: Optional[nn.Module]
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, d_model = x.shape
device = x.device
max_batch_steps = int(steps.max().item())
computation_loss = torch.tensor(0.0, device=device)
active_batches = torch.ones(batch_size, device=device, dtype=torch.bool)
for step in range(min(max_batch_steps, self.max_steps)):
step_mask = (steps > step) & active_batches
if not step_mask.any():
break
step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended))
x = self.norm1(x + self.dropout(attended))
fed_forward = self.feedforward(x)
fed_forward = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), fed_forward, torch.zeros_like(fed_forward))
x = self.norm2(x + self.dropout(fed_forward))
computation_loss += torch.tensor(0.1, device=device) * step_mask.sum()
active_batches &= (steps > step)
return x, computation_loss
# ============================================================================
# PRETRAINED MODEL WRAPPER
# ============================================================================
class MixtureOfRecursionsPreTrainedModel(PreTrainedModel):
"""PreTrainedModel wrapper for MixtureOfRecursions."""
config_class = MixtureOfRecursionsConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class MixtureOfRecursions(MixtureOfRecursionsPreTrainedModel):
"""Transformer model with mixture of recursive layers for technical content."""
def __init__(self, config: MixtureOfRecursionsConfig):
super().__init__(config)
self.config = config
self.d_model = config.d_model
self.vocab_size = config.vocab_size
self.padding_idx = config.padding_idx
self.embeddings = TechEmbeddingLayer(
vocab_size=config.vocab_size,
d_model=config.d_model,
max_seq_len=config.max_seq_len,
dropout=config.dropout,
padding_idx=config.padding_idx,
pos_encoding=config.pos_encoding
)
self.layers = nn.ModuleList([
RecursiveTransformerLayer(
d_model=config.d_model,
n_heads=config.n_heads,
dim_feedforward=config.dim_feedforward,
max_steps=config.max_steps,
dropout=config.dropout,
router_type=config.router_type
) for _ in range(config.n_layers)
])
self.final_norm = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Initialize weights
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True
):
batch_size, seq_len = input_ids.shape
# Create masks
padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
causal_mask = create_causal_mask(seq_len, input_ids.device)
combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
# Forward pass
x = self.embeddings(input_ids)
pos_encoding = self.embeddings.get_positional_encoding()
total_computation_loss = torch.tensor(0.0, device=x.device)
for layer in self.layers:
x, comp_loss = layer(x, combined_mask, pos_encoding)
total_computation_loss += comp_loss
x = self.final_norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
# Shift logits and labels for language modeling
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
loss += 0.01 * total_computation_loss # Add computation loss
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
from transformers.modeling_outputs import CausalLMOutput
return CausalLMOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)
def generate_step(
self,
input_ids: torch.Tensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None
) -> torch.Tensor:
self.eval()
with torch.no_grad():
outputs = self.forward(input_ids, return_dict=True)
logits = outputs.logits
last_logits = logits[:, -1, :] / temperature
if top_k is not None:
indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
probs = F.softmax(last_logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
# Register the model for auto class
MixtureOfRecursions.register_for_auto_class("AutoModelForCausalLM")
def count_parameters(model: nn.Module) -> Tuple[int, int]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params
def main():
"""Test the MixtureOfRecursions model and its components."""
print("Initializing MixtureOfRecursions model...")
config = MixtureOfRecursionsConfig(
vocab_size=DEFAULT_VOCAB_SIZE,
d_model=DEFAULT_D_MODEL,
n_layers=DEFAULT_N_LAYERS,
n_heads=DEFAULT_N_HEADS,
max_steps=DEFAULT_MAX_STEPS,
dim_feedforward=DEFAULT_DIM_FEEDFORWARD,
dropout=DEFAULT_DROPOUT,
router_type=DEFAULT_ROUTER_TYPE
)
model = MixtureOfRecursions(config)
total_params, trainable_params = count_parameters(model)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print("\nTesting forward pass...")
batch_size, seq_len = 4, 128
input_ids = torch.randint(0, DEFAULT_VOCAB_SIZE, (batch_size, seq_len))
attention_mask = torch.ones_like(input_ids)
attention_mask[:, -10:] = 0
outputs = model(input_ids, attention_mask, return_dict=True)
logits = outputs.logits
assert logits.shape == (batch_size, seq_len, DEFAULT_VOCAB_SIZE), f"Unexpected logits shape: {logits.shape}"
print(f"Input shape: {input_ids.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"Expected logits shape: ({batch_size}, {seq_len}, {DEFAULT_VOCAB_SIZE})")
print("\nTesting generation step...")
next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
print(f"Generated next token: {next_token.item()}")
print("\nModel test completed successfully!")
if __name__ == "__main__":
main()