#!/usr/bin/env python3 """ E2E Validation v2: Fixed resolution pipeline. Forces all images to 1120x1540 so token count matches fixed vision encoder output. Compares HF model output vs fixed PyTorch modules output. """ import os, sys, time, torch, torch.nn.functional as F from pathlib import Path from PIL import Image # ── Constants ── MODEL_DIR = "./models/LightOnOCR-2-1B" FIXED_H, FIXED_W = 1120, 1540 PATCH_SIZE = 14 SPATIAL_MERGE = 2 MERGED_H = FIXED_H // PATCH_SIZE // SPATIAL_MERGE # 40 MERGED_W = FIXED_W // PATCH_SIZE // SPATIAL_MERGE # 55 NUM_IMG_TOKENS = MERGED_H * MERGED_W # 2200 NUM_VPAD_TOKENS = MERGED_H - 1 # 39 # Token IDs IMAGE_TOKEN_ID = 151655 VISION_PAD_ID = 151654 VISION_END_ID = 151653 IM_START_ID = 151644 IM_END_ID = 151645 EOS_TOKEN_ID = 151645 # Decoder constants NUM_LAYERS = 28 NUM_KV_HEADS = 8 NUM_HEADS = 16 HEAD_DIM = 128 HIDDEN_SIZE = 1024 MAX_SEQ_LEN = 4096 # Large enough for 2256 input + generation def download_test_images(): """Get real test images.""" import requests os.makedirs("test_images", exist_ok=True) images = {} sources = { "receipt": "https://huggingface.co/datasets/hf-internal-testing/fixtures_ocr/resolve/main/SROIE-receipt.jpeg", } for name, url in sources.items(): path = f"test_images/{name}.png" if not os.path.exists(path): print(f" Downloading {name}...") try: resp = requests.get(url, timeout=30) resp.raise_for_status() with open(path, "wb") as f: f.write(resp.content) except Exception as e: print(f" FAILED: {e}") continue img = Image.open(path).convert("RGB") images[name] = img print(f" {name}: {img.size}") # Synthetic doc img = Image.new("RGB", (800, 600), "white") from PIL import ImageDraw draw = ImageDraw.Draw(img) draw.text((50, 50), "Invoice #12345", fill="black") draw.text((50, 100), "Date: 2024-01-15", fill="black") draw.text((50, 150), "Item 1: Widget x5 @ $10.00 = $50.00", fill="black") draw.text((50, 200), "Item 2: Gadget x2 @ $24.99 = $49.98", fill="black") draw.text((50, 250), "Total: $99.98", fill="black") img.save("test_images/synthetic.png") images["synthetic"] = img print(f" synthetic: {img.size}") return images def build_fixed_input_ids(processor, text_prompt="OCR this document. Extract all text."): """ Build input_ids with exactly NUM_IMG_TOKENS image tokens, matching our fixed 1120x1540 vision encoder output. Uses the processor's chat template but with a fixed-size image. """ # Create a dummy image at exact target size to get correct token count dummy_img = Image.new("RGB", (FIXED_W, FIXED_H), "white") messages = [{"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": text_prompt} ]}] text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) inputs = processor(text=text, images=[dummy_img], return_tensors="pt") input_ids = inputs["input_ids"] # Verify token counts ids_list = input_ids[0].tolist() n_img = ids_list.count(IMAGE_TOKEN_ID) n_pad = ids_list.count(VISION_PAD_ID) assert n_img == NUM_IMG_TOKENS, f"Expected {NUM_IMG_TOKENS} IMG tokens, got {n_img}" print(f" Input template: {input_ids.shape[1]} tokens ({n_img} IMG, {n_pad} VPAD)") return input_ids def preprocess_image_fixed(img, processor): """Resize image to exactly FIXED_H x FIXED_W and get pixel_values.""" # Resize maintaining content, pad to target img_resized = img.resize((FIXED_W, FIXED_H), Image.LANCZOS) # Use processor's image normalization dummy_messages = [{"role": "user", "content": [{"type": "image"}]}] text = processor.apply_chat_template(dummy_messages, add_generation_prompt=True, tokenize=False) inputs = processor(text=text, images=[img_resized], return_tensors="pt") return inputs["pixel_values"] def run_hf_model(images, processor): """Run original HF model with FIXED resolution preprocessing.""" from transformers import AutoModelForImageTextToText from safetensors.torch import load_file print("\n[HF Model] Loading...") model = AutoModelForImageTextToText.from_pretrained( MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu", ) state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors")) remapped = {k.replace("model.vision_encoder.", "model.vision_tower.") .replace("model.vision_projection.", "model.multi_modal_projector."): v for k, v in state_dict.items()} model.load_state_dict(remapped, strict=False) model = model.to("cuda").eval() results = {} for name, img in images.items(): print(f"\n [{name}] HF generate (fixed {FIXED_H}x{FIXED_W})...") pixel_values = preprocess_image_fixed(img, processor).to("cuda") input_ids = build_fixed_input_ids(processor).to("cuda") attention_mask = torch.ones_like(input_ids) input_len = input_ids.shape[1] t0 = time.time() with torch.no_grad(): output_ids = model.generate( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, image_sizes=torch.tensor([[FIXED_H, FIXED_W]], device="cuda"), max_new_tokens=512, do_sample=False, temperature=None, top_p=None, ) elapsed = time.time() - t0 new_ids = output_ids[0, input_len:] text = processor.tokenizer.decode(new_ids, skip_special_tokens=True) n_tok = len(new_ids) print(f" {n_tok} tokens, {elapsed:.1f}s ({n_tok/elapsed:.1f} tok/s)") print(f" Output: {text[:200]}...") results[name] = {"text": text, "tokens": n_tok, "time": elapsed} del model torch.cuda.empty_cache() return results def run_fixed_modules(images, processor): """Run fixed PyTorch modules E2E with proper token matching.""" sys.path.insert(0, ".") from export_vision import build_vision_module, load_original_model from export_decoder import TextDecoderFixed, build_decoder_module print("\n[Fixed Modules] Loading...") orig = load_original_model() vision = build_vision_module(orig) decoder = build_decoder_module(orig) embed_tokens = orig.model.language_model.embed_tokens device = "cuda" dtype = torch.bfloat16 vision = vision.to(device).to(dtype).eval() decoder = decoder.to(device).to(dtype).eval() embed_tokens = embed_tokens.to(device).to(dtype) del orig torch.cuda.empty_cache() results = {} for name, img in images.items(): print(f"\n [{name}] Fixed modules E2E...") try: # Step 1: Get pixel values at fixed resolution pixel_values = preprocess_image_fixed(img, processor).to(device).to(dtype) input_ids = build_fixed_input_ids(processor).to(device) # Step 2: Vision encoder with torch.no_grad(): image_features = vision(pixel_values) # [1, 2200, 1024] print(f" Vision output: {image_features.shape}") # Step 3: Build combined embeddings with torch.no_grad(): text_embeds = embed_tokens(input_ids) # [1, seq_len, 1024] ids_list = input_ids[0].tolist() img_positions = [i for i, t in enumerate(ids_list) if t == IMAGE_TOKEN_ID] assert len(img_positions) == image_features.shape[1], \ f"Token/feature mismatch: {len(img_positions)} slots vs {image_features.shape[1]} features" combined = text_embeds.clone() # Scatter vision features into IMG token positions indices = torch.tensor(img_positions, device=device) combined[0, indices] = image_features[0] seq_len = combined.shape[1] print(f" Combined seq: {seq_len}, scattering {len(img_positions)} features") # Step 4: Prefill — feed combined embeddings through decoder kv_caches = [] for _ in range(NUM_LAYERS): k = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) v = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) kv_caches.extend([k, v]) position_ids = torch.arange(seq_len, device=device).unsqueeze(0) cache_position = torch.arange(seq_len, device=device) # Causal mask for prefill mask = torch.full((1, 1, seq_len, MAX_SEQ_LEN), float("-inf"), dtype=dtype, device=device) for i in range(seq_len): mask[0, 0, i, :i+1] = 0.0 # Monkey-patch embed_tokens for prefill (we already have embeddings) orig_embed = decoder.embed_tokens class PrefillEmbed(torch.nn.Module): def __init__(self, embeds): super().__init__() self.e = embeds def forward(self, x): return self.e decoder.embed_tokens = PrefillEmbed(combined) t0 = time.time() with torch.no_grad(): result = decoder(input_ids[:, :seq_len], mask, position_ids, cache_position, *kv_caches) decoder.embed_tokens = orig_embed logits = result[0] kv_caches = list(result[1:]) next_token = logits[0, -1].argmax().item() generated = [next_token] cur_pos = seq_len # Step 5: Autoregressive decode for step in range(511): if next_token == EOS_TOKEN_ID or cur_pos >= MAX_SEQ_LEN: break token_input = torch.tensor([[next_token]], device=device) pos_ids = torch.tensor([[cur_pos]], device=device) cache_pos = torch.tensor([cur_pos], device=device) # Decode mask: attend to all positions up to cur_pos dmask = torch.zeros(1, 1, 1, MAX_SEQ_LEN, dtype=dtype, device=device) dmask[0, 0, 0, cur_pos+1:] = float("-inf") with torch.no_grad(): result = decoder(token_input, dmask, pos_ids, cache_pos, *kv_caches) logits = result[0] kv_caches = list(result[1:]) next_token = logits[0, -1].argmax().item() generated.append(next_token) cur_pos += 1 elapsed = time.time() - t0 text = processor.tokenizer.decode(generated, skip_special_tokens=True) n_tok = len(generated) print(f" {n_tok} tokens, {elapsed:.1f}s ({n_tok/elapsed:.1f} tok/s)") print(f" Output: {text[:200]}...") results[name] = {"text": text, "tokens": n_tok, "time": elapsed} except Exception as e: import traceback traceback.print_exc() results[name] = {"text": f"ERROR: {e}", "tokens": 0, "time": 0} return results def levenshtein(s1, s2): if len(s1) < len(s2): return levenshtein(s2, s1) if len(s2) == 0: return len(s1) prev = list(range(len(s2) + 1)) for i, c1 in enumerate(s1): curr = [i + 1] for j, c2 in enumerate(s2): curr.append(min(prev[j+1]+1, curr[j]+1, prev[j]+(c1!=c2))) prev = curr return prev[-1] def compare(hf_results, fx_results, images): print("\n" + "="*70) print("E2E COMPARISON (Fixed 1120x1540 resolution)") print("="*70) for name in images: hf = hf_results[name] fx = fx_results[name] hf_t, fx_t = hf["text"], fx["text"] exact = hf_t.strip() == fx_t.strip() ed = levenshtein(hf_t, fx_t) max_len = max(len(hf_t), len(fx_t), 1) char_acc = 1.0 - ed / max_len ref_words = set(hf_t.lower().split()) hyp_words = set(fx_t.lower().split()) union = ref_words | hyp_words word_acc = len(ref_words & hyp_words) / len(union) if union else 1.0 print(f"\n{'─'*70}") print(f" [{name}]") print(f" HF ({hf['tokens']} tok): {hf_t[:250]}") print(f" FIX ({fx['tokens']} tok): {fx_t[:250]}") print(f" Exact: {'✅ YES' if exact else '❌ NO'}") print(f" Edit dist: {ed}, Char acc: {char_acc:.4f}, Word acc: {word_acc:.4f}") print("\n" + "="*70) def main(): print("LightOnOCR-2-1B E2E Validation v2") print(f"Fixed resolution: {FIXED_H}x{FIXED_W} → {NUM_IMG_TOKENS} vision features") print(f"Device: cuda, Max seq: {MAX_SEQ_LEN}") print("="*70) images = download_test_images() from transformers import AutoProcessor processor = AutoProcessor.from_pretrained(MODEL_DIR) hf_results = run_hf_model(images, processor) # Free GPU for fixed modules torch.cuda.empty_cache() fx_results = run_fixed_modules(images, processor) compare(hf_results, fx_results, images) if __name__ == "__main__": main()