LightOnOCR-2-1B-ExecuTorch / scripts /export_vision.py
acul3's picture
Upload scripts/export_vision.py with huggingface_hub
846eac7 verified
#!/usr/bin/env python3
"""
Phase 3a: Vision Encoder Export for ExecuTorch
Extracts vision_encoder + vision_projection into a standalone nn.Module
with fixed-size input for torch.export compatibility.
Fixed resolution: 1120x1540 (snapped to patch_size=14 multiples)
-> patch grid: 80 x 110 = 8800 patches
-> after PatchMerger (2x2): 40 x 55 = 2200 tokens
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
# Fixed image dimensions (must be multiples of patch_size=14)
FIXED_H = 1120 # 1120 / 14 = 80 patches
FIXED_W = 1540 # 1540 / 14 = 110 patches
PATCH_SIZE = 14
SPATIAL_MERGE = 2
# Derived constants
PATCHES_H = FIXED_H // PATCH_SIZE # 80
PATCHES_W = FIXED_W // PATCH_SIZE # 110
NUM_PATCHES = PATCHES_H * PATCHES_W # 8800
MERGED_H = PATCHES_H // SPATIAL_MERGE # 40
MERGED_W = PATCHES_W // SPATIAL_MERGE # 55
NUM_MERGED = MERGED_H * MERGED_W # 2200
MODEL_DIR = "./models/LightOnOCR-2-1B"
class FixedPatchMerger(nn.Module):
"""
Rewritten PatchMerger that works with fixed single-image input.
No Python loops, no dynamic shapes.
Original: loops over variable-size images, dynamic unfold
This: single fixed-size image, vectorized unfold
"""
def __init__(self, hidden_size: int, spatial_merge_size: int = 2):
super().__init__()
self.spatial_merge_size = spatial_merge_size
self.merging_layer = nn.Linear(
hidden_size * spatial_merge_size ** 2, hidden_size, bias=False
)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
"""
Args:
image_features: [num_patches, hidden_size] where num_patches = PATCHES_H * PATCHES_W
Returns:
[num_merged, hidden_size] where num_merged = MERGED_H * MERGED_W
"""
d = image_features.shape[-1]
# Reshape flat patches into spatial grid: [d, H_patches, W_patches]
image_grid = image_features.view(PATCHES_H, PATCHES_W, d).permute(2, 0, 1).unsqueeze(0)
# Use unfold to merge spatial_merge_size x spatial_merge_size patches
# Input: [1, d, 80, 110] -> unfold with kernel=2, stride=2
# Output: [1, d*4, 40*55] = [1, d*4, 2200]
grid = F.unfold(
image_grid,
kernel_size=self.spatial_merge_size,
stride=self.spatial_merge_size
)
# Reshape: [1, d*4, 2200] -> [2200, d*4]
grid = grid.squeeze(0).t()
# Apply merging linear: [2200, d*4] -> [2200, d]
return self.merging_layer(grid)
class FixedMultiModalProjector(nn.Module):
"""Fixed-size multimodal projector (RMSNorm + PatchMerger + MLP)."""
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
spatial_merge_size: int = 2, rms_eps: float = 1e-6):
super().__init__()
self.norm_weight = nn.Parameter(torch.ones(vision_hidden_size))
self.norm_eps = rms_eps
self.patch_merger = FixedPatchMerger(vision_hidden_size, spatial_merge_size)
self.linear_1 = nn.Linear(vision_hidden_size, text_hidden_size, bias=False)
self.linear_2 = nn.Linear(text_hidden_size, text_hidden_size, bias=False)
def _rms_norm(self, x: torch.Tensor) -> torch.Tensor:
"""Inline RMSNorm — avoids @use_kernel_forward_from_hub decorator."""
input_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.norm_eps)
return self.norm_weight * x.to(input_dtype)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
"""
Args:
image_features: [num_patches, vision_hidden_size]
Returns:
[num_merged, text_hidden_size]
"""
image_features = self._rms_norm(image_features)
image_features = self.patch_merger(image_features)
hidden = self.linear_1(image_features)
hidden = F.gelu(hidden)
hidden = self.linear_2(hidden)
return hidden
class VisionEncoderFixed(nn.Module):
"""
Standalone vision encoder for ExecuTorch export.
Wraps PixtralVisionModel + MultiModalProjector with fixed-size input.
Input: pixel_values [1, 3, 1120, 1540]
Output: image_features [1, 2200, 1024]
"""
def __init__(self, vision_encoder, projector):
super().__init__()
# Vision encoder components
self.patch_conv = vision_encoder.patch_conv # Conv2d
self.ln_pre_weight = nn.Parameter(vision_encoder.ln_pre.weight.clone())
self.ln_pre_eps = vision_encoder.ln_pre.variance_epsilon
self.transformer = vision_encoder.transformer # PixtralTransformer
self.rope = vision_encoder.patch_positional_embedding # PixtralRotaryEmbedding
# Fixed projector
self.projector = projector
# Pre-compute position IDs for fixed resolution
max_width = vision_encoder.config.image_size // PATCH_SIZE
self.register_buffer(
"position_ids",
self._compute_fixed_position_ids(PATCHES_H, PATCHES_W, max_width)
)
@staticmethod
def _compute_fixed_position_ids(h: int, w: int, max_width: int) -> torch.Tensor:
"""Pre-compute position IDs for fixed-size image grid."""
mesh = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
ids = h_grid * max_width + v_grid
return ids[:, 0].unsqueeze(0) # [1, num_patches]
def _rms_norm_pre(self, x: torch.Tensor) -> torch.Tensor:
"""Inline RMSNorm for ln_pre."""
input_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.ln_pre_eps)
return self.ln_pre_weight * x.to(input_dtype)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Args:
pixel_values: [1, 3, 1120, 1540]
Returns:
image_features: [1, 2200, 1024]
"""
# Step 1: Patch convolution
# [1, 3, 1120, 1540] -> [1, 1024, 80, 110]
patch_embeds = self.patch_conv(pixel_values)
# Step 2: Flatten to sequence
# [1, 1024, 80, 110] -> [1, 8800, 1024]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
# Step 3: Pre-normalization
patch_embeds = self._rms_norm_pre(patch_embeds)
# Step 4: Compute RoPE position embeddings
position_embeddings = self.rope(patch_embeds, self.position_ids)
# Step 5: Run through transformer (no attention mask needed for single image)
# The block attention mask is identity for single image (all patches attend to all)
outputs = self.transformer(
patch_embeds,
attention_mask=None,
position_embeddings=position_embeddings,
output_hidden_states=True,
output_attentions=False,
return_dict=True,
)
# Step 6: Get last hidden state
# Use last hidden layer (vision_feature_layer=-1)
hidden_states = outputs.hidden_states[-1].squeeze(0) # [8800, 1024]
# Step 7: Project through multimodal projector
image_features = self.projector(hidden_states) # [2200, 1024]
return image_features.unsqueeze(0) # [1, 2200, 1024]
def load_original_model():
"""Load the original model with proper weight remapping."""
from transformers import AutoModelForImageTextToText
from safetensors.torch import load_file
print("Loading original model...")
model = AutoModelForImageTextToText.from_pretrained(
MODEL_DIR,
dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="cpu",
)
# Remap checkpoint keys (LightOnOCR uses different naming)
state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors"))
remapped = {}
for k, v in state_dict.items():
new_k = k.replace("model.vision_encoder.", "model.vision_tower.")
new_k = new_k.replace("model.vision_projection.", "model.multi_modal_projector.")
remapped[new_k] = v
model.load_state_dict(remapped, strict=False)
return model
def build_vision_module(original_model):
"""Build the fixed-size vision module from the original model."""
config = original_model.config
vision_encoder = original_model.model.vision_tower
orig_projector = original_model.model.multi_modal_projector
# Create fixed projector with weights from original
projector = FixedMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
spatial_merge_size=config.spatial_merge_size,
rms_eps=config.text_config.rms_norm_eps,
)
# Copy weights
projector.norm_weight.data.copy_(orig_projector.norm.weight.data)
projector.patch_merger.merging_layer.weight.data.copy_(
orig_projector.patch_merger.merging_layer.weight.data
)
projector.linear_1.weight.data.copy_(orig_projector.linear_1.weight.data)
projector.linear_2.weight.data.copy_(orig_projector.linear_2.weight.data)
# Build the fixed vision module
vision_module = VisionEncoderFixed(vision_encoder, projector)
vision_module.eval()
return vision_module
def test_vision_module(vision_module, original_model):
"""Test that the fixed module produces similar output to the original."""
print("\nTesting vision module output consistency...")
device = "cuda" if torch.cuda.is_available() else "cpu"
vision_module = vision_module.to(device).to(torch.bfloat16)
# Create test input
pixel_values = torch.randn(1, 3, FIXED_H, FIXED_W, dtype=torch.bfloat16, device=device)
with torch.no_grad():
# Run through fixed module
fixed_output = vision_module(pixel_values)
print(f" Fixed module output shape: {fixed_output.shape}")
print(f" Expected: [1, {NUM_MERGED}, {original_model.config.text_config.hidden_size}]")
# Run through original model's vision pipeline for comparison
original_model = original_model.to(device)
image_sizes = torch.tensor([[FIXED_H, FIXED_W]], device=device)
orig_features = original_model.model.get_image_features(
pixel_values=pixel_values,
image_sizes=image_sizes,
vision_feature_layer=-1,
return_dict=True,
)
orig_output = torch.cat(orig_features.pooler_output, dim=0).unsqueeze(0)
print(f" Original model output shape: {orig_output.shape}")
# Compare
if fixed_output.shape == orig_output.shape:
diff = (fixed_output - orig_output).abs()
print(f" Max absolute difference: {diff.max().item():.6f}")
print(f" Mean absolute difference: {diff.mean().item():.6f}")
print(f" Cosine similarity: {F.cosine_similarity(fixed_output.flatten(), orig_output.flatten(), dim=0).item():.6f}")
else:
print(f" Shape mismatch! Fixed: {fixed_output.shape}, Original: {orig_output.shape}")
return fixed_output
def try_torch_export(vision_module):
"""Attempt torch.export.export() on the vision module."""
print("\n" + "=" * 60)
print("ATTEMPTING torch.export.export()")
print("=" * 60)
# Export on CPU with float32 for XNNPACK compatibility
# XNNPACK doesn't support bfloat16 or CUDA SDPA
vision_module = vision_module.to("cpu").to(torch.float32)
vision_module.eval()
example_input = torch.randn(1, 3, FIXED_H, FIXED_W, dtype=torch.float32)
try:
print(" Running torch.export.export() on CPU/float32...")
exported = torch.export.export(
vision_module,
(example_input,),
strict=False, # Allow some Python control flow
)
print(" SUCCESS! torch.export completed!")
return exported
except Exception as e:
print(f" FAILED: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return None
def export_to_pte(exported_model, vision_module, example_input):
"""Convert exported model to .pte using XNNPACK backend."""
print("\n" + "=" * 60)
print("EXPORTING TO .pte (XNNPACK)")
print("=" * 60)
try:
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
if not hasattr(exported_model, 'graph_module'):
print(" Cannot export non-torch.export model to .pte directly")
return None
print(" Running to_edge_transform_and_lower...")
edge = to_edge_transform_and_lower(
exported_model,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
print(" Running to_executorch()...")
pte = edge.to_executorch()
output_path = "vision_encoder.pte"
with open(output_path, "wb") as f:
f.write(pte.buffer)
file_size = os.path.getsize(output_path) / (1024 * 1024)
print(f" Saved to {output_path} ({file_size:.1f} MB)")
return output_path
except ImportError as e:
print(f" ExecuTorch import failed: {e}")
print(" Make sure executorch is properly installed")
return None
except Exception as e:
print(f" Export failed: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return None
def main():
print("=" * 60)
print("Vision Encoder Export for ExecuTorch")
print(f"Fixed resolution: {FIXED_H}x{FIXED_W}")
print(f"Patches: {PATCHES_H}x{PATCHES_W} = {NUM_PATCHES}")
print(f"After merge: {MERGED_H}x{MERGED_W} = {NUM_MERGED}")
print("=" * 60)
# Load original model
original_model = load_original_model()
# Build fixed vision module
print("\nBuilding fixed-size vision module...")
vision_module = build_vision_module(original_model)
print(f" Vision module parameters: {sum(p.numel() for p in vision_module.parameters())/1e6:.2f}M")
# Test consistency
test_vision_module(vision_module, original_model)
# Free original model memory
del original_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Try torch.export
exported = try_torch_export(vision_module)
if exported is not None:
# Try to save as .pte
device = "cuda" if torch.cuda.is_available() else "cpu"
example_input = torch.randn(1, 3, FIXED_H, FIXED_W, dtype=torch.bfloat16, device=device)
export_to_pte(exported, vision_module, example_input)
# Save the PyTorch module for later use
torch.save(vision_module.state_dict(), "vision_encoder_fixed.pt")
print(f"\nSaved fixed vision module state dict to vision_encoder_fixed.pt")
print("Export script complete!")
if __name__ == "__main__":
main()