| |
| """ |
| E2E Validation v2: Fixed resolution pipeline. |
| Forces all images to 1120x1540 so token count matches fixed vision encoder output. |
| Compares HF model output vs fixed PyTorch modules output. |
| """ |
|
|
| import os, sys, time, torch, torch.nn.functional as F |
| from pathlib import Path |
| from PIL import Image |
|
|
| |
| MODEL_DIR = "./models/LightOnOCR-2-1B" |
| FIXED_H, FIXED_W = 1120, 1540 |
| PATCH_SIZE = 14 |
| SPATIAL_MERGE = 2 |
| MERGED_H = FIXED_H // PATCH_SIZE // SPATIAL_MERGE |
| MERGED_W = FIXED_W // PATCH_SIZE // SPATIAL_MERGE |
| NUM_IMG_TOKENS = MERGED_H * MERGED_W |
| NUM_VPAD_TOKENS = MERGED_H - 1 |
|
|
| |
| IMAGE_TOKEN_ID = 151655 |
| VISION_PAD_ID = 151654 |
| VISION_END_ID = 151653 |
| IM_START_ID = 151644 |
| IM_END_ID = 151645 |
| EOS_TOKEN_ID = 151645 |
|
|
| |
| NUM_LAYERS = 28 |
| NUM_KV_HEADS = 8 |
| NUM_HEADS = 16 |
| HEAD_DIM = 128 |
| HIDDEN_SIZE = 1024 |
| MAX_SEQ_LEN = 4096 |
|
|
|
|
| def download_test_images(): |
| """Get real test images.""" |
| import requests |
| os.makedirs("test_images", exist_ok=True) |
| images = {} |
|
|
| sources = { |
| "receipt": "https://huggingface.co/datasets/hf-internal-testing/fixtures_ocr/resolve/main/SROIE-receipt.jpeg", |
| } |
| for name, url in sources.items(): |
| path = f"test_images/{name}.png" |
| if not os.path.exists(path): |
| print(f" Downloading {name}...") |
| try: |
| resp = requests.get(url, timeout=30) |
| resp.raise_for_status() |
| with open(path, "wb") as f: |
| f.write(resp.content) |
| except Exception as e: |
| print(f" FAILED: {e}") |
| continue |
| img = Image.open(path).convert("RGB") |
| images[name] = img |
| print(f" {name}: {img.size}") |
|
|
| |
| img = Image.new("RGB", (800, 600), "white") |
| from PIL import ImageDraw |
| draw = ImageDraw.Draw(img) |
| draw.text((50, 50), "Invoice #12345", fill="black") |
| draw.text((50, 100), "Date: 2024-01-15", fill="black") |
| draw.text((50, 150), "Item 1: Widget x5 @ $10.00 = $50.00", fill="black") |
| draw.text((50, 200), "Item 2: Gadget x2 @ $24.99 = $49.98", fill="black") |
| draw.text((50, 250), "Total: $99.98", fill="black") |
| img.save("test_images/synthetic.png") |
| images["synthetic"] = img |
| print(f" synthetic: {img.size}") |
|
|
| return images |
|
|
|
|
| def build_fixed_input_ids(processor, text_prompt="OCR this document. Extract all text."): |
| """ |
| Build input_ids with exactly NUM_IMG_TOKENS image tokens, |
| matching our fixed 1120x1540 vision encoder output. |
| Uses the processor's chat template but with a fixed-size image. |
| """ |
| |
| dummy_img = Image.new("RGB", (FIXED_W, FIXED_H), "white") |
| messages = [{"role": "user", "content": [ |
| {"type": "image"}, {"type": "text", "text": text_prompt} |
| ]}] |
| text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| inputs = processor(text=text, images=[dummy_img], return_tensors="pt") |
|
|
| input_ids = inputs["input_ids"] |
| |
| ids_list = input_ids[0].tolist() |
| n_img = ids_list.count(IMAGE_TOKEN_ID) |
| n_pad = ids_list.count(VISION_PAD_ID) |
| assert n_img == NUM_IMG_TOKENS, f"Expected {NUM_IMG_TOKENS} IMG tokens, got {n_img}" |
| print(f" Input template: {input_ids.shape[1]} tokens ({n_img} IMG, {n_pad} VPAD)") |
| return input_ids |
|
|
|
|
| def preprocess_image_fixed(img, processor): |
| """Resize image to exactly FIXED_H x FIXED_W and get pixel_values.""" |
| |
| img_resized = img.resize((FIXED_W, FIXED_H), Image.LANCZOS) |
| |
| dummy_messages = [{"role": "user", "content": [{"type": "image"}]}] |
| text = processor.apply_chat_template(dummy_messages, add_generation_prompt=True, tokenize=False) |
| inputs = processor(text=text, images=[img_resized], return_tensors="pt") |
| return inputs["pixel_values"] |
|
|
|
|
| def run_hf_model(images, processor): |
| """Run original HF model with FIXED resolution preprocessing.""" |
| from transformers import AutoModelForImageTextToText |
| from safetensors.torch import load_file |
|
|
| print("\n[HF Model] Loading...") |
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu", |
| ) |
| state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors")) |
| remapped = {k.replace("model.vision_encoder.", "model.vision_tower.") |
| .replace("model.vision_projection.", "model.multi_modal_projector."): v |
| for k, v in state_dict.items()} |
| model.load_state_dict(remapped, strict=False) |
| model = model.to("cuda").eval() |
|
|
| results = {} |
| for name, img in images.items(): |
| print(f"\n [{name}] HF generate (fixed {FIXED_H}x{FIXED_W})...") |
| pixel_values = preprocess_image_fixed(img, processor).to("cuda") |
| input_ids = build_fixed_input_ids(processor).to("cuda") |
| attention_mask = torch.ones_like(input_ids) |
|
|
| input_len = input_ids.shape[1] |
| t0 = time.time() |
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| attention_mask=attention_mask, |
| image_sizes=torch.tensor([[FIXED_H, FIXED_W]], device="cuda"), |
| max_new_tokens=512, do_sample=False, temperature=None, top_p=None, |
| ) |
| elapsed = time.time() - t0 |
| new_ids = output_ids[0, input_len:] |
| text = processor.tokenizer.decode(new_ids, skip_special_tokens=True) |
| n_tok = len(new_ids) |
| print(f" {n_tok} tokens, {elapsed:.1f}s ({n_tok/elapsed:.1f} tok/s)") |
| print(f" Output: {text[:200]}...") |
| results[name] = {"text": text, "tokens": n_tok, "time": elapsed} |
|
|
| del model |
| torch.cuda.empty_cache() |
| return results |
|
|
|
|
| def run_fixed_modules(images, processor): |
| """Run fixed PyTorch modules E2E with proper token matching.""" |
| sys.path.insert(0, ".") |
| from export_vision import build_vision_module, load_original_model |
| from export_decoder import TextDecoderFixed, build_decoder_module |
|
|
| print("\n[Fixed Modules] Loading...") |
| orig = load_original_model() |
|
|
| vision = build_vision_module(orig) |
| decoder = build_decoder_module(orig) |
| embed_tokens = orig.model.language_model.embed_tokens |
|
|
| device = "cuda" |
| dtype = torch.bfloat16 |
| vision = vision.to(device).to(dtype).eval() |
| decoder = decoder.to(device).to(dtype).eval() |
| embed_tokens = embed_tokens.to(device).to(dtype) |
|
|
| del orig |
| torch.cuda.empty_cache() |
|
|
| results = {} |
| for name, img in images.items(): |
| print(f"\n [{name}] Fixed modules E2E...") |
| try: |
| |
| pixel_values = preprocess_image_fixed(img, processor).to(device).to(dtype) |
| input_ids = build_fixed_input_ids(processor).to(device) |
|
|
| |
| with torch.no_grad(): |
| image_features = vision(pixel_values) |
| print(f" Vision output: {image_features.shape}") |
|
|
| |
| with torch.no_grad(): |
| text_embeds = embed_tokens(input_ids) |
|
|
| ids_list = input_ids[0].tolist() |
| img_positions = [i for i, t in enumerate(ids_list) if t == IMAGE_TOKEN_ID] |
| assert len(img_positions) == image_features.shape[1], \ |
| f"Token/feature mismatch: {len(img_positions)} slots vs {image_features.shape[1]} features" |
|
|
| combined = text_embeds.clone() |
| |
| indices = torch.tensor(img_positions, device=device) |
| combined[0, indices] = image_features[0] |
|
|
| seq_len = combined.shape[1] |
| print(f" Combined seq: {seq_len}, scattering {len(img_positions)} features") |
|
|
| |
| kv_caches = [] |
| for _ in range(NUM_LAYERS): |
| k = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) |
| v = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) |
| kv_caches.extend([k, v]) |
|
|
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
| cache_position = torch.arange(seq_len, device=device) |
|
|
| |
| mask = torch.full((1, 1, seq_len, MAX_SEQ_LEN), float("-inf"), dtype=dtype, device=device) |
| for i in range(seq_len): |
| mask[0, 0, i, :i+1] = 0.0 |
|
|
| |
| orig_embed = decoder.embed_tokens |
| class PrefillEmbed(torch.nn.Module): |
| def __init__(self, embeds): |
| super().__init__() |
| self.e = embeds |
| def forward(self, x): |
| return self.e |
| decoder.embed_tokens = PrefillEmbed(combined) |
|
|
| t0 = time.time() |
| with torch.no_grad(): |
| result = decoder(input_ids[:, :seq_len], mask, position_ids, cache_position, *kv_caches) |
|
|
| decoder.embed_tokens = orig_embed |
|
|
| logits = result[0] |
| kv_caches = list(result[1:]) |
| next_token = logits[0, -1].argmax().item() |
|
|
| generated = [next_token] |
| cur_pos = seq_len |
|
|
| |
| for step in range(511): |
| if next_token == EOS_TOKEN_ID or cur_pos >= MAX_SEQ_LEN: |
| break |
|
|
| token_input = torch.tensor([[next_token]], device=device) |
| pos_ids = torch.tensor([[cur_pos]], device=device) |
| cache_pos = torch.tensor([cur_pos], device=device) |
|
|
| |
| dmask = torch.zeros(1, 1, 1, MAX_SEQ_LEN, dtype=dtype, device=device) |
| dmask[0, 0, 0, cur_pos+1:] = float("-inf") |
|
|
| with torch.no_grad(): |
| result = decoder(token_input, dmask, pos_ids, cache_pos, *kv_caches) |
|
|
| logits = result[0] |
| kv_caches = list(result[1:]) |
| next_token = logits[0, -1].argmax().item() |
| generated.append(next_token) |
| cur_pos += 1 |
|
|
| elapsed = time.time() - t0 |
| text = processor.tokenizer.decode(generated, skip_special_tokens=True) |
| n_tok = len(generated) |
| print(f" {n_tok} tokens, {elapsed:.1f}s ({n_tok/elapsed:.1f} tok/s)") |
| print(f" Output: {text[:200]}...") |
| results[name] = {"text": text, "tokens": n_tok, "time": elapsed} |
|
|
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| results[name] = {"text": f"ERROR: {e}", "tokens": 0, "time": 0} |
|
|
| return results |
|
|
|
|
| def levenshtein(s1, s2): |
| if len(s1) < len(s2): return levenshtein(s2, s1) |
| if len(s2) == 0: return len(s1) |
| prev = list(range(len(s2) + 1)) |
| for i, c1 in enumerate(s1): |
| curr = [i + 1] |
| for j, c2 in enumerate(s2): |
| curr.append(min(prev[j+1]+1, curr[j]+1, prev[j]+(c1!=c2))) |
| prev = curr |
| return prev[-1] |
|
|
|
|
| def compare(hf_results, fx_results, images): |
| print("\n" + "="*70) |
| print("E2E COMPARISON (Fixed 1120x1540 resolution)") |
| print("="*70) |
|
|
| for name in images: |
| hf = hf_results[name] |
| fx = fx_results[name] |
| hf_t, fx_t = hf["text"], fx["text"] |
|
|
| exact = hf_t.strip() == fx_t.strip() |
| ed = levenshtein(hf_t, fx_t) |
| max_len = max(len(hf_t), len(fx_t), 1) |
| char_acc = 1.0 - ed / max_len |
|
|
| ref_words = set(hf_t.lower().split()) |
| hyp_words = set(fx_t.lower().split()) |
| union = ref_words | hyp_words |
| word_acc = len(ref_words & hyp_words) / len(union) if union else 1.0 |
|
|
| print(f"\n{'─'*70}") |
| print(f" [{name}]") |
| print(f" HF ({hf['tokens']} tok): {hf_t[:250]}") |
| print(f" FIX ({fx['tokens']} tok): {fx_t[:250]}") |
| print(f" Exact: {'✅ YES' if exact else '❌ NO'}") |
| print(f" Edit dist: {ed}, Char acc: {char_acc:.4f}, Word acc: {word_acc:.4f}") |
|
|
| print("\n" + "="*70) |
|
|
|
|
| def main(): |
| print("LightOnOCR-2-1B E2E Validation v2") |
| print(f"Fixed resolution: {FIXED_H}x{FIXED_W} → {NUM_IMG_TOKENS} vision features") |
| print(f"Device: cuda, Max seq: {MAX_SEQ_LEN}") |
| print("="*70) |
|
|
| images = download_test_images() |
|
|
| from transformers import AutoProcessor |
| processor = AutoProcessor.from_pretrained(MODEL_DIR) |
|
|
| hf_results = run_hf_model(images, processor) |
|
|
| |
| torch.cuda.empty_cache() |
|
|
| fx_results = run_fixed_modules(images, processor) |
|
|
| compare(hf_results, fx_results, images) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|