File size: 4,259 Bytes
aefdcf0
 
 
a47e5cf
 
 
 
 
 
aefdcf0
 
f9d964d
 
aefdcf0
 
f9d964d
aefdcf0
 
 
 
f9d964d
aefdcf0
 
 
f9d964d
aefdcf0
 
 
f9d964d
aefdcf0
 
 
 
f9d964d
aefdcf0
 
 
f9d964d
aefdcf0
 
 
 
 
 
 
f9d964d
aefdcf0
 
 
 
 
 
 
 
 
f9d964d
aefdcf0
 
 
 
 
f9d964d
aefdcf0
 
 
f9d964d
aefdcf0
 
f9d964d
aefdcf0
 
 
 
 
 
 
cf79a6c
 
aefdcf0
 
f9d964d
486475d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Positional Encoding for Transformer models.

Provides sinusoidal position embeddings that inject sequential order information
into token representations. Required because self-attention is permutation-invariant
and has no inherent notion of token position.

Author: Oliver Perrin
Date: December 2025
"""

import math

import torch
import torch.nn as nn


class PositionalEncoding(nn.Module):
    """
    Implements the sinusoidal positional encoding from "Attention Is All You Need".

    Formula:
        PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
        PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

    Where:
        pos: position in sequence (0 to max_len-1)
        i: dimension index (0 to d_model/2)

    Args:
        d_model: Dimension of the model embeddings
        max_len: Maximum sequence length to pre-compute
        dropout: Dropout probability to apply after adding positional encoding

    Shape:
        Input: (batch, seq_len, d_model)
        Output: (batch, seq_len, d_model)

    Example:
        >>> pos_enc = PositionalEncoding(d_model=512, max_len=5000)
        >>> x = torch.randn(32, 100, 512)  # (batch, seq, d_model)
        >>> output = pos_enc(x)
        >>> output.shape
        torch.Size([32, 100, 512])
    """

    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        # Create a tensor of positions: [0, 1, 2, ..., max_len-1]
        # Create a tensor of dimension indices: [0, 1, 2, ..., d_model-1]
        # Compute the division term: 10000^(2i/d_model)
        # Apply sin to even indices, cos to odd indices
        # Register as buffer (not a parameter, but part of state_dict)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Add positional encoding to input embeddings.

        Args:
            x: Input embeddings (batch, seq_len, d_model)

        Returns:
            x with positional encoding added (batch, seq_len, d_model)
        """
        # Get sequence length from input
        # Add the appropriate slice of positional encoding
        # Apply dropout
        # Return result
        pe: torch.Tensor = self.pe  # type: ignore[assignment]
        x = x + pe[:, : x.size(1)].requires_grad_(False)
        # self.pe contains pre-computed encodings for all positions
        # just need to add the first seq_len positions to x
        return self.dropout(x)


class LearnedPositionalEncoding(nn.Module):
    """
    Learned positional embeddings (used by BERT, GPT, etc.).

    Note: T5/FLAN-T5 uses relative position bias instead of absolute positional embeddings.
    When loading from T5, the model uses learned positional encodings that train from scratch.

    Args:
        d_model: Dimension of the model embeddings
        max_len: Maximum sequence length
        dropout: Dropout probability
        padding_idx: Index of padding token (used to mask out padding positions if needed)
    """

    def __init__(
        self, d_model: int, max_len: int = 1024, dropout: float = 0.1, padding_idx: int = 1
    ):
        super().__init__()
        # Standard learned positional embeddings.
        # Note: T5's relative position bias is NOT transferred - we train these from scratch.
        self.embeddings = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input embeddings (batch, seq_len, d_model)
        """
        seq_len = x.size(1)
        positions = torch.arange(seq_len, dtype=torch.long, device=x.device)
        # Broadcast to batch
        positions = positions.unsqueeze(0).expand(x.size(0), -1)

        pos_embeds = self.embeddings(positions)
        return self.dropout(x + pos_embeds)