David Golchinfar
feat: Show VLM loading hint when checkbox is enabled
25f9f39
"""
SauerkrautLM-ColPali Demo Space
Visual Document Retrieval with Similarity Heat Maps + VLM Answer Generation
Multi-document indexing for realistic retrieval scenarios
"""
import os
import gradio as gr
import torch
import numpy as np
from PIL import Image
from einops import rearrange
from huggingface_hub import login
import spaces
import math
from pathlib import Path
from typing import List, Tuple, Optional, Dict
# Import model classes at startup
from sauerkrautlm_colpali.models.lfm2.collfm2.modeling_collfm2 import ColLFM2
from sauerkrautlm_colpali.models.lfm2.collfm2.processing_collfm2 import ColLFM2Processor
from sauerkrautlm_colpali.models.qwen3.colqwen3.modeling_colqwen3 import ColQwen3
from sauerkrautlm_colpali.models.qwen3.colqwen3.processing_colqwen3 import ColQwen3Processor
print("All model imports successful!")
EPSILON = 1e-10
def install_fa2():
print("Installing Flash Attention 2...")
os.system("pip install flash-attn --no-build-isolation")
# HF Token for private models
hf_token = os.getenv("HF_KEY")
if hf_token:
login(token=hf_token)
# ColPali Model options
COLPALI_MODELS = {
"SauerkrautLM-ColLFM2-450M (Fastest, 0.9GB)": "VAGOsolutions/SauerkrautLM-ColLFM2-450M-v0.1",
"SauerkrautLM-ColQwen3-1.7B-Turbo (Fast, 3.4GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-1.7b-Turbo-v0.1",
"SauerkrautLM-ColQwen3-2B (Balanced, 4.4GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-2b-v0.1",
"SauerkrautLM-ColQwen3-4B (Quality, 8GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-4b-v0.1",
"SauerkrautLM-ColQwen3-8B (Best, 16GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-8b-v0.1",
}
# Global model cache
loaded_colpali_model = None
loaded_colpali_processor = None
loaded_colpali_model_name = None
loaded_vlm_model = None
loaded_vlm_processor = None
# =============================================================================
# EXAMPLE CONFIGURATION - Organized by language and use case
# =============================================================================
EXAMPLE_CONFIG = [
# ==================== GERMAN (DE) ====================
# Annual Reports
{
"file": "deutsch/2024-infineon-geschaeftsbericht-v01-00-de_p4.png",
"query": "Wie hoch ist der Umsatz von Infineon?",
"category": "πŸ“Š Annual Report",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "Infineon GeschΓ€ftsbericht 2024 - Kennzahlen"
},
{
"file": "deutsch/BASF_Bericht_2024_p90.png",
"query": "Wie ist das Risikomanagement bei BASF organisiert?",
"category": "πŸ“Š Annual Report",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "BASF Bericht 2024 - Risikomanagement"
},
{
"file": "deutsch/entire-dtag-gb24_p351.png",
"query": "Wie setzt sich die VorstandsvergΓΌtung bei der Deutschen Telekom zusammen?",
"category": "πŸ“Š Annual Report",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "Deutsche Telekom GB 2024 - VergΓΌtung"
},
{
"file": "deutsch/entire-dtag-gb24_p77.png",
"query": "Wie hoch sind die Nettofinanzverbindlichkeiten der Deutschen Telekom?",
"category": "πŸ“Š Annual Report",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "Deutsche Telekom GB 2024 - Finanzen"
},
# Hydrogen / Energy
{
"file": "deutsch/Bildschirmfoto 2025-12-14 um 01.31.45.png",
"query": "Wo verlΓ€uft die geplante Wasserstofftrasse in Leipzig?",
"category": "⚑ Hydrogen/Energy",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "Wasserstoff-Infrastruktur Leipzig"
},
{
"file": "deutsch/Bildschirmfoto 2025-12-14 um 01.36.01.png",
"query": "Wie wird die GewΓ€sserquerung der Wasserstoffleitung realisiert?",
"category": "⚑ Hydrogen/Energy",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "Wasserstoff-Technische Zeichnung"
},
# Tax/Forms
{
"file": "deutsch/ESt_1_A_2022_p2.png",
"query": "Wo trage ich meine Bankverbindung in der SteuererklΓ€rung ein?",
"category": "πŸ“ Tax Form",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "EinkommensteuererklΓ€rung ESt 1A"
},
# Economic Reports
{
"file": "deutsch/Monatsbericht---Oktober-2025_p152.png",
"query": "Wie hoch sind die aktuellen ZinssΓ€tze fΓΌr Wohnungsbaukredite?",
"category": "πŸ’° Financial Report",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "Bundesbank Monatsbericht - ZinssΓ€tze"
},
{
"file": "deutsch/sd-2025-digital-07-wollmershaeuser-etal-ifo-konjunkturprognose-sommer-2025_p23.png",
"query": "Was sind die Annahmen fΓΌr den Γ–lpreis in der ifo-Prognose?",
"category": "πŸ“ˆ Economic Forecast",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "ifo Konjunkturprognose 2025"
},
# Environmental
{
"file": "deutsch/rep0913_p197.png",
"query": "Wie hat sich der Fleischkonsum in Γ–sterreich entwickelt?",
"category": "🌱 Environmental",
"lang": "πŸ‡©πŸ‡ͺ DE",
"description": "Klimaschutzbericht - Umweltbundesamt"
},
# ==================== ENGLISH (EN) ====================
# ESG/Sustainability
{
"file": "englisch/2025051910270996484_p19.png",
"query": "What is the waste recycling rate at CRRC?",
"category": "🌱 ESG Report",
"lang": "πŸ‡¬πŸ‡§ EN",
"description": "CRRC ESG Report 2024 - Waste Management"
},
# Scientific
{
"file": "englisch/6e81bb8284357ea1773e99832d21c65b_new_myosinlecture_ag_p5.png",
"query": "What are the different classes of myosin?",
"category": "πŸ”¬ Scientific Paper",
"lang": "πŸ‡¬πŸ‡§ EN",
"description": "Myosin Phylogenetic Tree"
},
# Historical/Vintage
{
"file": "englisch/ADVE_0004.png",
"query": "Which cigarette brand advertised 'Call for Philip Morris'?",
"category": "πŸ“œ Historical Document",
"lang": "πŸ‡¬πŸ‡§ EN",
"description": "Vintage Philip Morris Advertisement"
},
# Business Forms
{
"file": "englisch/Form_0033.png",
"query": "How long should vendor audit records be retained?",
"category": "πŸ“ Business Form",
"lang": "πŸ‡¬πŸ‡§ EN",
"description": "Records Retention Schedule"
},
{
"file": "englisch/Letter_0061.png",
"query": "Who requested copies of the 'Helping Youth Decide' booklet?",
"category": "βœ‰οΈ Business Letter",
"lang": "πŸ‡¬πŸ‡§ EN",
"description": "Girl Scouts Correspondence"
},
# Financial Reports
{
"file": "englisch/NASDAQ_DDD_2024_p76.png",
"query": "What is the total stockholders' equity of 3D Systems?",
"category": "πŸ“Š Annual Report",
"lang": "πŸ‡¬πŸ‡§ EN",
"description": "3D Systems Financial Statement"
},
{
"file": "englisch/TMUS-2024-Annual-Report_p143.png",
"query": "Who is the CEO of T-Mobile US?",
"category": "πŸ“Š Annual Report",
"lang": "πŸ‡¬πŸ‡§ EN",
"description": "T-Mobile US Annual Report 2024"
},
{
"file": "englisch/pwc-transparency-report-2023-2024_p33.png",
"query": "What are the rotation periods for audit partners at PwC?",
"category": "πŸ“‹ Transparency Report",
"lang": "πŸ‡¬πŸ‡§ EN",
"description": "PwC Transparency Report"
},
# ==================== FRENCH (FR) ====================
{
"file": "franzΓΆsisch/194000315_0_p178.png",
"query": "Quel est le coΓ»t du travail au SMIC en France?",
"category": "πŸ’° Labor Statistics",
"lang": "πŸ‡«πŸ‡· FR",
"description": "Statistiques du travail - SMIC"
},
{
"file": "franzΓΆsisch/194000315_0_p21.png",
"query": "Quelle est la prΓ©vision de croissance du PIB en zone euro?",
"category": "πŸ“ˆ Economic Forecast",
"lang": "πŸ‡«πŸ‡· FR",
"description": "PrΓ©visions Γ©conomiques Zone Euro"
},
{
"file": "franzΓΆsisch/Cours-de-physique-1v2_p44.png",
"query": "Comment fonctionne la vision des couleurs?",
"category": "πŸŽ“ Educational",
"lang": "πŸ‡«πŸ‡· FR",
"description": "Cours de Physique - Vision"
},
{
"file": "franzΓΆsisch/CSSF_RA_2024_FR_p14.png",
"query": "Quelle est la rΓ©partition des employΓ©s de la CSSF par nationalitΓ©?",
"category": "πŸ“Š Annual Report",
"lang": "πŸ‡«πŸ‡· FR",
"description": "CSSF Luxembourg - Rapport Annuel"
},
{
"file": "franzΓΆsisch/ICN_Definition-Nursing_Report_FR_Web_p47.png",
"query": "Combien d'associations nationales d'infirmières participent au CII?",
"category": "πŸ₯ Healthcare Report",
"lang": "πŸ‡«πŸ‡· FR",
"description": "ICN Rapport Infirmières"
},
{
"file": "franzΓΆsisch/rapport-cns-2024-internet_p14.png",
"query": "Quelles sont les principales missions de la CNS?",
"category": "πŸ₯ Healthcare Report",
"lang": "πŸ‡«πŸ‡· FR",
"description": "CNS Rapport Annuel"
},
{
"file": "franzΓΆsisch/rapport-esg-2024.pdf.coredownload.inline_p29.png",
"query": "Quels sont les objectifs ESG pour 2024?",
"category": "🌱 ESG Report",
"lang": "πŸ‡«πŸ‡· FR",
"description": "Rapport ESG 2024"
},
# ==================== SPANISH (ES) ====================
{
"file": "spanisch/Coeur-ESG-Report-23-May-2024-Spanish-version-compressed_p31.png",
"query": "ΒΏCuΓ‘les son las emisiones de gases de efecto invernadero de Coeur Mining?",
"category": "🌱 ESG Report",
"lang": "πŸ‡ͺπŸ‡Έ ES",
"description": "Coeur Mining ESG - Emisiones"
},
{
"file": "spanisch/Coeur-ESG-Report-23-May-2024-Spanish-version-compressed_p39.png",
"query": "ΒΏQuΓ© medidas de seguridad implementa Coeur Mining?",
"category": "🌱 ESG Report",
"lang": "πŸ‡ͺπŸ‡Έ ES",
"description": "Coeur Mining ESG - Seguridad"
},
{
"file": "spanisch/Informe-Economico-Regional-2022-2023_p112.png",
"query": "ΒΏCuΓ‘l es la situaciΓ³n econΓ³mica regional en 2023?",
"category": "πŸ“ˆ Economic Report",
"lang": "πŸ‡ͺπŸ‡Έ ES",
"description": "Informe EconΓ³mico Regional"
},
{
"file": "spanisch/Informe-Sostenibilidad-ESG-2024_p16.png",
"query": "ΒΏCuΓ‘les son los objetivos de sostenibilidad para 2024?",
"category": "🌱 ESG Report",
"lang": "πŸ‡ͺπŸ‡Έ ES",
"description": "Informe Sostenibilidad ESG"
},
{
"file": "spanisch/MAPs_PLAN_DESARROLLO_p14.png",
"query": "ΒΏCuΓ‘ntas propuestas de renovables se recibieron en EspaΓ±a?",
"category": "⚑ Energy Infrastructure",
"lang": "πŸ‡ͺπŸ‡Έ ES",
"description": "Plan Desarrollo Red ElΓ©ctrica"
},
{
"file": "spanisch/Presupuestos_p32.png",
"query": "ΒΏCuΓ‘l es el presupuesto total asignado?",
"category": "πŸ’° Budget Document",
"lang": "πŸ‡ͺπŸ‡Έ ES",
"description": "Presupuestos Generales"
},
]
def get_all_example_images() -> List[Tuple[str, Image.Image]]:
"""Load all example images for multi-document indexing."""
examples_dir = Path(__file__).parent / "demopics"
images = []
for example in EXAMPLE_CONFIG:
filepath = examples_dir / example["file"]
if filepath.exists():
try:
img = Image.open(filepath).convert("RGB")
images.append((str(filepath), img))
except Exception as e:
print(f"Error loading {filepath}: {e}")
return images
def get_available_examples():
"""Load examples for the Gradio Examples component (shuffled)."""
import random
examples_dir = Path(__file__).parent / "demopics"
available = []
for example in EXAMPLE_CONFIG:
filepath = examples_dir / example["file"]
if filepath.exists():
available.append([str(filepath), example["query"]])
# Shuffle to mix languages
random.seed(42) # Consistent shuffle
random.shuffle(available)
return available if available else None
def get_example_gallery_data():
"""Get data for the example gallery with categories."""
examples_dir = Path(__file__).parent / "demopics"
gallery_data = []
for example in EXAMPLE_CONFIG:
filepath = examples_dir / example["file"]
if filepath.exists():
gallery_data.append({
"path": str(filepath),
"query": example["query"],
"label": f"{example['lang']} {example['category']}: {example['description']}",
"category": example["category"],
"lang": example["lang"],
})
return gallery_data
@spaces.GPU
def load_colpali_model(model_choice: str):
"""Load the selected ColPali model with proper device placement."""
global loaded_colpali_model, loaded_colpali_processor, loaded_colpali_model_name
model_name = COLPALI_MODELS[model_choice]
if loaded_colpali_model_name == model_name and loaded_colpali_model is not None:
gr.Info(f"βœ… {model_choice} ready!")
return loaded_colpali_model, loaded_colpali_processor
gr.Info(f"⏳ Loading {model_choice}... Please wait.")
if loaded_colpali_model is not None:
gr.Info("πŸ”„ Unloading previous model...")
try:
del loaded_colpali_model
del loaded_colpali_processor
except Exception:
pass
torch.cuda.empty_cache()
try:
import flash_attn
attn_impl = "flash_attention_2"
gr.Info("⚑ Using Flash Attention 2")
except ImportError:
attn_impl = "sdpa"
gr.Info("πŸ”§ Using SDPA attention")
print(f"Loading {model_name} with attention: {attn_impl}")
if "ColLFM2" in model_name:
gr.Info("πŸ“₯ Downloading model weights...")
loaded_colpali_model = ColLFM2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation=attn_impl,
token=hf_token,
).eval().to("cuda")
gr.Info("πŸ“₯ Downloading processor...")
loaded_colpali_processor = ColLFM2Processor.from_pretrained(model_name, token=hf_token)
elif "ColQwen3" in model_name:
gr.Info("πŸ“₯ Downloading model weights...")
loaded_colpali_model = ColQwen3.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation=attn_impl,
device_map="cuda",
token=hf_token,
).eval()
gr.Info("πŸ“₯ Downloading processor...")
loaded_colpali_processor = ColQwen3Processor.from_pretrained(model_name, token=hf_token)
else:
raise ValueError(f"Unknown model type: {model_name}")
loaded_colpali_model_name = model_name
gr.Info(f"βœ… {model_choice} loaded and ready!")
return loaded_colpali_model, loaded_colpali_processor
@spaces.GPU
def load_vlm_model():
"""Load Qwen3-VL-4B for answer generation."""
global loaded_vlm_model, loaded_vlm_processor
if loaded_vlm_model is not None:
gr.Info("βœ… Qwen3-VL-4B ready!")
return loaded_vlm_model, loaded_vlm_processor
gr.Info("⏳ Loading Qwen3-VL-4B-Instruct... Please wait.")
from transformers import AutoModelForImageTextToText, AutoProcessor
vlm_model_name = "Qwen/Qwen3-VL-4B-Instruct"
print(f"Loading VLM: {vlm_model_name}")
gr.Info("πŸ“₯ Downloading VLM model weights (8GB)...")
loaded_vlm_model = AutoModelForImageTextToText.from_pretrained(
vlm_model_name,
torch_dtype=torch.bfloat16,
device_map="cuda",
token=hf_token,
).eval()
gr.Info("πŸ“₯ Downloading VLM processor...")
loaded_vlm_processor = AutoProcessor.from_pretrained(vlm_model_name, token=hf_token)
gr.Info("βœ… Qwen3-VL-4B loaded and ready!")
return loaded_vlm_model, loaded_vlm_processor
def on_vlm_toggle(enabled):
"""Show hint when VLM is enabled."""
if enabled:
gr.Info("ℹ️ VLM (Qwen3-VL-4B) will be loaded on first analysis. This adds ~30-60 seconds.")
return enabled
def get_similarity_maps_from_embeddings(
image_embeddings: torch.Tensor,
query_embeddings: torch.Tensor,
n_patches: Tuple[int, int],
image_mask: torch.Tensor,
query_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""EXACT ColPali implementation of similarity map computation."""
idx = 0
n_patches_x, n_patches_y = n_patches[0], n_patches[1]
n_image_tokens = int(image_mask[idx].sum().item())
expected_tokens = n_patches_x * n_patches_y
if n_image_tokens != expected_tokens:
n = n_image_tokens
sqrt_n = int(math.sqrt(n))
for i in range(sqrt_n, 0, -1):
if n % i == 0:
n_patches_x, n_patches_y = n // i, i
break
image_embedding_grid = rearrange(
image_embeddings[idx][image_mask[idx]],
"(h w) c -> w h c",
w=n_patches_x,
h=n_patches_y,
)
query_emb = query_embeddings[idx]
if query_mask is not None:
query_emb = query_emb[query_mask[idx]]
similarity_map = torch.einsum(
"nk,ijk->nij",
query_emb,
image_embedding_grid,
)
return similarity_map
def create_heatmap_overlay(
image: Image.Image,
similarity_map: torch.Tensor,
alpha: float = 0.5,
skip_normalize: bool = False,
) -> Image.Image:
"""Create heatmap overlay following EXACT ColPali visualization."""
import seaborn as sns
sim_float = similarity_map.float()
if skip_normalize:
sim_array = sim_float.cpu().numpy()
else:
min_val = sim_float.min()
max_val = sim_float.max()
sim_normalized = (sim_float - min_val) / (max_val - min_val + 1e-10)
sim_array = sim_normalized.cpu().numpy()
sim_array = rearrange(sim_array, "h w -> w h")
sim_image = Image.fromarray((sim_array * 255).astype(np.uint8))
sim_image = sim_image.resize(image.size, Image.Resampling.BICUBIC)
sim_resized = np.array(sim_image) / 255.0
cmap = sns.color_palette("mako", as_cmap=True)
heatmap_rgba = cmap(sim_resized)
heatmap = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
img_array = np.array(image.convert("RGB")).astype(np.float32)
heatmap_float = heatmap.astype(np.float32)
blended = img_array * (1 - alpha) + heatmap_float * alpha
blended = np.clip(blended, 0, 255).astype(np.uint8)
return Image.fromarray(blended)
def get_collfm2_heatmap(model, processor, image, image_embeddings, query_embeddings, batch_images, batch_queries):
"""Generate heatmap for ColLFM2 models (simplified version)."""
try:
if "input_ids" not in batch_images:
return None
input_ids = batch_images["input_ids"][0]
tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor.processor.tokenizer
image_token_id = tokenizer.convert_tokens_to_ids('<image>')
image_mask = input_ids == image_token_id
n_image_tokens = image_mask.sum().item()
# Find best grid that matches the token count and image aspect ratio
img_width, img_height = image.size
img_ratio = img_width / img_height
best_diff = float('inf')
n_patches_x, n_patches_y = int(math.sqrt(n_image_tokens)), int(math.sqrt(n_image_tokens))
for i in range(1, int(math.sqrt(n_image_tokens)) + 1):
if n_image_tokens % i == 0:
j = n_image_tokens // i
ratio1 = j / i
ratio2 = i / j
diff1 = abs(img_ratio - ratio1)
diff2 = abs(img_ratio - ratio2)
if diff1 < best_diff:
best_diff = diff1
n_patches_x, n_patches_y = j, i
if diff2 < best_diff:
best_diff = diff2
n_patches_x, n_patches_y = i, j
query_emb = query_embeddings[0]
# Filter padding
pad_token_id = getattr(tokenizer, "pad_token_id", 0) or 0
query_mask = batch_queries["input_ids"][0] != pad_token_id
query_emb = query_emb[query_mask]
# Get image embeddings
image_emb = image_embeddings[0][image_mask][:n_patches_x * n_patches_y]
image_grid = rearrange(image_emb, "(h w) c -> w h c", w=n_patches_x, h=n_patches_y)
# Compute similarity
similarity_map = torch.einsum("nk,ijk->nij", query_emb, image_grid)
aggregated = similarity_map.max(dim=0).values
# Aggressive normalization for ColLFM2
agg_float = aggregated.float()
threshold = torch.quantile(agg_float.flatten(), 0.90)
hot_mask = agg_float > threshold
min_hot = agg_float[hot_mask].min() if hot_mask.sum() > 0 else threshold
max_hot = agg_float.max()
normalized = torch.zeros_like(agg_float)
if hot_mask.sum() > 0:
normalized[hot_mask] = 0.5 + 0.5 * (agg_float[hot_mask] - min_hot) / (max_hot - min_hot + 1e-10)
return create_heatmap_overlay(image, normalized, skip_normalize=True)
except Exception as e:
import traceback
print(f"ColLFM2 Heatmap error: {traceback.format_exc()}")
return None
@spaces.GPU
def generate_vlm_answer(image: Image.Image, query: str) -> str:
"""Generate an answer using Qwen3-VL-4B-Instruct."""
try:
vlm_model, vlm_processor = load_vlm_model()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": f"Based on this document image, please answer the following question:\n\n{query}\n\nProvide a clear and concise answer based only on the information visible in the document."},
],
}
]
text = vlm_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
from qwen_vl_utils import process_vision_info
image_inputs, video_inputs = process_vision_info(messages)
inputs = vlm_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
with torch.no_grad():
generated_ids = vlm_model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
answer = vlm_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return answer
except Exception as e:
import traceback
return f"Error generating answer: {str(e)}"
@spaces.GPU
def process_query_with_corpus(model_choice: str, image: Image.Image, query: str, enable_vlm: bool, enable_corpus: bool):
"""
Process a query against an image and optionally against all corpus documents.
Returns similarity score, ranking info, heatmap, and optional VLM answer.
"""
if image is None:
return None, "⚠️ Please upload an image.", None, "", None
if not query.strip():
return None, "⚠️ Please enter a search query.", None, "", None
try:
model, processor = load_colpali_model(model_choice)
device = next(model.parameters()).device
if image.mode != "RGB":
image = image.convert("RGB")
# Process query
batch_queries = processor.process_queries([query]).to(device)
with torch.no_grad():
query_embeddings = model(**batch_queries)
# Process main image
batch_images = processor.process_images([image]).to(device)
with torch.no_grad():
image_embeddings = model(**batch_images)
gr.Info("πŸ“Š Computing similarity scores...")
scores = processor.score(query_embeddings, image_embeddings)
main_score = scores[0][0].item()
# Initialize for VLM
ranking_info = None
top_document_image = image
top_document_label = "Your Document"
if enable_corpus:
gr.Info("πŸ“š Indexing corpus documents...")
# Index all corpus documents
corpus_images = get_all_example_images()
if corpus_images:
# Check if user's image matches any corpus image by comparing pixels
user_img_array = np.array(image.resize((64, 64)))
user_is_example = False
user_example_idx = -1
for idx, (path, img) in enumerate(corpus_images):
corpus_img_array = np.array(img.resize((64, 64)))
if np.allclose(user_img_array, corpus_img_array, atol=10):
user_is_example = True
user_example_idx = idx
break
all_scores = []
# Process ALL corpus images
batch_size = 4
for i in range(0, len(corpus_images), batch_size):
batch_imgs = [img for _, img in corpus_images[i:i+batch_size]]
batch_corpus = processor.process_images(batch_imgs).to(device)
with torch.no_grad():
corpus_embeddings = model(**batch_corpus)
corpus_scores = processor.score(query_embeddings, corpus_embeddings)
for j, (path, img) in enumerate(corpus_images[i:i+batch_size]):
score = corpus_scores[0][j].item()
example_name = Path(path).name
is_user_doc = (i + j == user_example_idx)
label = example_name
for ex in EXAMPLE_CONFIG:
if ex["file"].endswith(example_name):
label = f"{ex['lang']} {ex['description']}"
break
if is_user_doc:
label = f"πŸ“„ {label} (Selected)"
all_scores.append((score, label, path, img))
# If user uploaded custom document (not from corpus), add it
if not user_is_example:
all_scores.append((main_score, "πŸ“„ Your Document", None, image))
# Sort by score descending
all_scores.sort(key=lambda x: x[0], reverse=True)
# Get top document for VLM
top_score, top_label, top_path, top_img = all_scores[0]
top_document_image = top_img
top_document_label = top_label
# Find rank of user's document
user_rank = next((i for i, (_, label, _, _) in enumerate(all_scores) if "πŸ“„" in label), 0) + 1
# Build ranking display
total_docs = len(all_scores)
ranking_lines = [f"### πŸ“Š Ranking (out of {total_docs} documents)"]
ranking_lines.append(f"**Your document ranks #{user_rank}**\n")
ranking_lines.append("| Rank | Score | Document |")
ranking_lines.append("|------|-------|----------|")
for rank, (score, label, _, _) in enumerate(all_scores[:10], 1):
marker = "πŸ‘‰ " if "πŸ“„" in label else ""
ranking_lines.append(f"| {rank} | {score:.4f} | {marker}{label} |")
if len(all_scores) > 10:
ranking_lines.append(f"| ... | ... | *{len(all_scores) - 10} more documents* |")
ranking_info = "\n".join(ranking_lines)
# Generate heatmap
gr.Info("πŸ”₯ Generating similarity heatmap...")
heatmap_image = None
heatmap_available = False
if "ColQwen3" in loaded_colpali_model_name:
try:
spatial_merge_size = getattr(model, "spatial_merge_size", 2)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=spatial_merge_size,
)
image_mask = batch_images["input_ids"] == processor.image_token_id
pad_token_id = getattr(processor.tokenizer, "pad_token_id", 0)
query_mask = batch_queries["input_ids"] != pad_token_id
similarity_map = get_similarity_maps_from_embeddings(
image_embeddings,
query_embeddings,
n_patches,
image_mask,
query_mask,
)
aggregated = similarity_map.max(dim=0).values
heatmap_image = create_heatmap_overlay(image, aggregated)
heatmap_available = True
except Exception as e:
import traceback
print(f"Heatmap error: {traceback.format_exc()}")
elif "ColLFM2" in loaded_colpali_model_name:
heatmap_image = get_collfm2_heatmap(
model, processor, image, image_embeddings, query_embeddings, batch_images, batch_queries
)
if heatmap_image is not None:
heatmap_available = True
# Build result text
user_image_path = None
# Try to detect if user's image is from examples
if hasattr(image, 'filename'):
user_image_path = image.filename
if heatmap_available:
if "ColLFM2" in loaded_colpali_model_name:
result_text = f"""## πŸ“Š Similarity Score: **{main_score:.4f}**
πŸ”΅ **Dark blue** = low relevance | 🟒 **Cyan/Green** = high relevance
⚠️ **ColLFM2 Heatmap Note:** This model uses a SigLIP2 vision encoder with pixel unshuffle, producing "holistic" embeddings. The heatmap shows **region-level relevance** rather than precise word-level localization. This is expected behavior - ColLFM2 excels at determining *if* a document is relevant. For precise heatmaps, try ColQwen3 models."""
else:
result_text = f"""## πŸ“Š Similarity Score: **{main_score:.4f}**
The heatmap shows which areas of the document are most relevant to your query.
πŸ”΅ **Dark blue** = low relevance | 🟒 **Cyan/Green** = high relevance"""
else:
result_text = f"""## πŸ“Š Similarity Score: **{main_score:.4f}**
*Heatmap visualization is not available for this model configuration.*"""
# Generate VLM answer using top-ranked document
vlm_answer = ""
if enable_vlm:
vlm_answer = generate_vlm_answer(top_document_image, query)
if enable_corpus and top_document_label != "πŸ“„ Your Document":
vlm_answer = f"*[Answer based on top-ranked document: {top_document_label}]*\n\n{vlm_answer}"
gr.Info("βœ… Analysis complete!")
return main_score, result_text, heatmap_image, vlm_answer, ranking_info
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```"
return None, error_msg, None, "", None
def create_demo():
available_examples = get_available_examples()
with gr.Blocks(title="SauerkrautLM-ColPali Demo") as demo:
# Header with logo
gr.HTML("""
<div style="text-align: center; padding: 20px 0;">
<div style="margin: 0 auto 20px auto; max-width: 800px;">
<img src="https://vago-solutions.ai/wp-content/uploads/2025/12/Sauerkrautlm-colpali-scaled.png"
alt="SauerkrautLM-ColPali"
style="width: 75%; border-radius: 12px;"/>
</div>
<p style="color: #888; font-size: 1.2rem; margin: 0 0 16px 0;">
Visual Document Retrieval with Multi-Vector Embeddings + VLM Answer Generation
</p>
<div style="margin-top: 16px;">
<a href="https://huggingface.co/VAGOsolutions" target="_blank" style="color: #667eea; text-decoration: none; margin: 0 12px; font-weight: 500;">πŸ€— Models</a>
<a href="https://github.com/VAGOsolutions/sauerkrautlm-colpali" target="_blank" style="color: #667eea; text-decoration: none; margin: 0 12px; font-weight: 500;">πŸ“– GitHub</a>
<a href="https://vago-solutions.ai" target="_blank" style="color: #667eea; text-decoration: none; margin: 0 12px; font-weight: 500;">🌐 VAGO Solutions</a>
</div>
</div>
""")
with gr.Row():
# Left Column - Inputs
with gr.Column(scale=1):
gr.HTML('<h3 style="color: #b0b0b0; margin-bottom: 16px;">βš™οΈ Configuration</h3>')
model_dropdown = gr.Dropdown(
choices=list(COLPALI_MODELS.keys()),
value="SauerkrautLM-ColQwen3-2B (Balanced, 4.4GB)",
label="πŸ” Retrieval Model",
info="ColPali-based model for document retrieval and heatmap",
)
with gr.Row():
enable_vlm = gr.Checkbox(
label="πŸ€– Enable VLM",
value=False,
info="Use Qwen3-VL-4B for answers (adds ~30-60s on first use)",
)
enable_corpus = gr.Checkbox(
label="πŸ“š Compare with Corpus",
value=True,
info="Rank against 31 example documents",
)
gr.HTML('<h3 style="color: #b0b0b0; margin: 24px 0 16px 0;">πŸ“„ Document</h3>')
image_input = gr.Image(
label="Upload Document Image",
type="pil",
height=350,
)
query_input = gr.Textbox(
label="πŸ” Search Query",
placeholder="e.g., What is the total revenue? / Wie hoch ist der Umsatz?",
lines=2,
)
submit_btn = gr.Button(
"πŸš€ Analyze Document",
variant="primary",
size="lg",
)
# Right Column - Results
with gr.Column(scale=1):
gr.HTML('<h3 style="color: #b0b0b0; margin-bottom: 16px;">πŸ“Š Results</h3>')
with gr.Group():
score_output = gr.Number(
label="Similarity Score",
precision=4,
)
result_markdown = gr.Markdown(
value="*Upload an image and enter a query to get started*",
)
gr.HTML('<h3 style="color: #b0b0b0; margin: 24px 0 16px 0;">πŸ”₯ Similarity Heatmap</h3>')
heatmap_output = gr.Image(
label="Heatmap Visualization",
type="pil",
height=400,
)
gr.HTML("""
<div style="display: flex; align-items: center; gap: 12px; padding: 12px; background: rgba(255,255,255,0.03); border-radius: 8px; margin-top: 8px;">
<div style="width: 150px; height: 20px; background: linear-gradient(90deg, #0b0924 0%, #1f1147 20%, #3b1c6c 35%, #4a3880 50%, #3e7a8c 70%, #5ec5c0 85%, #c3f0e4 100%); border-radius: 4px;"></div>
<span style="color: #888; font-size: 0.9rem;">Low β†’ High Relevance (mako colormap)</span>
</div>
""")
with gr.Accordion("πŸ“š Corpus Ranking", open=True):
ranking_output = gr.Markdown(
value="*Enable 'Compare with Corpus' to see how your document ranks*",
)
with gr.Accordion("πŸ€– VLM Answer", open=True, visible=True):
vlm_answer_output = gr.Textbox(
label="Answer from Qwen3-VL-4B",
lines=6,
interactive=False,
placeholder="Enable VLM and analyze to get an AI-generated answer...",
)
# Examples section
if available_examples:
gr.HTML('<h3 style="color: #b0b0b0; margin: 32px 0 16px 0;">πŸ“š Example Documents (31 multilingual documents)</h3>')
gr.HTML("""
<div style="padding: 12px; background: rgba(102, 126, 234, 0.1); border-radius: 8px; margin-bottom: 16px;">
<p style="color: #a0a0a0; margin: 0; font-size: 0.9rem;">
🌍 <strong>Languages:</strong> German (DE), English (EN), French (FR), Spanish (ES)<br>
πŸ“‚ <strong>Categories:</strong> Annual Reports, ESG, Tax Forms, Scientific Papers, Energy/Hydrogen, Healthcare, Economic Forecasts
</p>
</div>
""")
gr.Examples(
examples=available_examples,
inputs=[image_input, query_input],
label="Click an example to load it",
)
# Info section
gr.HTML("""
<details style="margin-top: 12px; padding: 12px; background: rgba(255,200,100,0.08); border: 1px solid rgba(255,200,100,0.2); border-radius: 8px;">
<summary style="cursor: pointer; color: #e0c080; font-weight: 500;">ℹ️ About Heatmap Differences: ColQwen3 vs ColLFM2</summary>
<div style="margin-top: 12px; color: #a0a0a0; font-size: 0.85rem; line-height: 1.6;">
<p><strong style="color: #80c0ff;">ColQwen3 (Qwen-based):</strong> Uses Qwen3-VL's vision encoder which preserves strong spatial locality in patch embeddings. Each image patch maintains distinct features, allowing query tokens to differentiate between regions. Result: <em>precise, localized heatmaps</em>.</p>
<p style="margin-top: 8px;"><strong style="color: #ffa080;">ColLFM2 (LFM2-based):</strong> Uses SigLIP2 NaFlex vision encoder with <em>pixel unshuffle</em> for efficient token reduction. This merges spatial information across patches, producing more "holistic" embeddings. Query tokens show high correlation (~0.97) across all patches. Result: <em>region-level relevance</em> rather than word-level precision.</p>
<p style="margin-top: 8px;"><strong style="color: #80ffa0;">Why ColLFM2 still performs well for retrieval:</strong> The subtle similarity differences (e.g., 0.88 vs 0.94) are sufficient for ranking documents correctly. ColLFM2 excels at determining <em>if</em> a document is relevant, while ColQwen3 better shows <em>where</em> the relevance is.</p>
</div>
</details>
""")
# Footer
gr.HTML("""
<div style="text-align: center; padding: 32px 0 16px 0; border-top: 1px solid rgba(255,255,255,0.1); margin-top: 32px;">
<p style="color: #666; font-size: 0.9rem;">
πŸ’‘ <b>Tip:</b> ColQwen3 models provide the best heatmap visualization. Enable VLM (Qwen3-VL-4B) for AI-generated answers.
</p>
<p style="color: #555; font-size: 0.85rem; margin-top: 8px;">
Made with ❀️ by <a href="https://vago-solutions.ai" target="_blank" style="color: #667eea;">VAGO Solutions</a>
</p>
</div>
""")
# Event handlers
submit_btn.click(
fn=process_query_with_corpus,
inputs=[model_dropdown, image_input, query_input, enable_vlm, enable_corpus],
outputs=[score_output, result_markdown, heatmap_output, vlm_answer_output, ranking_output],
)
query_input.submit(
fn=process_query_with_corpus,
inputs=[model_dropdown, image_input, query_input, enable_vlm, enable_corpus],
outputs=[score_output, result_markdown, heatmap_output, vlm_answer_output, ranking_output],
)
# Show hint when VLM is enabled
enable_vlm.change(
fn=on_vlm_toggle,
inputs=[enable_vlm],
outputs=[enable_vlm],
)
return demo
if __name__ == "__main__":
install_fa2()
demo = create_demo()
demo.queue(max_size=10).launch(debug=True)