#!/usr/bin/env python3 """ Phase 3a: Vision Encoder Export for ExecuTorch Extracts vision_encoder + vision_projection into a standalone nn.Module with fixed-size input for torch.export compatibility. Fixed resolution: 1120x1540 (snapped to patch_size=14 multiples) -> patch grid: 80 x 110 = 8800 patches -> after PatchMerger (2x2): 40 x 55 = 2200 tokens """ import os import sys import torch import torch.nn as nn import torch.nn.functional as F # Fixed image dimensions (must be multiples of patch_size=14) FIXED_H = 1120 # 1120 / 14 = 80 patches FIXED_W = 1540 # 1540 / 14 = 110 patches PATCH_SIZE = 14 SPATIAL_MERGE = 2 # Derived constants PATCHES_H = FIXED_H // PATCH_SIZE # 80 PATCHES_W = FIXED_W // PATCH_SIZE # 110 NUM_PATCHES = PATCHES_H * PATCHES_W # 8800 MERGED_H = PATCHES_H // SPATIAL_MERGE # 40 MERGED_W = PATCHES_W // SPATIAL_MERGE # 55 NUM_MERGED = MERGED_H * MERGED_W # 2200 MODEL_DIR = "./models/LightOnOCR-2-1B" class FixedPatchMerger(nn.Module): """ Rewritten PatchMerger that works with fixed single-image input. No Python loops, no dynamic shapes. Original: loops over variable-size images, dynamic unfold This: single fixed-size image, vectorized unfold """ def __init__(self, hidden_size: int, spatial_merge_size: int = 2): super().__init__() self.spatial_merge_size = spatial_merge_size self.merging_layer = nn.Linear( hidden_size * spatial_merge_size ** 2, hidden_size, bias=False ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: """ Args: image_features: [num_patches, hidden_size] where num_patches = PATCHES_H * PATCHES_W Returns: [num_merged, hidden_size] where num_merged = MERGED_H * MERGED_W """ d = image_features.shape[-1] # Reshape flat patches into spatial grid: [d, H_patches, W_patches] image_grid = image_features.view(PATCHES_H, PATCHES_W, d).permute(2, 0, 1).unsqueeze(0) # Use unfold to merge spatial_merge_size x spatial_merge_size patches # Input: [1, d, 80, 110] -> unfold with kernel=2, stride=2 # Output: [1, d*4, 40*55] = [1, d*4, 2200] grid = F.unfold( image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size ) # Reshape: [1, d*4, 2200] -> [2200, d*4] grid = grid.squeeze(0).t() # Apply merging linear: [2200, d*4] -> [2200, d] return self.merging_layer(grid) class FixedMultiModalProjector(nn.Module): """Fixed-size multimodal projector (RMSNorm + PatchMerger + MLP).""" def __init__(self, vision_hidden_size: int, text_hidden_size: int, spatial_merge_size: int = 2, rms_eps: float = 1e-6): super().__init__() self.norm_weight = nn.Parameter(torch.ones(vision_hidden_size)) self.norm_eps = rms_eps self.patch_merger = FixedPatchMerger(vision_hidden_size, spatial_merge_size) self.linear_1 = nn.Linear(vision_hidden_size, text_hidden_size, bias=False) self.linear_2 = nn.Linear(text_hidden_size, text_hidden_size, bias=False) def _rms_norm(self, x: torch.Tensor) -> torch.Tensor: """Inline RMSNorm — avoids @use_kernel_forward_from_hub decorator.""" input_dtype = x.dtype x = x.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.norm_eps) return self.norm_weight * x.to(input_dtype) def forward(self, image_features: torch.Tensor) -> torch.Tensor: """ Args: image_features: [num_patches, vision_hidden_size] Returns: [num_merged, text_hidden_size] """ image_features = self._rms_norm(image_features) image_features = self.patch_merger(image_features) hidden = self.linear_1(image_features) hidden = F.gelu(hidden) hidden = self.linear_2(hidden) return hidden class VisionEncoderFixed(nn.Module): """ Standalone vision encoder for ExecuTorch export. Wraps PixtralVisionModel + MultiModalProjector with fixed-size input. Input: pixel_values [1, 3, 1120, 1540] Output: image_features [1, 2200, 1024] """ def __init__(self, vision_encoder, projector): super().__init__() # Vision encoder components self.patch_conv = vision_encoder.patch_conv # Conv2d self.ln_pre_weight = nn.Parameter(vision_encoder.ln_pre.weight.clone()) self.ln_pre_eps = vision_encoder.ln_pre.variance_epsilon self.transformer = vision_encoder.transformer # PixtralTransformer self.rope = vision_encoder.patch_positional_embedding # PixtralRotaryEmbedding # Fixed projector self.projector = projector # Pre-compute position IDs for fixed resolution max_width = vision_encoder.config.image_size // PATCH_SIZE self.register_buffer( "position_ids", self._compute_fixed_position_ids(PATCHES_H, PATCHES_W, max_width) ) @staticmethod def _compute_fixed_position_ids(h: int, w: int, max_width: int) -> torch.Tensor: """Pre-compute position IDs for fixed-size image grid.""" mesh = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) ids = h_grid * max_width + v_grid return ids[:, 0].unsqueeze(0) # [1, num_patches] def _rms_norm_pre(self, x: torch.Tensor) -> torch.Tensor: """Inline RMSNorm for ln_pre.""" input_dtype = x.dtype x = x.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.ln_pre_eps) return self.ln_pre_weight * x.to(input_dtype) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Args: pixel_values: [1, 3, 1120, 1540] Returns: image_features: [1, 2200, 1024] """ # Step 1: Patch convolution # [1, 3, 1120, 1540] -> [1, 1024, 80, 110] patch_embeds = self.patch_conv(pixel_values) # Step 2: Flatten to sequence # [1, 1024, 80, 110] -> [1, 8800, 1024] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) # Step 3: Pre-normalization patch_embeds = self._rms_norm_pre(patch_embeds) # Step 4: Compute RoPE position embeddings position_embeddings = self.rope(patch_embeds, self.position_ids) # Step 5: Run through transformer (no attention mask needed for single image) # The block attention mask is identity for single image (all patches attend to all) outputs = self.transformer( patch_embeds, attention_mask=None, position_embeddings=position_embeddings, output_hidden_states=True, output_attentions=False, return_dict=True, ) # Step 6: Get last hidden state # Use last hidden layer (vision_feature_layer=-1) hidden_states = outputs.hidden_states[-1].squeeze(0) # [8800, 1024] # Step 7: Project through multimodal projector image_features = self.projector(hidden_states) # [2200, 1024] return image_features.unsqueeze(0) # [1, 2200, 1024] def load_original_model(): """Load the original model with proper weight remapping.""" from transformers import AutoModelForImageTextToText from safetensors.torch import load_file print("Loading original model...") model = AutoModelForImageTextToText.from_pretrained( MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu", ) # Remap checkpoint keys (LightOnOCR uses different naming) state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors")) remapped = {} for k, v in state_dict.items(): new_k = k.replace("model.vision_encoder.", "model.vision_tower.") new_k = new_k.replace("model.vision_projection.", "model.multi_modal_projector.") remapped[new_k] = v model.load_state_dict(remapped, strict=False) return model def build_vision_module(original_model): """Build the fixed-size vision module from the original model.""" config = original_model.config vision_encoder = original_model.model.vision_tower orig_projector = original_model.model.multi_modal_projector # Create fixed projector with weights from original projector = FixedMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, spatial_merge_size=config.spatial_merge_size, rms_eps=config.text_config.rms_norm_eps, ) # Copy weights projector.norm_weight.data.copy_(orig_projector.norm.weight.data) projector.patch_merger.merging_layer.weight.data.copy_( orig_projector.patch_merger.merging_layer.weight.data ) projector.linear_1.weight.data.copy_(orig_projector.linear_1.weight.data) projector.linear_2.weight.data.copy_(orig_projector.linear_2.weight.data) # Build the fixed vision module vision_module = VisionEncoderFixed(vision_encoder, projector) vision_module.eval() return vision_module def test_vision_module(vision_module, original_model): """Test that the fixed module produces similar output to the original.""" print("\nTesting vision module output consistency...") device = "cuda" if torch.cuda.is_available() else "cpu" vision_module = vision_module.to(device).to(torch.bfloat16) # Create test input pixel_values = torch.randn(1, 3, FIXED_H, FIXED_W, dtype=torch.bfloat16, device=device) with torch.no_grad(): # Run through fixed module fixed_output = vision_module(pixel_values) print(f" Fixed module output shape: {fixed_output.shape}") print(f" Expected: [1, {NUM_MERGED}, {original_model.config.text_config.hidden_size}]") # Run through original model's vision pipeline for comparison original_model = original_model.to(device) image_sizes = torch.tensor([[FIXED_H, FIXED_W]], device=device) orig_features = original_model.model.get_image_features( pixel_values=pixel_values, image_sizes=image_sizes, vision_feature_layer=-1, return_dict=True, ) orig_output = torch.cat(orig_features.pooler_output, dim=0).unsqueeze(0) print(f" Original model output shape: {orig_output.shape}") # Compare if fixed_output.shape == orig_output.shape: diff = (fixed_output - orig_output).abs() print(f" Max absolute difference: {diff.max().item():.6f}") print(f" Mean absolute difference: {diff.mean().item():.6f}") print(f" Cosine similarity: {F.cosine_similarity(fixed_output.flatten(), orig_output.flatten(), dim=0).item():.6f}") else: print(f" Shape mismatch! Fixed: {fixed_output.shape}, Original: {orig_output.shape}") return fixed_output def try_torch_export(vision_module): """Attempt torch.export.export() on the vision module.""" print("\n" + "=" * 60) print("ATTEMPTING torch.export.export()") print("=" * 60) # Export on CPU with float32 for XNNPACK compatibility # XNNPACK doesn't support bfloat16 or CUDA SDPA vision_module = vision_module.to("cpu").to(torch.float32) vision_module.eval() example_input = torch.randn(1, 3, FIXED_H, FIXED_W, dtype=torch.float32) try: print(" Running torch.export.export() on CPU/float32...") exported = torch.export.export( vision_module, (example_input,), strict=False, # Allow some Python control flow ) print(" SUCCESS! torch.export completed!") return exported except Exception as e: print(f" FAILED: {type(e).__name__}: {e}") import traceback traceback.print_exc() return None def export_to_pte(exported_model, vision_module, example_input): """Convert exported model to .pte using XNNPACK backend.""" print("\n" + "=" * 60) print("EXPORTING TO .pte (XNNPACK)") print("=" * 60) try: from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner if not hasattr(exported_model, 'graph_module'): print(" Cannot export non-torch.export model to .pte directly") return None print(" Running to_edge_transform_and_lower...") edge = to_edge_transform_and_lower( exported_model, compile_config=EdgeCompileConfig(_check_ir_validity=False), partitioner=[XnnpackPartitioner()], ) print(" Running to_executorch()...") pte = edge.to_executorch() output_path = "vision_encoder.pte" with open(output_path, "wb") as f: f.write(pte.buffer) file_size = os.path.getsize(output_path) / (1024 * 1024) print(f" Saved to {output_path} ({file_size:.1f} MB)") return output_path except ImportError as e: print(f" ExecuTorch import failed: {e}") print(" Make sure executorch is properly installed") return None except Exception as e: print(f" Export failed: {type(e).__name__}: {e}") import traceback traceback.print_exc() return None def main(): print("=" * 60) print("Vision Encoder Export for ExecuTorch") print(f"Fixed resolution: {FIXED_H}x{FIXED_W}") print(f"Patches: {PATCHES_H}x{PATCHES_W} = {NUM_PATCHES}") print(f"After merge: {MERGED_H}x{MERGED_W} = {NUM_MERGED}") print("=" * 60) # Load original model original_model = load_original_model() # Build fixed vision module print("\nBuilding fixed-size vision module...") vision_module = build_vision_module(original_model) print(f" Vision module parameters: {sum(p.numel() for p in vision_module.parameters())/1e6:.2f}M") # Test consistency test_vision_module(vision_module, original_model) # Free original model memory del original_model torch.cuda.empty_cache() if torch.cuda.is_available() else None # Try torch.export exported = try_torch_export(vision_module) if exported is not None: # Try to save as .pte device = "cuda" if torch.cuda.is_available() else "cpu" example_input = torch.randn(1, 3, FIXED_H, FIXED_W, dtype=torch.bfloat16, device=device) export_to_pte(exported, vision_module, example_input) # Save the PyTorch module for later use torch.save(vision_module.state_dict(), "vision_encoder_fixed.pt") print(f"\nSaved fixed vision module state dict to vision_encoder_fixed.pt") print("Export script complete!") if __name__ == "__main__": main()