Spaces:
Sleeping
Sleeping
gary-boon
Claude Opus 4.5
commited on
Commit
·
66a46b6
1
Parent(s):
d0b7e29
Add avg_entropy calculation for attention heads
Browse files- Compute normalized attention entropy averaged over query positions
- Normalize by log(k_i) where k_i = number of keys each position can attend to
- Average over latter half of positions for more stable signal
- Return both entropy (last-token) and avg_entropy fields for each head
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
- backend/model_service.py +42 -0
backend/model_service.py
CHANGED
|
@@ -1713,9 +1713,30 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1713 |
max_weight = head_weights.max().item()
|
| 1714 |
entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
|
| 1715 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1716 |
# Sanitize to prevent NaN/Inf in JSON
|
| 1717 |
max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
|
| 1718 |
entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
|
|
|
|
| 1719 |
|
| 1720 |
# Classify pattern
|
| 1721 |
pattern_type = None
|
|
@@ -1759,6 +1780,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1759 |
critical_heads.append({
|
| 1760 |
"head_idx": head_idx,
|
| 1761 |
"entropy": entropy,
|
|
|
|
| 1762 |
"max_weight": max_weight,
|
| 1763 |
"attention_weights": attention_matrix, # Full attention matrix for spreadsheet
|
| 1764 |
"q_matrix": q_matrix, # [seq_len, head_dim]
|
|
@@ -2161,8 +2183,27 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2161 |
max_weight = head_weights.max().item()
|
| 2162 |
entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
|
| 2163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2164 |
max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
|
| 2165 |
entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
|
|
|
|
| 2166 |
|
| 2167 |
pattern_type = None
|
| 2168 |
confidence = 0.0
|
|
@@ -2195,6 +2236,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2195 |
critical_heads.append({
|
| 2196 |
"head_idx": head_idx,
|
| 2197 |
"entropy": entropy,
|
|
|
|
| 2198 |
"max_weight": max_weight,
|
| 2199 |
"attention_weights": attention_matrix,
|
| 2200 |
"q_matrix": q_matrix,
|
|
|
|
| 1713 |
max_weight = head_weights.max().item()
|
| 1714 |
entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
|
| 1715 |
|
| 1716 |
+
# Normalized attention entropy averaged over latter half of query positions
|
| 1717 |
+
# Normalized by log(k_i) where k_i = number of keys position i can attend to
|
| 1718 |
+
# This produces values in [0,1] with better spread across heads
|
| 1719 |
+
# layer_attn[head_idx] shape: [q_len, k_len]
|
| 1720 |
+
head_attn = layer_attn[head_idx] # [q_len, k_len]
|
| 1721 |
+
q_len = head_attn.shape[0]
|
| 1722 |
+
|
| 1723 |
+
# Compute raw entropy per query position
|
| 1724 |
+
token_entropies = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=-1) # [q_len]
|
| 1725 |
+
|
| 1726 |
+
# Normalize by max possible entropy: log(k_i) where k_i = i + 1 (causal mask)
|
| 1727 |
+
# Skip position 0 where log(1) = 0
|
| 1728 |
+
positions = torch.arange(1, q_len + 1, device=head_attn.device, dtype=head_attn.dtype)
|
| 1729 |
+
max_entropies = torch.log(positions + 1e-10) # log(k_i), with epsilon for position 0
|
| 1730 |
+
normalized_entropies = token_entropies / (max_entropies + 1e-10) # [0, 1] range
|
| 1731 |
+
|
| 1732 |
+
# Average over latter half of positions (where there's enough context)
|
| 1733 |
+
start_idx = q_len // 2
|
| 1734 |
+
avg_entropy = normalized_entropies[start_idx:].mean().item() if start_idx < q_len else normalized_entropies.mean().item()
|
| 1735 |
+
|
| 1736 |
# Sanitize to prevent NaN/Inf in JSON
|
| 1737 |
max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
|
| 1738 |
entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
|
| 1739 |
+
avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
|
| 1740 |
|
| 1741 |
# Classify pattern
|
| 1742 |
pattern_type = None
|
|
|
|
| 1780 |
critical_heads.append({
|
| 1781 |
"head_idx": head_idx,
|
| 1782 |
"entropy": entropy,
|
| 1783 |
+
"avg_entropy": avg_entropy, # Averaged over all query positions
|
| 1784 |
"max_weight": max_weight,
|
| 1785 |
"attention_weights": attention_matrix, # Full attention matrix for spreadsheet
|
| 1786 |
"q_matrix": q_matrix, # [seq_len, head_dim]
|
|
|
|
| 2183 |
max_weight = head_weights.max().item()
|
| 2184 |
entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
|
| 2185 |
|
| 2186 |
+
# Normalized attention entropy averaged over latter half of query positions
|
| 2187 |
+
# Normalized by log(k_i) where k_i = number of keys position i can attend to
|
| 2188 |
+
# This produces values in [0,1] with better spread across heads
|
| 2189 |
+
head_attn = layer_attn[head_idx] # [q_len, k_len]
|
| 2190 |
+
q_len = head_attn.shape[0]
|
| 2191 |
+
|
| 2192 |
+
# Compute raw entropy per query position
|
| 2193 |
+
token_entropies = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=-1) # [q_len]
|
| 2194 |
+
|
| 2195 |
+
# Normalize by max possible entropy: log(k_i) where k_i = i + 1 (causal mask)
|
| 2196 |
+
positions = torch.arange(1, q_len + 1, device=head_attn.device, dtype=head_attn.dtype)
|
| 2197 |
+
max_entropies = torch.log(positions + 1e-10)
|
| 2198 |
+
normalized_entropies = token_entropies / (max_entropies + 1e-10) # [0, 1] range
|
| 2199 |
+
|
| 2200 |
+
# Average over latter half of positions
|
| 2201 |
+
start_idx = q_len // 2
|
| 2202 |
+
avg_entropy = normalized_entropies[start_idx:].mean().item() if start_idx < q_len else normalized_entropies.mean().item()
|
| 2203 |
+
|
| 2204 |
max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
|
| 2205 |
entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
|
| 2206 |
+
avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
|
| 2207 |
|
| 2208 |
pattern_type = None
|
| 2209 |
confidence = 0.0
|
|
|
|
| 2236 |
critical_heads.append({
|
| 2237 |
"head_idx": head_idx,
|
| 2238 |
"entropy": entropy,
|
| 2239 |
+
"avg_entropy": avg_entropy, # Averaged over all query positions
|
| 2240 |
"max_weight": max_weight,
|
| 2241 |
"attention_weights": attention_matrix,
|
| 2242 |
"q_matrix": q_matrix,
|