acul3's picture
Upload scripts/test_e2e_v2.py with huggingface_hub
02acc80 verified
#!/usr/bin/env python3
"""
E2E Validation v2: Fixed resolution pipeline.
Forces all images to 1120x1540 so token count matches fixed vision encoder output.
Compares HF model output vs fixed PyTorch modules output.
"""
import os, sys, time, torch, torch.nn.functional as F
from pathlib import Path
from PIL import Image
# ── Constants ──
MODEL_DIR = "./models/LightOnOCR-2-1B"
FIXED_H, FIXED_W = 1120, 1540
PATCH_SIZE = 14
SPATIAL_MERGE = 2
MERGED_H = FIXED_H // PATCH_SIZE // SPATIAL_MERGE # 40
MERGED_W = FIXED_W // PATCH_SIZE // SPATIAL_MERGE # 55
NUM_IMG_TOKENS = MERGED_H * MERGED_W # 2200
NUM_VPAD_TOKENS = MERGED_H - 1 # 39
# Token IDs
IMAGE_TOKEN_ID = 151655
VISION_PAD_ID = 151654
VISION_END_ID = 151653
IM_START_ID = 151644
IM_END_ID = 151645
EOS_TOKEN_ID = 151645
# Decoder constants
NUM_LAYERS = 28
NUM_KV_HEADS = 8
NUM_HEADS = 16
HEAD_DIM = 128
HIDDEN_SIZE = 1024
MAX_SEQ_LEN = 4096 # Large enough for 2256 input + generation
def download_test_images():
"""Get real test images."""
import requests
os.makedirs("test_images", exist_ok=True)
images = {}
sources = {
"receipt": "https://huggingface.co/datasets/hf-internal-testing/fixtures_ocr/resolve/main/SROIE-receipt.jpeg",
}
for name, url in sources.items():
path = f"test_images/{name}.png"
if not os.path.exists(path):
print(f" Downloading {name}...")
try:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
with open(path, "wb") as f:
f.write(resp.content)
except Exception as e:
print(f" FAILED: {e}")
continue
img = Image.open(path).convert("RGB")
images[name] = img
print(f" {name}: {img.size}")
# Synthetic doc
img = Image.new("RGB", (800, 600), "white")
from PIL import ImageDraw
draw = ImageDraw.Draw(img)
draw.text((50, 50), "Invoice #12345", fill="black")
draw.text((50, 100), "Date: 2024-01-15", fill="black")
draw.text((50, 150), "Item 1: Widget x5 @ $10.00 = $50.00", fill="black")
draw.text((50, 200), "Item 2: Gadget x2 @ $24.99 = $49.98", fill="black")
draw.text((50, 250), "Total: $99.98", fill="black")
img.save("test_images/synthetic.png")
images["synthetic"] = img
print(f" synthetic: {img.size}")
return images
def build_fixed_input_ids(processor, text_prompt="OCR this document. Extract all text."):
"""
Build input_ids with exactly NUM_IMG_TOKENS image tokens,
matching our fixed 1120x1540 vision encoder output.
Uses the processor's chat template but with a fixed-size image.
"""
# Create a dummy image at exact target size to get correct token count
dummy_img = Image.new("RGB", (FIXED_W, FIXED_H), "white")
messages = [{"role": "user", "content": [
{"type": "image"}, {"type": "text", "text": text_prompt}
]}]
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = processor(text=text, images=[dummy_img], return_tensors="pt")
input_ids = inputs["input_ids"]
# Verify token counts
ids_list = input_ids[0].tolist()
n_img = ids_list.count(IMAGE_TOKEN_ID)
n_pad = ids_list.count(VISION_PAD_ID)
assert n_img == NUM_IMG_TOKENS, f"Expected {NUM_IMG_TOKENS} IMG tokens, got {n_img}"
print(f" Input template: {input_ids.shape[1]} tokens ({n_img} IMG, {n_pad} VPAD)")
return input_ids
def preprocess_image_fixed(img, processor):
"""Resize image to exactly FIXED_H x FIXED_W and get pixel_values."""
# Resize maintaining content, pad to target
img_resized = img.resize((FIXED_W, FIXED_H), Image.LANCZOS)
# Use processor's image normalization
dummy_messages = [{"role": "user", "content": [{"type": "image"}]}]
text = processor.apply_chat_template(dummy_messages, add_generation_prompt=True, tokenize=False)
inputs = processor(text=text, images=[img_resized], return_tensors="pt")
return inputs["pixel_values"]
def run_hf_model(images, processor):
"""Run original HF model with FIXED resolution preprocessing."""
from transformers import AutoModelForImageTextToText
from safetensors.torch import load_file
print("\n[HF Model] Loading...")
model = AutoModelForImageTextToText.from_pretrained(
MODEL_DIR, dtype=torch.bfloat16, attn_implementation="sdpa", device_map="cpu",
)
state_dict = load_file(os.path.join(MODEL_DIR, "model.safetensors"))
remapped = {k.replace("model.vision_encoder.", "model.vision_tower.")
.replace("model.vision_projection.", "model.multi_modal_projector."): v
for k, v in state_dict.items()}
model.load_state_dict(remapped, strict=False)
model = model.to("cuda").eval()
results = {}
for name, img in images.items():
print(f"\n [{name}] HF generate (fixed {FIXED_H}x{FIXED_W})...")
pixel_values = preprocess_image_fixed(img, processor).to("cuda")
input_ids = build_fixed_input_ids(processor).to("cuda")
attention_mask = torch.ones_like(input_ids)
input_len = input_ids.shape[1]
t0 = time.time()
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
image_sizes=torch.tensor([[FIXED_H, FIXED_W]], device="cuda"),
max_new_tokens=512, do_sample=False, temperature=None, top_p=None,
)
elapsed = time.time() - t0
new_ids = output_ids[0, input_len:]
text = processor.tokenizer.decode(new_ids, skip_special_tokens=True)
n_tok = len(new_ids)
print(f" {n_tok} tokens, {elapsed:.1f}s ({n_tok/elapsed:.1f} tok/s)")
print(f" Output: {text[:200]}...")
results[name] = {"text": text, "tokens": n_tok, "time": elapsed}
del model
torch.cuda.empty_cache()
return results
def run_fixed_modules(images, processor):
"""Run fixed PyTorch modules E2E with proper token matching."""
sys.path.insert(0, ".")
from export_vision import build_vision_module, load_original_model
from export_decoder import TextDecoderFixed, build_decoder_module
print("\n[Fixed Modules] Loading...")
orig = load_original_model()
vision = build_vision_module(orig)
decoder = build_decoder_module(orig)
embed_tokens = orig.model.language_model.embed_tokens
device = "cuda"
dtype = torch.bfloat16
vision = vision.to(device).to(dtype).eval()
decoder = decoder.to(device).to(dtype).eval()
embed_tokens = embed_tokens.to(device).to(dtype)
del orig
torch.cuda.empty_cache()
results = {}
for name, img in images.items():
print(f"\n [{name}] Fixed modules E2E...")
try:
# Step 1: Get pixel values at fixed resolution
pixel_values = preprocess_image_fixed(img, processor).to(device).to(dtype)
input_ids = build_fixed_input_ids(processor).to(device)
# Step 2: Vision encoder
with torch.no_grad():
image_features = vision(pixel_values) # [1, 2200, 1024]
print(f" Vision output: {image_features.shape}")
# Step 3: Build combined embeddings
with torch.no_grad():
text_embeds = embed_tokens(input_ids) # [1, seq_len, 1024]
ids_list = input_ids[0].tolist()
img_positions = [i for i, t in enumerate(ids_list) if t == IMAGE_TOKEN_ID]
assert len(img_positions) == image_features.shape[1], \
f"Token/feature mismatch: {len(img_positions)} slots vs {image_features.shape[1]} features"
combined = text_embeds.clone()
# Scatter vision features into IMG token positions
indices = torch.tensor(img_positions, device=device)
combined[0, indices] = image_features[0]
seq_len = combined.shape[1]
print(f" Combined seq: {seq_len}, scattering {len(img_positions)} features")
# Step 4: Prefill — feed combined embeddings through decoder
kv_caches = []
for _ in range(NUM_LAYERS):
k = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
v = torch.zeros(1, NUM_KV_HEADS, MAX_SEQ_LEN, HEAD_DIM, dtype=dtype, device=device)
kv_caches.extend([k, v])
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
cache_position = torch.arange(seq_len, device=device)
# Causal mask for prefill
mask = torch.full((1, 1, seq_len, MAX_SEQ_LEN), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
mask[0, 0, i, :i+1] = 0.0
# Monkey-patch embed_tokens for prefill (we already have embeddings)
orig_embed = decoder.embed_tokens
class PrefillEmbed(torch.nn.Module):
def __init__(self, embeds):
super().__init__()
self.e = embeds
def forward(self, x):
return self.e
decoder.embed_tokens = PrefillEmbed(combined)
t0 = time.time()
with torch.no_grad():
result = decoder(input_ids[:, :seq_len], mask, position_ids, cache_position, *kv_caches)
decoder.embed_tokens = orig_embed
logits = result[0]
kv_caches = list(result[1:])
next_token = logits[0, -1].argmax().item()
generated = [next_token]
cur_pos = seq_len
# Step 5: Autoregressive decode
for step in range(511):
if next_token == EOS_TOKEN_ID or cur_pos >= MAX_SEQ_LEN:
break
token_input = torch.tensor([[next_token]], device=device)
pos_ids = torch.tensor([[cur_pos]], device=device)
cache_pos = torch.tensor([cur_pos], device=device)
# Decode mask: attend to all positions up to cur_pos
dmask = torch.zeros(1, 1, 1, MAX_SEQ_LEN, dtype=dtype, device=device)
dmask[0, 0, 0, cur_pos+1:] = float("-inf")
with torch.no_grad():
result = decoder(token_input, dmask, pos_ids, cache_pos, *kv_caches)
logits = result[0]
kv_caches = list(result[1:])
next_token = logits[0, -1].argmax().item()
generated.append(next_token)
cur_pos += 1
elapsed = time.time() - t0
text = processor.tokenizer.decode(generated, skip_special_tokens=True)
n_tok = len(generated)
print(f" {n_tok} tokens, {elapsed:.1f}s ({n_tok/elapsed:.1f} tok/s)")
print(f" Output: {text[:200]}...")
results[name] = {"text": text, "tokens": n_tok, "time": elapsed}
except Exception as e:
import traceback
traceback.print_exc()
results[name] = {"text": f"ERROR: {e}", "tokens": 0, "time": 0}
return results
def levenshtein(s1, s2):
if len(s1) < len(s2): return levenshtein(s2, s1)
if len(s2) == 0: return len(s1)
prev = list(range(len(s2) + 1))
for i, c1 in enumerate(s1):
curr = [i + 1]
for j, c2 in enumerate(s2):
curr.append(min(prev[j+1]+1, curr[j]+1, prev[j]+(c1!=c2)))
prev = curr
return prev[-1]
def compare(hf_results, fx_results, images):
print("\n" + "="*70)
print("E2E COMPARISON (Fixed 1120x1540 resolution)")
print("="*70)
for name in images:
hf = hf_results[name]
fx = fx_results[name]
hf_t, fx_t = hf["text"], fx["text"]
exact = hf_t.strip() == fx_t.strip()
ed = levenshtein(hf_t, fx_t)
max_len = max(len(hf_t), len(fx_t), 1)
char_acc = 1.0 - ed / max_len
ref_words = set(hf_t.lower().split())
hyp_words = set(fx_t.lower().split())
union = ref_words | hyp_words
word_acc = len(ref_words & hyp_words) / len(union) if union else 1.0
print(f"\n{'─'*70}")
print(f" [{name}]")
print(f" HF ({hf['tokens']} tok): {hf_t[:250]}")
print(f" FIX ({fx['tokens']} tok): {fx_t[:250]}")
print(f" Exact: {'✅ YES' if exact else '❌ NO'}")
print(f" Edit dist: {ed}, Char acc: {char_acc:.4f}, Word acc: {word_acc:.4f}")
print("\n" + "="*70)
def main():
print("LightOnOCR-2-1B E2E Validation v2")
print(f"Fixed resolution: {FIXED_H}x{FIXED_W}{NUM_IMG_TOKENS} vision features")
print(f"Device: cuda, Max seq: {MAX_SEQ_LEN}")
print("="*70)
images = download_test_images()
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(MODEL_DIR)
hf_results = run_hf_model(images, processor)
# Free GPU for fixed modules
torch.cuda.empty_cache()
fx_results = run_fixed_modules(images, processor)
compare(hf_results, fx_results, images)
if __name__ == "__main__":
main()