| |
| """ |
| E2E Validation for INT8 weight-only quantized models. |
| Compares: HF original vs INT8 quantized fixed modules. |
| """ |
|
|
| import os, sys, time, torch, torch.nn.functional as F |
| from PIL import Image |
| sys.path.insert(0, ".") |
|
|
| MODEL_DIR = "./models/LightOnOCR-2-1B" |
| FIXED_H, FIXED_W = 1120, 1540 |
| IMAGE_TOKEN_ID = 151655 |
| EOS_TOKEN_ID = 151645 |
| NUM_LAYERS = 28 |
| NUM_KV_HEADS = 8 |
| HEAD_DIM = 128 |
| MAX_SEQ_LEN = 4096 |
|
|
|
|
| def get_test_images(): |
| images = {} |
| if os.path.exists("test_images/receipt.png"): |
| images["receipt"] = Image.open("test_images/receipt.png").convert("RGB") |
| img = Image.new("RGB", (800, 600), "white") |
| from PIL import ImageDraw |
| draw = ImageDraw.Draw(img) |
| draw.text((50, 50), "Invoice #12345", fill="black") |
| draw.text((50, 100), "Date: 2024-01-15", fill="black") |
| draw.text((50, 150), "Item 1: Widget x5 @ $10.00 = $50.00", fill="black") |
| draw.text((50, 200), "Item 2: Gadget x2 @ $24.99 = $49.98", fill="black") |
| draw.text((50, 250), "Total: $99.98", fill="black") |
| images["synthetic"] = img |
| return images |
|
|
|
|
| def preprocess_image_fixed(img, processor): |
| img_resized = img.resize((FIXED_W, FIXED_H), Image.LANCZOS) |
| dummy_msg = [{"role": "user", "content": [{"type": "image"}]}] |
| text = processor.apply_chat_template(dummy_msg, add_generation_prompt=True, tokenize=False) |
| inputs = processor(text=text, images=[img_resized], return_tensors="pt") |
| return inputs["pixel_values"] |
|
|
|
|
| def build_fixed_input_ids(processor): |
| dummy_img = Image.new("RGB", (FIXED_W, FIXED_H), "white") |
| messages = [{"role": "user", "content": [ |
| {"type": "image"}, {"type": "text", "text": "OCR this document. Extract all text."} |
| ]}] |
| text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| inputs = processor(text=text, images=[dummy_img], return_tensors="pt") |
| return inputs["input_ids"] |
|
|
|
|
| def run_hf_model(images, processor): |
| from transformers import AutoModelForImageTextToText |
| from safetensors.torch import load_file |
|
|
| print("\n[HF Model]") |
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu") |
| state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors")) |
| remapped = {k.replace("model.vision_encoder.", "model.vision_tower.") |
| .replace("model.vision_projection.", "model.multi_modal_projector."): v |
| for k, v in state_dict.items()} |
| model.load_state_dict(remapped, strict=False) |
| model = model.to("cuda").eval() |
|
|
| results = {} |
| for name, img in images.items(): |
| print(f" [{name}] HF generate...") |
| pv = preprocess_image_fixed(img, processor).to("cuda") |
| input_ids = build_fixed_input_ids(processor).to("cuda") |
| input_len = input_ids.shape[1] |
| t0 = time.time() |
| with torch.no_grad(): |
| out = model.generate( |
| input_ids=input_ids, pixel_values=pv, |
| attention_mask=torch.ones_like(input_ids), |
| image_sizes=torch.tensor([[FIXED_H, FIXED_W]], device="cuda"), |
| max_new_tokens=512, do_sample=False, temperature=None, top_p=None) |
| elapsed = time.time() - t0 |
| text = processor.tokenizer.decode(out[0, input_len:], skip_special_tokens=True) |
| n = len(out[0]) - input_len |
| print(f" {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)") |
| print(f" {text[:150]}...") |
| results[name] = {"text": text, "tokens": n, "time": elapsed} |
| del model; torch.cuda.empty_cache() |
| return results |
|
|
|
|
| def run_int8_modules(images, processor): |
| """Run INT8 weight-only quantized fixed modules E2E.""" |
| from export_vision import build_vision_module, load_original_model |
| from export_decoder import build_decoder_module |
| from torchao.quantization import quantize_, int8_weight_only |
|
|
| print("\n[INT8 Quantized Modules]") |
| orig = load_original_model() |
| vision = build_vision_module(orig) |
| decoder = build_decoder_module(orig) |
| embed_tokens = orig.model.language_model.embed_tokens |
|
|
| device = "cuda" |
| dtype = torch.bfloat16 |
|
|
| |
| print(" Applying int8_weight_only to vision...") |
| vision = vision.to("cpu").to(torch.float32) |
| quantize_(vision, int8_weight_only()) |
| vision = vision.to(device).to(dtype).eval() |
|
|
| print(" Applying int8_weight_only to decoder...") |
| decoder = decoder.to("cpu").to(torch.float32) |
| quantize_(decoder, int8_weight_only()) |
| decoder = decoder.to(device).to(dtype).eval() |
|
|
| embed_tokens = embed_tokens.to(device).to(dtype) |
| del orig; torch.cuda.empty_cache() |
|
|
| results = {} |
| for name, img in images.items(): |
| print(f" [{name}] INT8 E2E...") |
| try: |
| pv = preprocess_image_fixed(img, processor).to(device).to(dtype) |
| input_ids = build_fixed_input_ids(processor).to(device) |
|
|
| with torch.no_grad(): |
| image_features = vision(pv) |
| print(f" Vision: {image_features.shape}") |
|
|
| with torch.no_grad(): |
| text_embeds = embed_tokens(input_ids) |
|
|
| ids_list = input_ids[0].tolist() |
| img_positions = [i for i, t in enumerate(ids_list) if t == IMAGE_TOKEN_ID] |
|
|
| combined = text_embeds.clone() |
| indices = torch.tensor(img_positions, device=device) |
| combined[0, indices] = image_features[0] |
|
|
| seq_len = combined.shape[1] |
|
|
| kv_caches = [] |
| for _ in range(NUM_LAYERS): |
| k = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) |
| v = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device) |
| kv_caches.extend([k, v]) |
|
|
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
| cache_position = torch.arange(seq_len, device=device) |
| mask = torch.full((1, 1, seq_len, MAX_SEQ_LEN), float("-inf"), dtype=dtype, device=device) |
| for i in range(seq_len): |
| mask[0, 0, i, :i+1] = 0.0 |
|
|
| orig_embed = decoder.embed_tokens |
| class PrefillEmbed(torch.nn.Module): |
| def __init__(self, e): super().__init__(); self.e = e |
| def forward(self, x): return self.e |
| decoder.embed_tokens = PrefillEmbed(combined) |
|
|
| t0 = time.time() |
| with torch.no_grad(): |
| result = decoder(input_ids[:, :seq_len], mask, position_ids, cache_position, *kv_caches) |
| decoder.embed_tokens = orig_embed |
|
|
| logits = result[0] |
| kv_caches = list(result[1:]) |
| next_token = logits[0, -1].argmax().item() |
| generated = [next_token] |
| cur_pos = seq_len |
|
|
| for step in range(511): |
| if next_token == EOS_TOKEN_ID or cur_pos >= MAX_SEQ_LEN: |
| break |
| token_input = torch.tensor([[next_token]], device=device) |
| pos_ids = torch.tensor([[cur_pos]], device=device) |
| cache_pos = torch.tensor([cur_pos], device=device) |
| dmask = torch.zeros(1, 1, 1, MAX_SEQ_LEN, dtype=dtype, device=device) |
| dmask[0, 0, 0, cur_pos+1:] = float("-inf") |
| with torch.no_grad(): |
| result = decoder(token_input, dmask, pos_ids, cache_pos, *kv_caches) |
| logits = result[0] |
| kv_caches = list(result[1:]) |
| next_token = logits[0, -1].argmax().item() |
| generated.append(next_token) |
| cur_pos += 1 |
|
|
| elapsed = time.time() - t0 |
| text = processor.tokenizer.decode(generated, skip_special_tokens=True) |
| n = len(generated) |
| print(f" {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)") |
| print(f" {text[:150]}...") |
| results[name] = {"text": text, "tokens": n, "time": elapsed} |
|
|
| except Exception as e: |
| import traceback; traceback.print_exc() |
| results[name] = {"text": f"ERROR: {e}", "tokens": 0, "time": 0} |
|
|
| return results |
|
|
|
|
| def levenshtein(s1, s2): |
| if len(s1) < len(s2): return levenshtein(s2, s1) |
| if len(s2) == 0: return len(s1) |
| prev = list(range(len(s2) + 1)) |
| for i, c1 in enumerate(s1): |
| curr = [i + 1] |
| for j, c2 in enumerate(s2): |
| curr.append(min(prev[j+1]+1, curr[j]+1, prev[j]+(c1!=c2))) |
| prev = curr |
| return prev[-1] |
|
|
|
|
| def main(): |
| from transformers import AutoProcessor |
| processor = AutoProcessor.from_pretrained(MODEL_DIR) |
|
|
| print("="*60) |
| print("LightOnOCR E2E: HF vs INT8 Quantized") |
| print("="*60) |
|
|
| images = get_test_images() |
| hf = run_hf_model(images, processor) |
| torch.cuda.empty_cache() |
| q8 = run_int8_modules(images, processor) |
|
|
| print("\n" + "="*60) |
| print("COMPARISON: HF (FP32) vs INT8 Weight-Only") |
| print("="*60) |
|
|
| for name in images: |
| hf_t = hf[name]["text"] |
| q8_t = q8[name]["text"] |
| exact = hf_t.strip() == q8_t.strip() |
| ed = levenshtein(hf_t, q8_t) |
| max_len = max(len(hf_t), len(q8_t), 1) |
| char_acc = 1.0 - ed / max_len |
| ref_w = set(hf_t.lower().split()) |
| hyp_w = set(q8_t.lower().split()) |
| word_acc = len(ref_w & hyp_w) / len(ref_w | hyp_w) if ref_w | hyp_w else 1.0 |
|
|
| print(f"\n{'─'*60}") |
| print(f" [{name}]") |
| print(f" HF ({hf[name]['tokens']} tok): {hf_t[:200]}") |
| print(f" INT8 ({q8[name]['tokens']} tok): {q8_t[:200]}") |
| print(f" Exact: {'✅' if exact else '❌'} | Edit dist: {ed} | Char acc: {char_acc:.4f} | Word acc: {word_acc:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|