LightOnOCR-2-1B-ExecuTorch / scripts /test_e2e_int8.py
acul3's picture
Upload scripts/test_e2e_int8.py with huggingface_hub
0aebce7 verified
#!/usr/bin/env python3
"""
E2E Validation for INT8 weight-only quantized models.
Compares: HF original vs INT8 quantized fixed modules.
"""
import os, sys, time, torch, torch.nn.functional as F
from PIL import Image
sys.path.insert(0, ".")
MODEL_DIR = "./models/LightOnOCR-2-1B"
FIXED_H, FIXED_W = 1120, 1540
IMAGE_TOKEN_ID = 151655
EOS_TOKEN_ID = 151645
NUM_LAYERS = 28
NUM_KV_HEADS = 8
HEAD_DIM = 128
MAX_SEQ_LEN = 4096
def get_test_images():
images = {}
if os.path.exists("test_images/receipt.png"):
images["receipt"] = Image.open("test_images/receipt.png").convert("RGB")
img = Image.new("RGB", (800, 600), "white")
from PIL import ImageDraw
draw = ImageDraw.Draw(img)
draw.text((50, 50), "Invoice #12345", fill="black")
draw.text((50, 100), "Date: 2024-01-15", fill="black")
draw.text((50, 150), "Item 1: Widget x5 @ $10.00 = $50.00", fill="black")
draw.text((50, 200), "Item 2: Gadget x2 @ $24.99 = $49.98", fill="black")
draw.text((50, 250), "Total: $99.98", fill="black")
images["synthetic"] = img
return images
def preprocess_image_fixed(img, processor):
img_resized = img.resize((FIXED_W, FIXED_H), Image.LANCZOS)
dummy_msg = [{"role": "user", "content": [{"type": "image"}]}]
text = processor.apply_chat_template(dummy_msg, add_generation_prompt=True, tokenize=False)
inputs = processor(text=text, images=[img_resized], return_tensors="pt")
return inputs["pixel_values"]
def build_fixed_input_ids(processor):
dummy_img = Image.new("RGB", (FIXED_W, FIXED_H), "white")
messages = [{"role": "user", "content": [
{"type": "image"}, {"type": "text", "text": "OCR this document. Extract all text."}
]}]
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = processor(text=text, images=[dummy_img], return_tensors="pt")
return inputs["input_ids"]
def run_hf_model(images, processor):
from transformers import AutoModelForImageTextToText
from safetensors.torch import load_file
print("\n[HF Model]")
model = AutoModelForImageTextToText.from_pretrained(
MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu")
state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors"))
remapped = {k.replace("model.vision_encoder.", "model.vision_tower.")
.replace("model.vision_projection.", "model.multi_modal_projector."): v
for k, v in state_dict.items()}
model.load_state_dict(remapped, strict=False)
model = model.to("cuda").eval()
results = {}
for name, img in images.items():
print(f" [{name}] HF generate...")
pv = preprocess_image_fixed(img, processor).to("cuda")
input_ids = build_fixed_input_ids(processor).to("cuda")
input_len = input_ids.shape[1]
t0 = time.time()
with torch.no_grad():
out = model.generate(
input_ids=input_ids, pixel_values=pv,
attention_mask=torch.ones_like(input_ids),
image_sizes=torch.tensor([[FIXED_H, FIXED_W]], device="cuda"),
max_new_tokens=512, do_sample=False, temperature=None, top_p=None)
elapsed = time.time() - t0
text = processor.tokenizer.decode(out[0, input_len:], skip_special_tokens=True)
n = len(out[0]) - input_len
print(f" {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)")
print(f" {text[:150]}...")
results[name] = {"text": text, "tokens": n, "time": elapsed}
del model; torch.cuda.empty_cache()
return results
def run_int8_modules(images, processor):
"""Run INT8 weight-only quantized fixed modules E2E."""
from export_vision import build_vision_module, load_original_model
from export_decoder import build_decoder_module
from torchao.quantization import quantize_, int8_weight_only
print("\n[INT8 Quantized Modules]")
orig = load_original_model()
vision = build_vision_module(orig)
decoder = build_decoder_module(orig)
embed_tokens = orig.model.language_model.embed_tokens
device = "cuda"
dtype = torch.bfloat16
# Apply INT8 weight-only quantization (same as what we exported to .pte)
print(" Applying int8_weight_only to vision...")
vision = vision.to("cpu").to(torch.float32)
quantize_(vision, int8_weight_only())
vision = vision.to(device).to(dtype).eval()
print(" Applying int8_weight_only to decoder...")
decoder = decoder.to("cpu").to(torch.float32)
quantize_(decoder, int8_weight_only())
decoder = decoder.to(device).to(dtype).eval()
embed_tokens = embed_tokens.to(device).to(dtype)
del orig; torch.cuda.empty_cache()
results = {}
for name, img in images.items():
print(f" [{name}] INT8 E2E...")
try:
pv = preprocess_image_fixed(img, processor).to(device).to(dtype)
input_ids = build_fixed_input_ids(processor).to(device)
with torch.no_grad():
image_features = vision(pv)
print(f" Vision: {image_features.shape}")
with torch.no_grad():
text_embeds = embed_tokens(input_ids)
ids_list = input_ids[0].tolist()
img_positions = [i for i, t in enumerate(ids_list) if t == IMAGE_TOKEN_ID]
combined = text_embeds.clone()
indices = torch.tensor(img_positions, device=device)
combined[0, indices] = image_features[0]
seq_len = combined.shape[1]
kv_caches = []
for _ in range(NUM_LAYERS):
k = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
v = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
kv_caches.extend([k, v])
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
cache_position = torch.arange(seq_len, device=device)
mask = torch.full((1, 1, seq_len, MAX_SEQ_LEN), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
mask[0, 0, i, :i+1] = 0.0
orig_embed = decoder.embed_tokens
class PrefillEmbed(torch.nn.Module):
def __init__(self, e): super().__init__(); self.e = e
def forward(self, x): return self.e
decoder.embed_tokens = PrefillEmbed(combined)
t0 = time.time()
with torch.no_grad():
result = decoder(input_ids[:, :seq_len], mask, position_ids, cache_position, *kv_caches)
decoder.embed_tokens = orig_embed
logits = result[0]
kv_caches = list(result[1:])
next_token = logits[0, -1].argmax().item()
generated = [next_token]
cur_pos = seq_len
for step in range(511):
if next_token == EOS_TOKEN_ID or cur_pos >= MAX_SEQ_LEN:
break
token_input = torch.tensor([[next_token]], device=device)
pos_ids = torch.tensor([[cur_pos]], device=device)
cache_pos = torch.tensor([cur_pos], device=device)
dmask = torch.zeros(1, 1, 1, MAX_SEQ_LEN, dtype=dtype, device=device)
dmask[0, 0, 0, cur_pos+1:] = float("-inf")
with torch.no_grad():
result = decoder(token_input, dmask, pos_ids, cache_pos, *kv_caches)
logits = result[0]
kv_caches = list(result[1:])
next_token = logits[0, -1].argmax().item()
generated.append(next_token)
cur_pos += 1
elapsed = time.time() - t0
text = processor.tokenizer.decode(generated, skip_special_tokens=True)
n = len(generated)
print(f" {n} tok, {elapsed:.1f}s ({n/elapsed:.1f} tok/s)")
print(f" {text[:150]}...")
results[name] = {"text": text, "tokens": n, "time": elapsed}
except Exception as e:
import traceback; traceback.print_exc()
results[name] = {"text": f"ERROR: {e}", "tokens": 0, "time": 0}
return results
def levenshtein(s1, s2):
if len(s1) < len(s2): return levenshtein(s2, s1)
if len(s2) == 0: return len(s1)
prev = list(range(len(s2) + 1))
for i, c1 in enumerate(s1):
curr = [i + 1]
for j, c2 in enumerate(s2):
curr.append(min(prev[j+1]+1, curr[j]+1, prev[j]+(c1!=c2)))
prev = curr
return prev[-1]
def main():
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(MODEL_DIR)
print("="*60)
print("LightOnOCR E2E: HF vs INT8 Quantized")
print("="*60)
images = get_test_images()
hf = run_hf_model(images, processor)
torch.cuda.empty_cache()
q8 = run_int8_modules(images, processor)
print("\n" + "="*60)
print("COMPARISON: HF (FP32) vs INT8 Weight-Only")
print("="*60)
for name in images:
hf_t = hf[name]["text"]
q8_t = q8[name]["text"]
exact = hf_t.strip() == q8_t.strip()
ed = levenshtein(hf_t, q8_t)
max_len = max(len(hf_t), len(q8_t), 1)
char_acc = 1.0 - ed / max_len
ref_w = set(hf_t.lower().split())
hyp_w = set(q8_t.lower().split())
word_acc = len(ref_w & hyp_w) / len(ref_w | hyp_w) if ref_w | hyp_w else 1.0
print(f"\n{'─'*60}")
print(f" [{name}]")
print(f" HF ({hf[name]['tokens']} tok): {hf_t[:200]}")
print(f" INT8 ({q8[name]['tokens']} tok): {q8_t[:200]}")
print(f" Exact: {'✅' if exact else '❌'} | Edit dist: {ed} | Char acc: {char_acc:.4f} | Word acc: {word_acc:.4f}")
if __name__ == "__main__":
main()