acul3's picture
Upload scripts/quantize_wo.py with huggingface_hub
f7962bb verified
#!/usr/bin/env python3
"""
Weight-only INT8 quantization β€” no calibration, no forward passes needed.
Uses torchao int8_weight_only which packs weights instantly.
Then re-exports to ExecuTorch XNNPACK .pte.
"""
import os, sys, time, gc, torch
sys.path.insert(0, ".")
MODEL_DIR = "./models/LightOnOCR-2-1B"
FIXED_H, FIXED_W = 1120, 1540
def quantize_vision(orig):
from export_vision import build_vision_module
from torchao.quantization import quantize_, int8_weight_only
from torch.export import export
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
print("\n=== VISION ENCODER (INT8 weight-only) ===")
vision = build_vision_module(orig)
vision = vision.to("cpu").to(torch.float32).eval()
print(f" Params: {sum(p.numel() for p in vision.parameters())/1e6:.1f}M")
# Weight-only quantization β€” instant, no forward pass
print(" Applying int8_weight_only...")
t0 = time.time()
quantize_(vision, int8_weight_only())
print(f" Quantization took {time.time()-t0:.1f}s")
# Export
print(" torch.export...")
example = (torch.randn(1, 3, FIXED_H, FIXED_W),)
t0 = time.time()
ep = export(vision, example)
print(f" Export took {time.time()-t0:.1f}s")
# Lower to XNNPACK
print(" XNNPACK lowering...")
t0 = time.time()
edge = to_edge_transform_and_lower(
ep,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()]
)
et = edge.to_executorch()
print(f" Lowering took {time.time()-t0:.1f}s")
path = "vision_encoder_int8.pte"
with open(path, "wb") as f:
f.write(et.buffer)
print(f" βœ… {path}: {os.path.getsize(path)/1024/1024:.1f} MB")
del vision, ep, edge, et; gc.collect()
return path
def quantize_decoder(orig):
import export_decoder as ed
from export_decoder import build_decoder_module
from torchao.quantization import quantize_, int8_weight_only
from torch.export import export
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
print("\n=== TEXT DECODER (INT8 weight-only) ===")
decoder = build_decoder_module(orig)
decoder = decoder.to("cpu").to(torch.float32).eval()
print(f" Params: {sum(p.numel() for p in decoder.parameters())/1e6:.1f}M")
# Weight-only quantization β€” instant
print(" Applying int8_weight_only...")
t0 = time.time()
quantize_(decoder, int8_weight_only())
print(f" Quantization took {time.time()-t0:.1f}s")
# Export
print(" torch.export...")
kv = ed.create_empty_kv_caches(1, torch.float32, "cpu")
example = (
torch.ones(1, 8, dtype=torch.long),
ed.create_causal_mask(8, ed.MAX_SEQ_LEN, torch.float32),
torch.arange(8).unsqueeze(0),
torch.arange(8),
*kv,
)
t0 = time.time()
ep = export(decoder, example)
print(f" Export took {time.time()-t0:.1f}s")
# Lower
print(" XNNPACK lowering...")
t0 = time.time()
edge = to_edge_transform_and_lower(
ep,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()]
)
et = edge.to_executorch()
print(f" Lowering took {time.time()-t0:.1f}s")
path = "text_decoder_int8.pte"
with open(path, "wb") as f:
f.write(et.buffer)
print(f" βœ… {path}: {os.path.getsize(path)/1024/1024:.1f} MB")
del decoder, ep, edge, et; gc.collect()
return path
def main():
from export_vision import load_original_model
print("LightOnOCR INT8 Weight-Only Quantization")
print("No calibration needed β€” weights quantized instantly\n")
print("Loading model...")
orig = load_original_model()
vis_path = quantize_vision(orig)
dec_path = quantize_decoder(orig)
del orig; gc.collect()
print("\n=== RESULTS ===")
for fp32, int8 in [("vision_encoder.pte", vis_path),
("text_decoder_4096.pte", dec_path)]:
if os.path.exists(fp32) and os.path.exists(int8):
orig_mb = os.path.getsize(fp32) / 1024 / 1024
quant_mb = os.path.getsize(int8) / 1024 / 1024
ratio = quant_mb / orig_mb * 100
print(f" {fp32}: {orig_mb:.0f} MB β†’ {int8}: {quant_mb:.0f} MB ({ratio:.0f}%)")
if __name__ == "__main__":
main()