#!/usr/bin/env python3 """ Phase 3b: Text Decoder Export for ExecuTorch Extracts language_model + lm_head into a standalone nn.Module with static KV cache tensors for torch.export compatibility. Architecture: Qwen3 decoder (28 layers, GQA 16/8 heads, head_dim=128) Fixed max_seq_len: 512 """ import os import sys import math import torch import torch.nn as nn import torch.nn.functional as F # Model constants from config HIDDEN_SIZE = 1024 NUM_LAYERS = 28 NUM_HEADS = 16 NUM_KV_HEADS = 8 HEAD_DIM = 128 INTERMEDIATE_SIZE = 3072 VOCAB_SIZE = 151936 MAX_SEQ_LEN = 4096 RMS_EPS = 1e-6 ROPE_THETA = 1000000.0 NUM_KV_GROUPS = NUM_HEADS // NUM_KV_HEADS # 2 MODEL_DIR = "./models/LightOnOCR-2-1B" def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = RMS_EPS) -> torch.Tensor: """Inline RMSNorm — avoids @use_kernel_forward_from_hub decorator.""" input_dtype = x.dtype x = x.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + eps) return weight * x.to(input_dtype) def precompute_rope_freqs(max_seq_len: int, head_dim: int, theta: float = ROPE_THETA): """Precompute RoPE cos/sin for all positions up to max_seq_len.""" freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) t = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.outer(t, freqs) cos = freqs.cos() sin = freqs.sin() # Duplicate for full head_dim: [seq_len, head_dim/2] -> [seq_len, head_dim] cos = torch.cat([cos, cos], dim=-1) sin = torch.cat([sin, sin], dim=-1) return cos, sin # [max_seq_len, head_dim] def apply_rotary_pos_emb(q, k, cos, sin, position_ids): """ Apply rotary position embeddings to query and key states. q, k: [batch, num_heads, seq_len, head_dim] cos, sin: [max_seq_len, head_dim] position_ids: [batch, seq_len] """ # Gather cos/sin for the given positions cos = cos[position_ids].unsqueeze(1) # [batch, 1, seq_len, head_dim] sin = sin[position_ids].unsqueeze(1) # [batch, 1, seq_len, head_dim] # Rotate q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) class Qwen3AttentionFixed(nn.Module): """ Fixed Qwen3 attention with static KV cache, inline QK-norm, and no dynamic dispatch. Designed for torch.export compatibility. """ def __init__(self, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.scaling = HEAD_DIM ** -0.5 # Projections self.q_proj = nn.Linear(HIDDEN_SIZE, NUM_HEADS * HEAD_DIM, bias=False) self.k_proj = nn.Linear(HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False) self.v_proj = nn.Linear(HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False) self.o_proj = nn.Linear(NUM_HEADS * HEAD_DIM, HIDDEN_SIZE, bias=False) # QK-norm weights (RMSNorm per head) self.q_norm_weight = nn.Parameter(torch.ones(HEAD_DIM)) self.k_norm_weight = nn.Parameter(torch.ones(HEAD_DIM)) def forward( self, hidden_states: torch.Tensor, # [batch, seq_len, hidden_size] cos: torch.Tensor, # [max_seq_len, head_dim] sin: torch.Tensor, # [max_seq_len, head_dim] position_ids: torch.Tensor, # [batch, seq_len] attention_mask: torch.Tensor, # [batch, 1, seq_len, cache_len+seq_len] k_cache: torch.Tensor, # [batch, num_kv_heads, max_seq_len, head_dim] v_cache: torch.Tensor, # [batch, num_kv_heads, max_seq_len, head_dim] cache_position: torch.Tensor, # [seq_len] — positions to write into cache ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Returns (output, updated_k_cache, updated_v_cache)""" batch, seq_len, _ = hidden_states.shape # Project Q, K, V q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) # Reshape: [batch, seq_len, num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim] q = q.view(batch, seq_len, NUM_HEADS, HEAD_DIM) k = k.view(batch, seq_len, NUM_KV_HEADS, HEAD_DIM) v = v.view(batch, seq_len, NUM_KV_HEADS, HEAD_DIM) # Apply QK-norm (RMSNorm per head, inline) q = rms_norm(q, self.q_norm_weight) k = rms_norm(k, self.k_norm_weight) q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] k = k.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] v = v.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] # Apply RoPE q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) # Update KV cache using scatter (index_put) # cache_position: [seq_len] — the positions to update # k_cache shape: [batch, num_kv_heads, max_seq_len, head_dim] k_cache = k_cache.clone() v_cache = v_cache.clone() k_cache[:, :, cache_position, :] = k v_cache[:, :, cache_position, :] = v # Expand KV heads for GQA: repeat each KV head for its group of Q heads cache_len = k_cache.shape[2] # dynamic, works for any MAX_SEQ_LEN k_expanded = k_cache.unsqueeze(2).expand(-1, -1, NUM_KV_GROUPS, -1, -1) k_expanded = k_expanded.reshape(batch, NUM_HEADS, cache_len, HEAD_DIM) v_expanded = v_cache.unsqueeze(2).expand(-1, -1, NUM_KV_GROUPS, -1, -1) v_expanded = v_expanded.reshape(batch, NUM_HEADS, cache_len, HEAD_DIM) # Attention: Q @ K^T / sqrt(head_dim) attn_weights = torch.matmul(q, k_expanded.transpose(2, 3)) * self.scaling # Apply attention mask attn_weights = attn_weights + attention_mask # Softmax attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) # Attention output attn_output = torch.matmul(attn_weights, v_expanded) # Reshape back: [batch, num_heads, seq_len, head_dim] -> [batch, seq_len, hidden_size] attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch, seq_len, -1) # Output projection attn_output = self.o_proj(attn_output) return attn_output, k_cache, v_cache class Qwen3MLPFixed(nn.Module): """Fixed Qwen3 MLP (SiLU gate + up projection).""" def __init__(self): super().__init__() self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False) self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False) self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class Qwen3DecoderLayerFixed(nn.Module): """Fixed Qwen3 decoder layer with static KV cache.""" def __init__(self, layer_idx: int): super().__init__() self.self_attn = Qwen3AttentionFixed(layer_idx) self.mlp = Qwen3MLPFixed() self.input_layernorm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE)) self.post_attention_layernorm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE)) def forward( self, hidden_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, cache_position: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Pre-norm + self attention residual = hidden_states hidden_states = rms_norm(hidden_states, self.input_layernorm_weight) hidden_states, k_cache, v_cache = self.self_attn( hidden_states, cos, sin, position_ids, attention_mask, k_cache, v_cache, cache_position ) hidden_states = residual + hidden_states # Pre-norm + MLP residual = hidden_states hidden_states = rms_norm(hidden_states, self.post_attention_layernorm_weight) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states, k_cache, v_cache class TextDecoderFixed(nn.Module): """ Complete text decoder for ExecuTorch export. Includes embedding, all decoder layers with static KV cache, and LM head. For prefill: input_ids has seq_len > 1, cache_position starts at 0 For decode: input_ids has seq_len = 1, cache_position = current position """ def __init__(self): super().__init__() self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) self.layers = nn.ModuleList([ Qwen3DecoderLayerFixed(i) for i in range(NUM_LAYERS) ]) self.norm_weight = nn.Parameter(torch.ones(HIDDEN_SIZE)) self.lm_head = nn.Linear(HIDDEN_SIZE, VOCAB_SIZE, bias=False) # Pre-compute RoPE frequencies cos, sin = precompute_rope_freqs(MAX_SEQ_LEN, HEAD_DIM, ROPE_THETA) self.register_buffer("rope_cos", cos) self.register_buffer("rope_sin", sin) def forward( self, input_ids: torch.Tensor, # [batch, seq_len] attention_mask: torch.Tensor, # [batch, 1, seq_len, max_seq_len] position_ids: torch.Tensor, # [batch, seq_len] cache_position: torch.Tensor, # [seq_len] *kv_caches: torch.Tensor, # 28 * (k_cache, v_cache) flattened ) -> tuple: """ Returns: (logits, *updated_kv_caches) kv_caches: 56 tensors total (28 layers * 2 for k,v) Each cache: [batch, num_kv_heads, max_seq_len, head_dim] """ # Embed tokens hidden_states = self.embed_tokens(input_ids) # Process through all layers, updating KV caches updated_caches = [] for i, layer in enumerate(self.layers): k_cache = kv_caches[i * 2] v_cache = kv_caches[i * 2 + 1] hidden_states, new_k, new_v = layer( hidden_states, self.rope_cos, self.rope_sin, position_ids, attention_mask, k_cache, v_cache, cache_position ) updated_caches.append(new_k) updated_caches.append(new_v) # Final norm hidden_states = rms_norm(hidden_states, self.norm_weight) # LM head — only compute logits for the last token logits = self.lm_head(hidden_states[:, -1:, :]) # [batch, 1, vocab_size] return (logits, *updated_caches) def load_original_model(): """Load the original model with proper weight remapping.""" from transformers import AutoModelForImageTextToText from safetensors.torch import load_file print("Loading original model...") model = AutoModelForImageTextToText.from_pretrained( MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu", ) state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors")) remapped = {} for k, v in state_dict.items(): new_k = k.replace("model.vision_encoder.", "model.vision_tower.") new_k = new_k.replace("model.vision_projection.", "model.multi_modal_projector.") remapped[new_k] = v model.load_state_dict(remapped, strict=False) return model def build_decoder_module(original_model): """Build the fixed decoder module from the original model's weights.""" print("\nBuilding fixed text decoder...") orig_lm = original_model.model.language_model orig_lm_head = original_model.lm_head decoder = TextDecoderFixed() # Copy embedding weights decoder.embed_tokens.weight.data.copy_(orig_lm.embed_tokens.weight.data) # Copy final norm weight decoder.norm_weight.data.copy_(orig_lm.norm.weight.data) # Copy LM head (tied with embeddings) decoder.lm_head.weight.data.copy_(orig_lm.embed_tokens.weight.data) # Copy layer weights for i in range(NUM_LAYERS): orig_layer = orig_lm.layers[i] fixed_layer = decoder.layers[i] # Attention projections fixed_layer.self_attn.q_proj.weight.data.copy_(orig_layer.self_attn.q_proj.weight.data) fixed_layer.self_attn.k_proj.weight.data.copy_(orig_layer.self_attn.k_proj.weight.data) fixed_layer.self_attn.v_proj.weight.data.copy_(orig_layer.self_attn.v_proj.weight.data) fixed_layer.self_attn.o_proj.weight.data.copy_(orig_layer.self_attn.o_proj.weight.data) # QK-norm weights fixed_layer.self_attn.q_norm_weight.data.copy_(orig_layer.self_attn.q_norm.weight.data) fixed_layer.self_attn.k_norm_weight.data.copy_(orig_layer.self_attn.k_norm.weight.data) # Layer norms fixed_layer.input_layernorm_weight.data.copy_(orig_layer.input_layernorm.weight.data) fixed_layer.post_attention_layernorm_weight.data.copy_(orig_layer.post_attention_layernorm.weight.data) # MLP fixed_layer.mlp.gate_proj.weight.data.copy_(orig_layer.mlp.gate_proj.weight.data) fixed_layer.mlp.up_proj.weight.data.copy_(orig_layer.mlp.up_proj.weight.data) fixed_layer.mlp.down_proj.weight.data.copy_(orig_layer.mlp.down_proj.weight.data) decoder.eval() total_params = sum(p.numel() for p in decoder.parameters()) print(f" Decoder parameters: {total_params/1e6:.2f}M") return decoder def create_empty_kv_caches(batch_size: int = 1, dtype=torch.float32, device="cpu"): """Create empty KV cache tensors for all layers.""" caches = [] for _ in range(NUM_LAYERS): k = torch.zeros(batch_size, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) v = torch.zeros(batch_size, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) caches.extend([k, v]) return tuple(caches) def create_causal_mask(seq_len: int, cache_len: int = MAX_SEQ_LEN, dtype=torch.float32): """Create causal attention mask.""" mask = torch.full((seq_len, cache_len), float("-inf"), dtype=dtype) mask = torch.triu(mask, diagonal=cache_len - seq_len + 1) return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, cache_len] def test_decoder_module(decoder, original_model): """Test that the fixed decoder produces same output as original.""" print("\nTesting decoder output consistency...") device = "cuda" if torch.cuda.is_available() else "cpu" decoder = decoder.to(device).to(torch.bfloat16) original_model = original_model.to(device) # Test input input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=device) seq_len = input_ids.shape[1] position_ids = torch.arange(seq_len, device=device).unsqueeze(0) cache_position = torch.arange(seq_len, device=device) # Causal mask mask = create_causal_mask(seq_len, dtype=torch.bfloat16).to(device) # Empty KV caches kv_caches = create_empty_kv_caches(1, torch.bfloat16, device) with torch.no_grad(): # Fixed decoder result = decoder(input_ids, mask, position_ids, cache_position, *kv_caches) fixed_logits = result[0] print(f" Fixed decoder output shape: {fixed_logits.shape}") # Original model (text-only, no image) orig_outputs = original_model( input_ids=input_ids, attention_mask=torch.ones_like(input_ids), use_cache=False, ) orig_logits = orig_outputs.logits[:, -1:, :] print(f" Original model output shape: {orig_logits.shape}") # Compare diff = (fixed_logits.float() - orig_logits.float()).abs() print(f" Max absolute difference: {diff.max().item():.6f}") print(f" Mean absolute difference: {diff.mean().item():.6f}") # Check top-k predictions match fixed_topk = fixed_logits.float().topk(5, dim=-1) orig_topk = orig_logits.float().topk(5, dim=-1) print(f" Fixed top-5 token IDs: {fixed_topk.indices[0, 0].tolist()}") print(f" Original top-5 token IDs: {orig_topk.indices[0, 0].tolist()}") matching = sum(1 for t in fixed_topk.indices[0, 0].tolist() if t in orig_topk.indices[0, 0].tolist()) print(f" Top-5 overlap: {matching}/5") def try_torch_export(decoder): """Attempt torch.export.export() on the decoder.""" print("\n" + "=" * 60) print("ATTEMPTING torch.export.export() on decoder") print("=" * 60) # Export on CPU with float32 for XNNPACK decoder = decoder.to("cpu").to(torch.float32) decoder.eval() batch_size = 1 seq_len = 1 # Export for single-token decode step (simpler) input_ids = torch.randint(0, VOCAB_SIZE, (batch_size, seq_len)) attention_mask = create_causal_mask(seq_len, MAX_SEQ_LEN, torch.float32) position_ids = torch.zeros(batch_size, seq_len, dtype=torch.long) cache_position = torch.zeros(seq_len, dtype=torch.long) kv_caches = create_empty_kv_caches(batch_size, torch.float32, "cpu") example_args = (input_ids, attention_mask, position_ids, cache_position, *kv_caches) try: print(f" Exporting with seq_len={seq_len}, max_cache={MAX_SEQ_LEN}...") print(f" Number of input tensors: {len(example_args)} (4 + {NUM_LAYERS}*2 KV caches)") exported = torch.export.export( decoder, example_args, strict=False, ) print(" SUCCESS! torch.export completed!") return exported except Exception as e: print(f" FAILED: {type(e).__name__}: {e}") import traceback traceback.print_exc() # Try with trace as fallback print("\n Trying torch.jit.trace as fallback...") try: traced = torch.jit.trace(decoder, example_args) print(" torch.jit.trace succeeded!") return traced except Exception as e2: print(f" torch.jit.trace also failed: {type(e2).__name__}: {e2}") return None def export_to_pte(exported_model): """Convert exported model to .pte using XNNPACK backend.""" print("\n" + "=" * 60) print("EXPORTING DECODER TO .pte (XNNPACK)") print("=" * 60) try: from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner if not hasattr(exported_model, 'graph_module'): print(" Need torch.export.export() result for .pte export") return None print(" Running to_edge_transform_and_lower...") edge = to_edge_transform_and_lower( exported_model, compile_config=EdgeCompileConfig(_check_ir_validity=False), partitioner=[XnnpackPartitioner()], ) print(" Running to_executorch()...") pte = edge.to_executorch() output_path = "text_decoder.pte" with open(output_path, "wb") as f: f.write(pte.buffer) file_size = os.path.getsize(output_path) / (1024 * 1024) print(f" Saved to {output_path} ({file_size:.1f} MB)") return output_path except ImportError as e: print(f" ExecuTorch import failed: {e}") return None except Exception as e: print(f" Export failed: {type(e).__name__}: {e}") import traceback traceback.print_exc() return None def main(): print("=" * 60) print("Text Decoder Export for ExecuTorch") print(f"Architecture: Qwen3 {NUM_LAYERS}L, {NUM_HEADS}H/{NUM_KV_HEADS}KV, dim={HIDDEN_SIZE}") print(f"Max seq len: {MAX_SEQ_LEN}") print(f"KV cache size per layer: {NUM_KV_HEADS}x{MAX_SEQ_LEN}x{HEAD_DIM} = {NUM_KV_HEADS*MAX_SEQ_LEN*HEAD_DIM/1e6:.2f}M elements") print("=" * 60) # Load original model original_model = load_original_model() # Build fixed decoder decoder = build_decoder_module(original_model) # Test consistency test_decoder_module(decoder, original_model) # Free original model memory del original_model torch.cuda.empty_cache() if torch.cuda.is_available() else None # Try torch.export exported = try_torch_export(decoder) if exported is not None: export_to_pte(exported) # Save the PyTorch module for later use torch.save(decoder.state_dict(), "text_decoder_fixed.pt") print(f"\nSaved fixed decoder state dict to text_decoder_fixed.pt") print("Decoder export script complete!") if __name__ == "__main__": main()