gary-boon Claude Opus 4.5 commited on
Commit
66a46b6
·
1 Parent(s): d0b7e29

Add avg_entropy calculation for attention heads

Browse files

- Compute normalized attention entropy averaged over query positions
- Normalize by log(k_i) where k_i = number of keys each position can attend to
- Average over latter half of positions for more stable signal
- Return both entropy (last-token) and avg_entropy fields for each head

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>

Files changed (1) hide show
  1. backend/model_service.py +42 -0
backend/model_service.py CHANGED
@@ -1713,9 +1713,30 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1713
  max_weight = head_weights.max().item()
1714
  entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1716
  # Sanitize to prevent NaN/Inf in JSON
1717
  max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
1718
  entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
 
1719
 
1720
  # Classify pattern
1721
  pattern_type = None
@@ -1759,6 +1780,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1759
  critical_heads.append({
1760
  "head_idx": head_idx,
1761
  "entropy": entropy,
 
1762
  "max_weight": max_weight,
1763
  "attention_weights": attention_matrix, # Full attention matrix for spreadsheet
1764
  "q_matrix": q_matrix, # [seq_len, head_dim]
@@ -2161,8 +2183,27 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2161
  max_weight = head_weights.max().item()
2162
  entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
2163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2164
  max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
2165
  entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
 
2166
 
2167
  pattern_type = None
2168
  confidence = 0.0
@@ -2195,6 +2236,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2195
  critical_heads.append({
2196
  "head_idx": head_idx,
2197
  "entropy": entropy,
 
2198
  "max_weight": max_weight,
2199
  "attention_weights": attention_matrix,
2200
  "q_matrix": q_matrix,
 
1713
  max_weight = head_weights.max().item()
1714
  entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
1715
 
1716
+ # Normalized attention entropy averaged over latter half of query positions
1717
+ # Normalized by log(k_i) where k_i = number of keys position i can attend to
1718
+ # This produces values in [0,1] with better spread across heads
1719
+ # layer_attn[head_idx] shape: [q_len, k_len]
1720
+ head_attn = layer_attn[head_idx] # [q_len, k_len]
1721
+ q_len = head_attn.shape[0]
1722
+
1723
+ # Compute raw entropy per query position
1724
+ token_entropies = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=-1) # [q_len]
1725
+
1726
+ # Normalize by max possible entropy: log(k_i) where k_i = i + 1 (causal mask)
1727
+ # Skip position 0 where log(1) = 0
1728
+ positions = torch.arange(1, q_len + 1, device=head_attn.device, dtype=head_attn.dtype)
1729
+ max_entropies = torch.log(positions + 1e-10) # log(k_i), with epsilon for position 0
1730
+ normalized_entropies = token_entropies / (max_entropies + 1e-10) # [0, 1] range
1731
+
1732
+ # Average over latter half of positions (where there's enough context)
1733
+ start_idx = q_len // 2
1734
+ avg_entropy = normalized_entropies[start_idx:].mean().item() if start_idx < q_len else normalized_entropies.mean().item()
1735
+
1736
  # Sanitize to prevent NaN/Inf in JSON
1737
  max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
1738
  entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
1739
+ avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
1740
 
1741
  # Classify pattern
1742
  pattern_type = None
 
1780
  critical_heads.append({
1781
  "head_idx": head_idx,
1782
  "entropy": entropy,
1783
+ "avg_entropy": avg_entropy, # Averaged over all query positions
1784
  "max_weight": max_weight,
1785
  "attention_weights": attention_matrix, # Full attention matrix for spreadsheet
1786
  "q_matrix": q_matrix, # [seq_len, head_dim]
 
2183
  max_weight = head_weights.max().item()
2184
  entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item()
2185
 
2186
+ # Normalized attention entropy averaged over latter half of query positions
2187
+ # Normalized by log(k_i) where k_i = number of keys position i can attend to
2188
+ # This produces values in [0,1] with better spread across heads
2189
+ head_attn = layer_attn[head_idx] # [q_len, k_len]
2190
+ q_len = head_attn.shape[0]
2191
+
2192
+ # Compute raw entropy per query position
2193
+ token_entropies = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=-1) # [q_len]
2194
+
2195
+ # Normalize by max possible entropy: log(k_i) where k_i = i + 1 (causal mask)
2196
+ positions = torch.arange(1, q_len + 1, device=head_attn.device, dtype=head_attn.dtype)
2197
+ max_entropies = torch.log(positions + 1e-10)
2198
+ normalized_entropies = token_entropies / (max_entropies + 1e-10) # [0, 1] range
2199
+
2200
+ # Average over latter half of positions
2201
+ start_idx = q_len // 2
2202
+ avg_entropy = normalized_entropies[start_idx:].mean().item() if start_idx < q_len else normalized_entropies.mean().item()
2203
+
2204
  max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight
2205
  entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy
2206
+ avg_entropy = 0.0 if math.isnan(avg_entropy) or math.isinf(avg_entropy) else avg_entropy
2207
 
2208
  pattern_type = None
2209
  confidence = 0.0
 
2236
  critical_heads.append({
2237
  "head_idx": head_idx,
2238
  "entropy": entropy,
2239
+ "avg_entropy": avg_entropy, # Averaged over all query positions
2240
  "max_weight": max_weight,
2241
  "attention_weights": attention_matrix,
2242
  "q_matrix": q_matrix,