#!/usr/bin/env python3 """ E2E Validation for INT8 weight-only quantized models. Compares: HF original vs INT8 quantized fixed modules. """ import os, sys, time, torch, torch.nn.functional as F from PIL import Image sys.path.insert(0, ".") MODEL_DIR = "./models/LightOnOCR-2-1B" FIXED_H, FIXED_W = 1120, 1540 IMAGE_TOKEN_ID = 151655 EOS_TOKEN_ID = 151645 NUM_LAYERS = 28 NUM_KV_HEADS = 8 HEAD_DIM = 128 MAX_SEQ_LEN = 4096 def get_test_images(): images = {} if os.path.exists("test_images/receipt.png"): images["receipt"] = Image.open("test_images/receipt.png").convert("RGB") img = Image.new("RGB", (800, 600), "white") from PIL import ImageDraw draw = ImageDraw.Draw(img) draw.text((50, 50), "Invoice #12345", fill="black") draw.text((50, 100), "Date: 2024-01-15", fill="black") draw.text((50, 150), "Item 1: Widget x5 @ $10.00 = $50.00", fill="black") draw.text((50, 200), "Item 2: Gadget x2 @ $24.99 = $49.98", fill="black") draw.text((50, 250), "Total: $99.98", fill="black") images["synthetic"] = img return images def preprocess_image_fixed(img, processor): img_resized = img.resize((FIXED_W, FIXED_H), Image.LANCZOS) dummy_msg = [{"role": "user", "content": [{"type": "image"}]}] text = processor.apply_chat_template(dummy_msg, add_generation_prompt=True, tokenize=False) inputs = processor(text=text, images=[img_resized], return_tensors="pt") return inputs["pixel_values"] def build_fixed_input_ids(processor): dummy_img = Image.new("RGB", (FIXED_W, FIXED_H), "white") messages = [{"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "OCR this document. Extract all text."} ]}] text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) inputs = processor(text=text, images=[dummy_img], return_tensors="pt") return inputs["input_ids"] def run_hf_model(images, processor): from transformers import AutoModelForImageTextToText from safetensors.torch import load_file print("\n[HF Model]") model = AutoModelForImageTextToText.from_pretrained( MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu") state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors")) remapped = {k.replace("model.vision_encoder.", "model.vision_tower.") .replace("model.vision_projection.", "model.multi_modal_projector."): v for k, v in state_dict.items()} model.load_state_dict(remapped, strict=False) model = model.to("cuda").eval() results = {} for name, img in images.items(): print(f" [{name}] HF generate...") pv = preprocess_image_fixed(img, processor).to("cuda") input_ids = build_fixed_input_ids(processor).to("cuda") input_len = input_ids.shape[1] t0 = time.time() with torch.no_grad(): out = model.generate( input_ids=input_ids, pixel_values=pv, attention_mask=torch.ones_like(input_ids), image_sizes=torch.tensor([[FIXED_H, FIXED_W]], device="cuda"), max_new_tokens=512, do_sample=False, temperature=None, top_p=None) elapsed = time.time() - t0 text = processor.tokenizer.decode(out[0, input_len:], skip_special_tokens=True) n = len(out[0]) - input_len print(f" {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)") print(f" {text[:150]}...") results[name] = {"text": text, "tokens": n, "time": elapsed} del model; torch.cuda.empty_cache() return results def run_int8_modules(images, processor): """Run INT8 weight-only quantized fixed modules E2E.""" from export_vision import build_vision_module, load_original_model from export_decoder import build_decoder_module from torchao.quantization import quantize_, int8_weight_only print("\n[INT8 Quantized Modules]") orig = load_original_model() vision = build_vision_module(orig) decoder = build_decoder_module(orig) embed_tokens = orig.model.language_model.embed_tokens device = "cuda" dtype = torch.bfloat16 # Apply INT8 weight-only quantization (same as what we exported to .pte) print(" Applying int8_weight_only to vision...") vision = vision.to("cpu").to(torch.float32) quantize_(vision, int8_weight_only()) vision = vision.to(device).to(dtype).eval() print(" Applying int8_weight_only to decoder...") decoder = decoder.to("cpu").to(torch.float32) quantize_(decoder, int8_weight_only()) decoder = decoder.to(device).to(dtype).eval() embed_tokens = embed_tokens.to(device).to(dtype) del orig; torch.cuda.empty_cache() results = {} for name, img in images.items(): print(f" [{name}] INT8 E2E...") try: pv = preprocess_image_fixed(img, processor).to(device).to(dtype) input_ids = build_fixed_input_ids(processor).to(device) with torch.no_grad(): image_features = vision(pv) print(f" Vision: {image_features.shape}") with torch.no_grad(): text_embeds = embed_tokens(input_ids) ids_list = input_ids[0].tolist() img_positions = [i for i, t in enumerate(ids_list) if t == IMAGE_TOKEN_ID] combined = text_embeds.clone() indices = torch.tensor(img_positions, device=device) combined[0, indices] = image_features[0] seq_len = combined.shape[1] kv_caches = [] for _ in range(NUM_LAYERS): k = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) v = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) kv_caches.extend([k, v]) position_ids = torch.arange(seq_len, device=device).unsqueeze(0) cache_position = torch.arange(seq_len, device=device) mask = torch.full((1, 1, seq_len, MAX_SEQ_LEN), float("-inf"), dtype=dtype, device=device) for i in range(seq_len): mask[0, 0, i, :i+1] = 0.0 orig_embed = decoder.embed_tokens class PrefillEmbed(torch.nn.Module): def __init__(self, e): super().__init__(); self.e = e def forward(self, x): return self.e decoder.embed_tokens = PrefillEmbed(combined) t0 = time.time() with torch.no_grad(): result = decoder(input_ids[:, :seq_len], mask, position_ids, cache_position, *kv_caches) decoder.embed_tokens = orig_embed logits = result[0] kv_caches = list(result[1:]) next_token = logits[0, -1].argmax().item() generated = [next_token] cur_pos = seq_len for step in range(511): if next_token == EOS_TOKEN_ID or cur_pos >= MAX_SEQ_LEN: break token_input = torch.tensor([[next_token]], device=device) pos_ids = torch.tensor([[cur_pos]], device=device) cache_pos = torch.tensor([cur_pos], device=device) dmask = torch.zeros(1, 1, 1, MAX_SEQ_LEN, dtype=dtype, device=device) dmask[0, 0, 0, cur_pos+1:] = float("-inf") with torch.no_grad(): result = decoder(token_input, dmask, pos_ids, cache_pos, *kv_caches) logits = result[0] kv_caches = list(result[1:]) next_token = logits[0, -1].argmax().item() generated.append(next_token) cur_pos += 1 elapsed = time.time() - t0 text = processor.tokenizer.decode(generated, skip_special_tokens=True) n = len(generated) print(f" {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)") print(f" {text[:150]}...") results[name] = {"text": text, "tokens": n, "time": elapsed} except Exception as e: import traceback; traceback.print_exc() results[name] = {"text": f"ERROR: {e}", "tokens": 0, "time": 0} return results def levenshtein(s1, s2): if len(s1) < len(s2): return levenshtein(s2, s1) if len(s2) == 0: return len(s1) prev = list(range(len(s2) + 1)) for i, c1 in enumerate(s1): curr = [i + 1] for j, c2 in enumerate(s2): curr.append(min(prev[j+1]+1, curr[j]+1, prev[j]+(c1!=c2))) prev = curr return prev[-1] def main(): from transformers import AutoProcessor processor = AutoProcessor.from_pretrained(MODEL_DIR) print("="*60) print("LightOnOCR E2E: HF vs INT8 Quantized") print("="*60) images = get_test_images() hf = run_hf_model(images, processor) torch.cuda.empty_cache() q8 = run_int8_modules(images, processor) print("\n" + "="*60) print("COMPARISON: HF (FP32) vs INT8 Weight-Only") print("="*60) for name in images: hf_t = hf[name]["text"] q8_t = q8[name]["text"] exact = hf_t.strip() == q8_t.strip() ed = levenshtein(hf_t, q8_t) max_len = max(len(hf_t), len(q8_t), 1) char_acc = 1.0 - ed / max_len ref_w = set(hf_t.lower().split()) hyp_w = set(q8_t.lower().split()) word_acc = len(ref_w & hyp_w) / len(ref_w | hyp_w) if ref_w | hyp_w else 1.0 print(f"\n{'─'*60}") print(f" [{name}]") print(f" HF ({hf[name]['tokens']} tok): {hf_t[:200]}") print(f" INT8 ({q8[name]['tokens']} tok): {q8_t[:200]}") print(f" Exact: {'✅' if exact else '❌'} | Edit dist: {ed} | Char acc: {char_acc:.4f} | Word acc: {word_acc:.4f}") if __name__ == "__main__": main()