Spaces:
Running
Running
File size: 4,259 Bytes
aefdcf0 a47e5cf aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 f9d964d aefdcf0 cf79a6c aefdcf0 f9d964d 486475d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""
Positional Encoding for Transformer models.
Provides sinusoidal position embeddings that inject sequential order information
into token representations. Required because self-attention is permutation-invariant
and has no inherent notion of token position.
Author: Oliver Perrin
Date: December 2025
"""
import math
import torch
import torch.nn as nn
class PositionalEncoding(nn.Module):
"""
Implements the sinusoidal positional encoding from "Attention Is All You Need".
Formula:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Where:
pos: position in sequence (0 to max_len-1)
i: dimension index (0 to d_model/2)
Args:
d_model: Dimension of the model embeddings
max_len: Maximum sequence length to pre-compute
dropout: Dropout probability to apply after adding positional encoding
Shape:
Input: (batch, seq_len, d_model)
Output: (batch, seq_len, d_model)
Example:
>>> pos_enc = PositionalEncoding(d_model=512, max_len=5000)
>>> x = torch.randn(32, 100, 512) # (batch, seq, d_model)
>>> output = pos_enc(x)
>>> output.shape
torch.Size([32, 100, 512])
"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# Create a tensor of positions: [0, 1, 2, ..., max_len-1]
# Create a tensor of dimension indices: [0, 1, 2, ..., d_model-1]
# Compute the division term: 10000^(2i/d_model)
# Apply sin to even indices, cos to odd indices
# Register as buffer (not a parameter, but part of state_dict)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term) # Even indices
pe[:, 1::2] = torch.cos(position * div_term) # Odd indices
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encoding to input embeddings.
Args:
x: Input embeddings (batch, seq_len, d_model)
Returns:
x with positional encoding added (batch, seq_len, d_model)
"""
# Get sequence length from input
# Add the appropriate slice of positional encoding
# Apply dropout
# Return result
pe: torch.Tensor = self.pe # type: ignore[assignment]
x = x + pe[:, : x.size(1)].requires_grad_(False)
# self.pe contains pre-computed encodings for all positions
# just need to add the first seq_len positions to x
return self.dropout(x)
class LearnedPositionalEncoding(nn.Module):
"""
Learned positional embeddings (used by BERT, GPT, etc.).
Note: T5/FLAN-T5 uses relative position bias instead of absolute positional embeddings.
When loading from T5, the model uses learned positional encodings that train from scratch.
Args:
d_model: Dimension of the model embeddings
max_len: Maximum sequence length
dropout: Dropout probability
padding_idx: Index of padding token (used to mask out padding positions if needed)
"""
def __init__(
self, d_model: int, max_len: int = 1024, dropout: float = 0.1, padding_idx: int = 1
):
super().__init__()
# Standard learned positional embeddings.
# Note: T5's relative position bias is NOT transferred - we train these from scratch.
self.embeddings = nn.Embedding(max_len, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input embeddings (batch, seq_len, d_model)
"""
seq_len = x.size(1)
positions = torch.arange(seq_len, dtype=torch.long, device=x.device)
# Broadcast to batch
positions = positions.unsqueeze(0).expand(x.size(0), -1)
pos_embeds = self.embeddings(positions)
return self.dropout(x + pos_embeds)
|