Antreas commited on
Commit
0be3e32
·
verified ·
1 Parent(s): 418668b

Enable AutoModel loading

Browse files
Files changed (1) hide show
  1. transformer.py +104 -0
transformer.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Variant A: Tiny Transformer — 1-2 layer standard transformer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .configuration_ogma import OgmaConfig
12
+ from .embeddings import RotaryPositionalEncoding, apply_rope
13
+
14
+ __all__ = ["TransformerVariant"]
15
+
16
+
17
+ class SwiGLU(nn.Module):
18
+ """SwiGLU feedforward network."""
19
+
20
+ def __init__(self, d_model: int, d_hidden: int, dropout: float = 0.0) -> None:
21
+ super().__init__()
22
+ self.w1 = nn.Linear(d_model, d_hidden, bias=False)
23
+ self.w2 = nn.Linear(d_model, d_hidden, bias=False)
24
+ self.w3 = nn.Linear(d_hidden, d_model, bias=False)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ out: torch.Tensor = self.dropout(
29
+ self.w3(F.silu(self.w1(x)) * self.w2(x))
30
+ )
31
+ return out
32
+
33
+
34
+ class TransformerLayer(nn.Module):
35
+ """Single transformer encoder layer with RoPE and SwiGLU."""
36
+
37
+ def __init__(self, config: OgmaConfig, rope: RotaryPositionalEncoding) -> None:
38
+ super().__init__()
39
+ self.n_heads = config.n_heads
40
+ self.d_head = config.d_head
41
+ self.rope = rope
42
+
43
+ self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
44
+ self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
45
+ self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
46
+ self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
47
+
48
+ self.norm1 = nn.LayerNorm(config.d_model)
49
+ self.norm2 = nn.LayerNorm(config.d_model)
50
+ self.ffn = SwiGLU(config.d_model, config.ffn_hidden, config.dropout)
51
+ self.attn_dropout = nn.Dropout(config.dropout)
52
+
53
+ def forward(
54
+ self,
55
+ x: torch.Tensor,
56
+ attention_mask: torch.Tensor | None = None,
57
+ ) -> torch.Tensor:
58
+ B, S, D = x.shape
59
+
60
+ # Pre-norm attention
61
+ h = self.norm1(x)
62
+ q = self.q_proj(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
63
+ k = self.k_proj(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
64
+ v = self.v_proj(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
65
+
66
+ cos, sin = self.rope(h)
67
+ q, k = apply_rope(q, k, cos, sin)
68
+
69
+ scale = 1.0 / math.sqrt(self.d_head)
70
+ attn = torch.matmul(q, k.transpose(-2, -1)) * scale
71
+
72
+ if attention_mask is not None:
73
+ # attention_mask: (B, S) -> (B, 1, 1, S) for broadcasting
74
+ mask = attention_mask.unsqueeze(1).unsqueeze(2)
75
+ attn = attn.masked_fill(mask == 0, float("-inf"))
76
+
77
+ attn = self.attn_dropout(F.softmax(attn, dim=-1))
78
+ out = torch.matmul(attn, v)
79
+ out = out.transpose(1, 2).contiguous().view(B, S, D)
80
+ x = x + self.o_proj(out)
81
+
82
+ # Pre-norm FFN
83
+ x = x + self.ffn(self.norm2(x))
84
+ return x
85
+
86
+
87
+ class TransformerVariant(nn.Module):
88
+ """Variant A: 1-2 layer transformer encoder with RoPE and SwiGLU."""
89
+
90
+ def __init__(self, config: OgmaConfig) -> None:
91
+ super().__init__()
92
+ rope = RotaryPositionalEncoding(config.d_head, config.max_seq_len + 1)
93
+ self.layers = nn.ModuleList(
94
+ [TransformerLayer(config, rope) for _ in range(config.n_layers)]
95
+ )
96
+
97
+ def forward(
98
+ self,
99
+ x: torch.Tensor,
100
+ attention_mask: torch.Tensor | None = None,
101
+ ) -> torch.Tensor:
102
+ for layer in self.layers:
103
+ x = layer(x, attention_mask)
104
+ return x