#!/usr/bin/env python3 """ Weight-only INT8 quantization — no calibration, no forward passes needed. Uses torchao int8_weight_only which packs weights instantly. Then re-exports to ExecuTorch XNNPACK .pte. """ import os, sys, time, gc, torch sys.path.insert(0, ".") MODEL_DIR = "./models/LightOnOCR-2-1B" FIXED_H, FIXED_W = 1120, 1540 def quantize_vision(orig): from export_vision import build_vision_module from torchao.quantization import quantize_, int8_weight_only from torch.export import export from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner print("\n=== VISION ENCODER (INT8 weight-only) ===") vision = build_vision_module(orig) vision = vision.to("cpu").to(torch.float32).eval() print(f" Params: {sum(p.numel() for p in vision.parameters())/1e6:.1f}M") # Weight-only quantization — instant, no forward pass print(" Applying int8_weight_only...") t0 = time.time() quantize_(vision, int8_weight_only()) print(f" Quantization took {time.time()-t0:.1f}s") # Export print(" torch.export...") example = (torch.randn(1, 3, FIXED_H, FIXED_W),) t0 = time.time() ep = export(vision, example) print(f" Export took {time.time()-t0:.1f}s") # Lower to XNNPACK print(" XNNPACK lowering...") t0 = time.time() edge = to_edge_transform_and_lower( ep, compile_config=EdgeCompileConfig(_check_ir_validity=False), partitioner=[XnnpackPartitioner()] ) et = edge.to_executorch() print(f" Lowering took {time.time()-t0:.1f}s") path = "vision_encoder_int8.pte" with open(path, "wb") as f: f.write(et.buffer) print(f" ✅ {path}: {os.path.getsize(path)/1024/1024:.1f} MB") del vision, ep, edge, et; gc.collect() return path def quantize_decoder(orig): import export_decoder as ed from export_decoder import build_decoder_module from torchao.quantization import quantize_, int8_weight_only from torch.export import export from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner print("\n=== TEXT DECODER (INT8 weight-only) ===") decoder = build_decoder_module(orig) decoder = decoder.to("cpu").to(torch.float32).eval() print(f" Params: {sum(p.numel() for p in decoder.parameters())/1e6:.1f}M") # Weight-only quantization — instant print(" Applying int8_weight_only...") t0 = time.time() quantize_(decoder, int8_weight_only()) print(f" Quantization took {time.time()-t0:.1f}s") # Export print(" torch.export...") kv = ed.create_empty_kv_caches(1, torch.float32, "cpu") example = ( torch.ones(1, 8, dtype=torch.long), ed.create_causal_mask(8, ed.MAX_SEQ_LEN, torch.float32), torch.arange(8).unsqueeze(0), torch.arange(8), *kv, ) t0 = time.time() ep = export(decoder, example) print(f" Export took {time.time()-t0:.1f}s") # Lower print(" XNNPACK lowering...") t0 = time.time() edge = to_edge_transform_and_lower( ep, compile_config=EdgeCompileConfig(_check_ir_validity=False), partitioner=[XnnpackPartitioner()] ) et = edge.to_executorch() print(f" Lowering took {time.time()-t0:.1f}s") path = "text_decoder_int8.pte" with open(path, "wb") as f: f.write(et.buffer) print(f" ✅ {path}: {os.path.getsize(path)/1024/1024:.1f} MB") del decoder, ep, edge, et; gc.collect() return path def main(): from export_vision import load_original_model print("LightOnOCR INT8 Weight-Only Quantization") print("No calibration needed — weights quantized instantly\n") print("Loading model...") orig = load_original_model() vis_path = quantize_vision(orig) dec_path = quantize_decoder(orig) del orig; gc.collect() print("\n=== RESULTS ===") for fp32, int8 in [("vision_encoder.pte", vis_path), ("text_decoder_4096.pte", dec_path)]: if os.path.exists(fp32) and os.path.exists(int8): orig_mb = os.path.getsize(fp32) / 1024 / 1024 quant_mb = os.path.getsize(int8) / 1024 / 1024 ratio = quant_mb / orig_mb * 100 print(f" {fp32}: {orig_mb:.0f} MB → {int8}: {quant_mb:.0f} MB ({ratio:.0f}%)") if __name__ == "__main__": main()