Spaces:
Running
on
L40S
Running
on
L40S
| """ | |
| SauerkrautLM-ColPali Demo Space | |
| Visual Document Retrieval with Similarity Heat Maps + VLM Answer Generation | |
| Multi-document indexing for realistic retrieval scenarios | |
| """ | |
| import os | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from einops import rearrange | |
| from huggingface_hub import login | |
| import spaces | |
| import math | |
| from pathlib import Path | |
| from typing import List, Tuple, Optional, Dict | |
| # Import model classes at startup | |
| from sauerkrautlm_colpali.models.lfm2.collfm2.modeling_collfm2 import ColLFM2 | |
| from sauerkrautlm_colpali.models.lfm2.collfm2.processing_collfm2 import ColLFM2Processor | |
| from sauerkrautlm_colpali.models.qwen3.colqwen3.modeling_colqwen3 import ColQwen3 | |
| from sauerkrautlm_colpali.models.qwen3.colqwen3.processing_colqwen3 import ColQwen3Processor | |
| print("All model imports successful!") | |
| EPSILON = 1e-10 | |
| def install_fa2(): | |
| print("Installing Flash Attention 2...") | |
| os.system("pip install flash-attn --no-build-isolation") | |
| # HF Token for private models | |
| hf_token = os.getenv("HF_KEY") | |
| if hf_token: | |
| login(token=hf_token) | |
| # ColPali Model options | |
| COLPALI_MODELS = { | |
| "SauerkrautLM-ColLFM2-450M (Fastest, 0.9GB)": "VAGOsolutions/SauerkrautLM-ColLFM2-450M-v0.1", | |
| "SauerkrautLM-ColQwen3-1.7B-Turbo (Fast, 3.4GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-1.7b-Turbo-v0.1", | |
| "SauerkrautLM-ColQwen3-2B (Balanced, 4.4GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-2b-v0.1", | |
| "SauerkrautLM-ColQwen3-4B (Quality, 8GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-4b-v0.1", | |
| "SauerkrautLM-ColQwen3-8B (Best, 16GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-8b-v0.1", | |
| } | |
| # Global model cache | |
| loaded_colpali_model = None | |
| loaded_colpali_processor = None | |
| loaded_colpali_model_name = None | |
| loaded_vlm_model = None | |
| loaded_vlm_processor = None | |
| # ============================================================================= | |
| # EXAMPLE CONFIGURATION - Organized by language and use case | |
| # ============================================================================= | |
| EXAMPLE_CONFIG = [ | |
| # ==================== GERMAN (DE) ==================== | |
| # Annual Reports | |
| { | |
| "file": "deutsch/2024-infineon-geschaeftsbericht-v01-00-de_p4.png", | |
| "query": "Wie hoch ist der Umsatz von Infineon?", | |
| "category": "π Annual Report", | |
| "lang": "π©πͺ DE", | |
| "description": "Infineon GeschΓ€ftsbericht 2024 - Kennzahlen" | |
| }, | |
| { | |
| "file": "deutsch/BASF_Bericht_2024_p90.png", | |
| "query": "Wie ist das Risikomanagement bei BASF organisiert?", | |
| "category": "π Annual Report", | |
| "lang": "π©πͺ DE", | |
| "description": "BASF Bericht 2024 - Risikomanagement" | |
| }, | |
| { | |
| "file": "deutsch/entire-dtag-gb24_p351.png", | |
| "query": "Wie setzt sich die VorstandsvergΓΌtung bei der Deutschen Telekom zusammen?", | |
| "category": "π Annual Report", | |
| "lang": "π©πͺ DE", | |
| "description": "Deutsche Telekom GB 2024 - VergΓΌtung" | |
| }, | |
| { | |
| "file": "deutsch/entire-dtag-gb24_p77.png", | |
| "query": "Wie hoch sind die Nettofinanzverbindlichkeiten der Deutschen Telekom?", | |
| "category": "π Annual Report", | |
| "lang": "π©πͺ DE", | |
| "description": "Deutsche Telekom GB 2024 - Finanzen" | |
| }, | |
| # Hydrogen / Energy | |
| { | |
| "file": "deutsch/Bildschirmfoto 2025-12-14 um 01.31.45.png", | |
| "query": "Wo verlΓ€uft die geplante Wasserstofftrasse in Leipzig?", | |
| "category": "β‘ Hydrogen/Energy", | |
| "lang": "π©πͺ DE", | |
| "description": "Wasserstoff-Infrastruktur Leipzig" | |
| }, | |
| { | |
| "file": "deutsch/Bildschirmfoto 2025-12-14 um 01.36.01.png", | |
| "query": "Wie wird die GewΓ€sserquerung der Wasserstoffleitung realisiert?", | |
| "category": "β‘ Hydrogen/Energy", | |
| "lang": "π©πͺ DE", | |
| "description": "Wasserstoff-Technische Zeichnung" | |
| }, | |
| # Tax/Forms | |
| { | |
| "file": "deutsch/ESt_1_A_2022_p2.png", | |
| "query": "Wo trage ich meine Bankverbindung in der SteuererklΓ€rung ein?", | |
| "category": "π Tax Form", | |
| "lang": "π©πͺ DE", | |
| "description": "EinkommensteuererklΓ€rung ESt 1A" | |
| }, | |
| # Economic Reports | |
| { | |
| "file": "deutsch/Monatsbericht---Oktober-2025_p152.png", | |
| "query": "Wie hoch sind die aktuellen ZinssΓ€tze fΓΌr Wohnungsbaukredite?", | |
| "category": "π° Financial Report", | |
| "lang": "π©πͺ DE", | |
| "description": "Bundesbank Monatsbericht - ZinssΓ€tze" | |
| }, | |
| { | |
| "file": "deutsch/sd-2025-digital-07-wollmershaeuser-etal-ifo-konjunkturprognose-sommer-2025_p23.png", | |
| "query": "Was sind die Annahmen fΓΌr den Γlpreis in der ifo-Prognose?", | |
| "category": "π Economic Forecast", | |
| "lang": "π©πͺ DE", | |
| "description": "ifo Konjunkturprognose 2025" | |
| }, | |
| # Environmental | |
| { | |
| "file": "deutsch/rep0913_p197.png", | |
| "query": "Wie hat sich der Fleischkonsum in Γsterreich entwickelt?", | |
| "category": "π± Environmental", | |
| "lang": "π©πͺ DE", | |
| "description": "Klimaschutzbericht - Umweltbundesamt" | |
| }, | |
| # ==================== ENGLISH (EN) ==================== | |
| # ESG/Sustainability | |
| { | |
| "file": "englisch/2025051910270996484_p19.png", | |
| "query": "What is the waste recycling rate at CRRC?", | |
| "category": "π± ESG Report", | |
| "lang": "π¬π§ EN", | |
| "description": "CRRC ESG Report 2024 - Waste Management" | |
| }, | |
| # Scientific | |
| { | |
| "file": "englisch/6e81bb8284357ea1773e99832d21c65b_new_myosinlecture_ag_p5.png", | |
| "query": "What are the different classes of myosin?", | |
| "category": "π¬ Scientific Paper", | |
| "lang": "π¬π§ EN", | |
| "description": "Myosin Phylogenetic Tree" | |
| }, | |
| # Historical/Vintage | |
| { | |
| "file": "englisch/ADVE_0004.png", | |
| "query": "Which cigarette brand advertised 'Call for Philip Morris'?", | |
| "category": "π Historical Document", | |
| "lang": "π¬π§ EN", | |
| "description": "Vintage Philip Morris Advertisement" | |
| }, | |
| # Business Forms | |
| { | |
| "file": "englisch/Form_0033.png", | |
| "query": "How long should vendor audit records be retained?", | |
| "category": "π Business Form", | |
| "lang": "π¬π§ EN", | |
| "description": "Records Retention Schedule" | |
| }, | |
| { | |
| "file": "englisch/Letter_0061.png", | |
| "query": "Who requested copies of the 'Helping Youth Decide' booklet?", | |
| "category": "βοΈ Business Letter", | |
| "lang": "π¬π§ EN", | |
| "description": "Girl Scouts Correspondence" | |
| }, | |
| # Financial Reports | |
| { | |
| "file": "englisch/NASDAQ_DDD_2024_p76.png", | |
| "query": "What is the total stockholders' equity of 3D Systems?", | |
| "category": "π Annual Report", | |
| "lang": "π¬π§ EN", | |
| "description": "3D Systems Financial Statement" | |
| }, | |
| { | |
| "file": "englisch/TMUS-2024-Annual-Report_p143.png", | |
| "query": "Who is the CEO of T-Mobile US?", | |
| "category": "π Annual Report", | |
| "lang": "π¬π§ EN", | |
| "description": "T-Mobile US Annual Report 2024" | |
| }, | |
| { | |
| "file": "englisch/pwc-transparency-report-2023-2024_p33.png", | |
| "query": "What are the rotation periods for audit partners at PwC?", | |
| "category": "π Transparency Report", | |
| "lang": "π¬π§ EN", | |
| "description": "PwC Transparency Report" | |
| }, | |
| # ==================== FRENCH (FR) ==================== | |
| { | |
| "file": "franzΓΆsisch/194000315_0_p178.png", | |
| "query": "Quel est le coΓ»t du travail au SMIC en France?", | |
| "category": "π° Labor Statistics", | |
| "lang": "π«π· FR", | |
| "description": "Statistiques du travail - SMIC" | |
| }, | |
| { | |
| "file": "franzΓΆsisch/194000315_0_p21.png", | |
| "query": "Quelle est la prΓ©vision de croissance du PIB en zone euro?", | |
| "category": "π Economic Forecast", | |
| "lang": "π«π· FR", | |
| "description": "PrΓ©visions Γ©conomiques Zone Euro" | |
| }, | |
| { | |
| "file": "franzΓΆsisch/Cours-de-physique-1v2_p44.png", | |
| "query": "Comment fonctionne la vision des couleurs?", | |
| "category": "π Educational", | |
| "lang": "π«π· FR", | |
| "description": "Cours de Physique - Vision" | |
| }, | |
| { | |
| "file": "franzΓΆsisch/CSSF_RA_2024_FR_p14.png", | |
| "query": "Quelle est la rΓ©partition des employΓ©s de la CSSF par nationalitΓ©?", | |
| "category": "π Annual Report", | |
| "lang": "π«π· FR", | |
| "description": "CSSF Luxembourg - Rapport Annuel" | |
| }, | |
| { | |
| "file": "franzΓΆsisch/ICN_Definition-Nursing_Report_FR_Web_p47.png", | |
| "query": "Combien d'associations nationales d'infirmières participent au CII?", | |
| "category": "π₯ Healthcare Report", | |
| "lang": "π«π· FR", | |
| "description": "ICN Rapport Infirmières" | |
| }, | |
| { | |
| "file": "franzΓΆsisch/rapport-cns-2024-internet_p14.png", | |
| "query": "Quelles sont les principales missions de la CNS?", | |
| "category": "π₯ Healthcare Report", | |
| "lang": "π«π· FR", | |
| "description": "CNS Rapport Annuel" | |
| }, | |
| { | |
| "file": "franzΓΆsisch/rapport-esg-2024.pdf.coredownload.inline_p29.png", | |
| "query": "Quels sont les objectifs ESG pour 2024?", | |
| "category": "π± ESG Report", | |
| "lang": "π«π· FR", | |
| "description": "Rapport ESG 2024" | |
| }, | |
| # ==================== SPANISH (ES) ==================== | |
| { | |
| "file": "spanisch/Coeur-ESG-Report-23-May-2024-Spanish-version-compressed_p31.png", | |
| "query": "ΒΏCuΓ‘les son las emisiones de gases de efecto invernadero de Coeur Mining?", | |
| "category": "π± ESG Report", | |
| "lang": "πͺπΈ ES", | |
| "description": "Coeur Mining ESG - Emisiones" | |
| }, | |
| { | |
| "file": "spanisch/Coeur-ESG-Report-23-May-2024-Spanish-version-compressed_p39.png", | |
| "query": "ΒΏQuΓ© medidas de seguridad implementa Coeur Mining?", | |
| "category": "π± ESG Report", | |
| "lang": "πͺπΈ ES", | |
| "description": "Coeur Mining ESG - Seguridad" | |
| }, | |
| { | |
| "file": "spanisch/Informe-Economico-Regional-2022-2023_p112.png", | |
| "query": "ΒΏCuΓ‘l es la situaciΓ³n econΓ³mica regional en 2023?", | |
| "category": "π Economic Report", | |
| "lang": "πͺπΈ ES", | |
| "description": "Informe EconΓ³mico Regional" | |
| }, | |
| { | |
| "file": "spanisch/Informe-Sostenibilidad-ESG-2024_p16.png", | |
| "query": "ΒΏCuΓ‘les son los objetivos de sostenibilidad para 2024?", | |
| "category": "π± ESG Report", | |
| "lang": "πͺπΈ ES", | |
| "description": "Informe Sostenibilidad ESG" | |
| }, | |
| { | |
| "file": "spanisch/MAPs_PLAN_DESARROLLO_p14.png", | |
| "query": "ΒΏCuΓ‘ntas propuestas de renovables se recibieron en EspaΓ±a?", | |
| "category": "β‘ Energy Infrastructure", | |
| "lang": "πͺπΈ ES", | |
| "description": "Plan Desarrollo Red ElΓ©ctrica" | |
| }, | |
| { | |
| "file": "spanisch/Presupuestos_p32.png", | |
| "query": "ΒΏCuΓ‘l es el presupuesto total asignado?", | |
| "category": "π° Budget Document", | |
| "lang": "πͺπΈ ES", | |
| "description": "Presupuestos Generales" | |
| }, | |
| ] | |
| def get_all_example_images() -> List[Tuple[str, Image.Image]]: | |
| """Load all example images for multi-document indexing.""" | |
| examples_dir = Path(__file__).parent / "demopics" | |
| images = [] | |
| for example in EXAMPLE_CONFIG: | |
| filepath = examples_dir / example["file"] | |
| if filepath.exists(): | |
| try: | |
| img = Image.open(filepath).convert("RGB") | |
| images.append((str(filepath), img)) | |
| except Exception as e: | |
| print(f"Error loading {filepath}: {e}") | |
| return images | |
| def get_available_examples(): | |
| """Load examples for the Gradio Examples component (shuffled).""" | |
| import random | |
| examples_dir = Path(__file__).parent / "demopics" | |
| available = [] | |
| for example in EXAMPLE_CONFIG: | |
| filepath = examples_dir / example["file"] | |
| if filepath.exists(): | |
| available.append([str(filepath), example["query"]]) | |
| # Shuffle to mix languages | |
| random.seed(42) # Consistent shuffle | |
| random.shuffle(available) | |
| return available if available else None | |
| def get_example_gallery_data(): | |
| """Get data for the example gallery with categories.""" | |
| examples_dir = Path(__file__).parent / "demopics" | |
| gallery_data = [] | |
| for example in EXAMPLE_CONFIG: | |
| filepath = examples_dir / example["file"] | |
| if filepath.exists(): | |
| gallery_data.append({ | |
| "path": str(filepath), | |
| "query": example["query"], | |
| "label": f"{example['lang']} {example['category']}: {example['description']}", | |
| "category": example["category"], | |
| "lang": example["lang"], | |
| }) | |
| return gallery_data | |
| def load_colpali_model(model_choice: str): | |
| """Load the selected ColPali model with proper device placement.""" | |
| global loaded_colpali_model, loaded_colpali_processor, loaded_colpali_model_name | |
| model_name = COLPALI_MODELS[model_choice] | |
| if loaded_colpali_model_name == model_name and loaded_colpali_model is not None: | |
| gr.Info(f"β {model_choice} ready!") | |
| return loaded_colpali_model, loaded_colpali_processor | |
| gr.Info(f"β³ Loading {model_choice}... Please wait.") | |
| if loaded_colpali_model is not None: | |
| gr.Info("π Unloading previous model...") | |
| try: | |
| del loaded_colpali_model | |
| del loaded_colpali_processor | |
| except Exception: | |
| pass | |
| torch.cuda.empty_cache() | |
| try: | |
| import flash_attn | |
| attn_impl = "flash_attention_2" | |
| gr.Info("β‘ Using Flash Attention 2") | |
| except ImportError: | |
| attn_impl = "sdpa" | |
| gr.Info("π§ Using SDPA attention") | |
| print(f"Loading {model_name} with attention: {attn_impl}") | |
| if "ColLFM2" in model_name: | |
| gr.Info("π₯ Downloading model weights...") | |
| loaded_colpali_model = ColLFM2.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation=attn_impl, | |
| token=hf_token, | |
| ).eval().to("cuda") | |
| gr.Info("π₯ Downloading processor...") | |
| loaded_colpali_processor = ColLFM2Processor.from_pretrained(model_name, token=hf_token) | |
| elif "ColQwen3" in model_name: | |
| gr.Info("π₯ Downloading model weights...") | |
| loaded_colpali_model = ColQwen3.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation=attn_impl, | |
| device_map="cuda", | |
| token=hf_token, | |
| ).eval() | |
| gr.Info("π₯ Downloading processor...") | |
| loaded_colpali_processor = ColQwen3Processor.from_pretrained(model_name, token=hf_token) | |
| else: | |
| raise ValueError(f"Unknown model type: {model_name}") | |
| loaded_colpali_model_name = model_name | |
| gr.Info(f"β {model_choice} loaded and ready!") | |
| return loaded_colpali_model, loaded_colpali_processor | |
| def load_vlm_model(): | |
| """Load Qwen3-VL-4B for answer generation.""" | |
| global loaded_vlm_model, loaded_vlm_processor | |
| if loaded_vlm_model is not None: | |
| gr.Info("β Qwen3-VL-4B ready!") | |
| return loaded_vlm_model, loaded_vlm_processor | |
| gr.Info("β³ Loading Qwen3-VL-4B-Instruct... Please wait.") | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| vlm_model_name = "Qwen/Qwen3-VL-4B-Instruct" | |
| print(f"Loading VLM: {vlm_model_name}") | |
| gr.Info("π₯ Downloading VLM model weights (8GB)...") | |
| loaded_vlm_model = AutoModelForImageTextToText.from_pretrained( | |
| vlm_model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| token=hf_token, | |
| ).eval() | |
| gr.Info("π₯ Downloading VLM processor...") | |
| loaded_vlm_processor = AutoProcessor.from_pretrained(vlm_model_name, token=hf_token) | |
| gr.Info("β Qwen3-VL-4B loaded and ready!") | |
| return loaded_vlm_model, loaded_vlm_processor | |
| def on_vlm_toggle(enabled): | |
| """Show hint when VLM is enabled.""" | |
| if enabled: | |
| gr.Info("βΉοΈ VLM (Qwen3-VL-4B) will be loaded on first analysis. This adds ~30-60 seconds.") | |
| return enabled | |
| def get_similarity_maps_from_embeddings( | |
| image_embeddings: torch.Tensor, | |
| query_embeddings: torch.Tensor, | |
| n_patches: Tuple[int, int], | |
| image_mask: torch.Tensor, | |
| query_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """EXACT ColPali implementation of similarity map computation.""" | |
| idx = 0 | |
| n_patches_x, n_patches_y = n_patches[0], n_patches[1] | |
| n_image_tokens = int(image_mask[idx].sum().item()) | |
| expected_tokens = n_patches_x * n_patches_y | |
| if n_image_tokens != expected_tokens: | |
| n = n_image_tokens | |
| sqrt_n = int(math.sqrt(n)) | |
| for i in range(sqrt_n, 0, -1): | |
| if n % i == 0: | |
| n_patches_x, n_patches_y = n // i, i | |
| break | |
| image_embedding_grid = rearrange( | |
| image_embeddings[idx][image_mask[idx]], | |
| "(h w) c -> w h c", | |
| w=n_patches_x, | |
| h=n_patches_y, | |
| ) | |
| query_emb = query_embeddings[idx] | |
| if query_mask is not None: | |
| query_emb = query_emb[query_mask[idx]] | |
| similarity_map = torch.einsum( | |
| "nk,ijk->nij", | |
| query_emb, | |
| image_embedding_grid, | |
| ) | |
| return similarity_map | |
| def create_heatmap_overlay( | |
| image: Image.Image, | |
| similarity_map: torch.Tensor, | |
| alpha: float = 0.5, | |
| skip_normalize: bool = False, | |
| ) -> Image.Image: | |
| """Create heatmap overlay following EXACT ColPali visualization.""" | |
| import seaborn as sns | |
| sim_float = similarity_map.float() | |
| if skip_normalize: | |
| sim_array = sim_float.cpu().numpy() | |
| else: | |
| min_val = sim_float.min() | |
| max_val = sim_float.max() | |
| sim_normalized = (sim_float - min_val) / (max_val - min_val + 1e-10) | |
| sim_array = sim_normalized.cpu().numpy() | |
| sim_array = rearrange(sim_array, "h w -> w h") | |
| sim_image = Image.fromarray((sim_array * 255).astype(np.uint8)) | |
| sim_image = sim_image.resize(image.size, Image.Resampling.BICUBIC) | |
| sim_resized = np.array(sim_image) / 255.0 | |
| cmap = sns.color_palette("mako", as_cmap=True) | |
| heatmap_rgba = cmap(sim_resized) | |
| heatmap = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8) | |
| img_array = np.array(image.convert("RGB")).astype(np.float32) | |
| heatmap_float = heatmap.astype(np.float32) | |
| blended = img_array * (1 - alpha) + heatmap_float * alpha | |
| blended = np.clip(blended, 0, 255).astype(np.uint8) | |
| return Image.fromarray(blended) | |
| def get_collfm2_heatmap(model, processor, image, image_embeddings, query_embeddings, batch_images, batch_queries): | |
| """Generate heatmap for ColLFM2 models (simplified version).""" | |
| try: | |
| if "input_ids" not in batch_images: | |
| return None | |
| input_ids = batch_images["input_ids"][0] | |
| tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor.processor.tokenizer | |
| image_token_id = tokenizer.convert_tokens_to_ids('<image>') | |
| image_mask = input_ids == image_token_id | |
| n_image_tokens = image_mask.sum().item() | |
| # Find best grid that matches the token count and image aspect ratio | |
| img_width, img_height = image.size | |
| img_ratio = img_width / img_height | |
| best_diff = float('inf') | |
| n_patches_x, n_patches_y = int(math.sqrt(n_image_tokens)), int(math.sqrt(n_image_tokens)) | |
| for i in range(1, int(math.sqrt(n_image_tokens)) + 1): | |
| if n_image_tokens % i == 0: | |
| j = n_image_tokens // i | |
| ratio1 = j / i | |
| ratio2 = i / j | |
| diff1 = abs(img_ratio - ratio1) | |
| diff2 = abs(img_ratio - ratio2) | |
| if diff1 < best_diff: | |
| best_diff = diff1 | |
| n_patches_x, n_patches_y = j, i | |
| if diff2 < best_diff: | |
| best_diff = diff2 | |
| n_patches_x, n_patches_y = i, j | |
| query_emb = query_embeddings[0] | |
| # Filter padding | |
| pad_token_id = getattr(tokenizer, "pad_token_id", 0) or 0 | |
| query_mask = batch_queries["input_ids"][0] != pad_token_id | |
| query_emb = query_emb[query_mask] | |
| # Get image embeddings | |
| image_emb = image_embeddings[0][image_mask][:n_patches_x * n_patches_y] | |
| image_grid = rearrange(image_emb, "(h w) c -> w h c", w=n_patches_x, h=n_patches_y) | |
| # Compute similarity | |
| similarity_map = torch.einsum("nk,ijk->nij", query_emb, image_grid) | |
| aggregated = similarity_map.max(dim=0).values | |
| # Aggressive normalization for ColLFM2 | |
| agg_float = aggregated.float() | |
| threshold = torch.quantile(agg_float.flatten(), 0.90) | |
| hot_mask = agg_float > threshold | |
| min_hot = agg_float[hot_mask].min() if hot_mask.sum() > 0 else threshold | |
| max_hot = agg_float.max() | |
| normalized = torch.zeros_like(agg_float) | |
| if hot_mask.sum() > 0: | |
| normalized[hot_mask] = 0.5 + 0.5 * (agg_float[hot_mask] - min_hot) / (max_hot - min_hot + 1e-10) | |
| return create_heatmap_overlay(image, normalized, skip_normalize=True) | |
| except Exception as e: | |
| import traceback | |
| print(f"ColLFM2 Heatmap error: {traceback.format_exc()}") | |
| return None | |
| def generate_vlm_answer(image: Image.Image, query: str) -> str: | |
| """Generate an answer using Qwen3-VL-4B-Instruct.""" | |
| try: | |
| vlm_model, vlm_processor = load_vlm_model() | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": f"Based on this document image, please answer the following question:\n\n{query}\n\nProvide a clear and concise answer based only on the information visible in the document."}, | |
| ], | |
| } | |
| ] | |
| text = vlm_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| from qwen_vl_utils import process_vision_info | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = vlm_processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to("cuda") | |
| with torch.no_grad(): | |
| generated_ids = vlm_model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| answer = vlm_processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| return answer | |
| except Exception as e: | |
| import traceback | |
| return f"Error generating answer: {str(e)}" | |
| def process_query_with_corpus(model_choice: str, image: Image.Image, query: str, enable_vlm: bool, enable_corpus: bool): | |
| """ | |
| Process a query against an image and optionally against all corpus documents. | |
| Returns similarity score, ranking info, heatmap, and optional VLM answer. | |
| """ | |
| if image is None: | |
| return None, "β οΈ Please upload an image.", None, "", None | |
| if not query.strip(): | |
| return None, "β οΈ Please enter a search query.", None, "", None | |
| try: | |
| model, processor = load_colpali_model(model_choice) | |
| device = next(model.parameters()).device | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Process query | |
| batch_queries = processor.process_queries([query]).to(device) | |
| with torch.no_grad(): | |
| query_embeddings = model(**batch_queries) | |
| # Process main image | |
| batch_images = processor.process_images([image]).to(device) | |
| with torch.no_grad(): | |
| image_embeddings = model(**batch_images) | |
| gr.Info("π Computing similarity scores...") | |
| scores = processor.score(query_embeddings, image_embeddings) | |
| main_score = scores[0][0].item() | |
| # Initialize for VLM | |
| ranking_info = None | |
| top_document_image = image | |
| top_document_label = "Your Document" | |
| if enable_corpus: | |
| gr.Info("π Indexing corpus documents...") | |
| # Index all corpus documents | |
| corpus_images = get_all_example_images() | |
| if corpus_images: | |
| # Check if user's image matches any corpus image by comparing pixels | |
| user_img_array = np.array(image.resize((64, 64))) | |
| user_is_example = False | |
| user_example_idx = -1 | |
| for idx, (path, img) in enumerate(corpus_images): | |
| corpus_img_array = np.array(img.resize((64, 64))) | |
| if np.allclose(user_img_array, corpus_img_array, atol=10): | |
| user_is_example = True | |
| user_example_idx = idx | |
| break | |
| all_scores = [] | |
| # Process ALL corpus images | |
| batch_size = 4 | |
| for i in range(0, len(corpus_images), batch_size): | |
| batch_imgs = [img for _, img in corpus_images[i:i+batch_size]] | |
| batch_corpus = processor.process_images(batch_imgs).to(device) | |
| with torch.no_grad(): | |
| corpus_embeddings = model(**batch_corpus) | |
| corpus_scores = processor.score(query_embeddings, corpus_embeddings) | |
| for j, (path, img) in enumerate(corpus_images[i:i+batch_size]): | |
| score = corpus_scores[0][j].item() | |
| example_name = Path(path).name | |
| is_user_doc = (i + j == user_example_idx) | |
| label = example_name | |
| for ex in EXAMPLE_CONFIG: | |
| if ex["file"].endswith(example_name): | |
| label = f"{ex['lang']} {ex['description']}" | |
| break | |
| if is_user_doc: | |
| label = f"π {label} (Selected)" | |
| all_scores.append((score, label, path, img)) | |
| # If user uploaded custom document (not from corpus), add it | |
| if not user_is_example: | |
| all_scores.append((main_score, "π Your Document", None, image)) | |
| # Sort by score descending | |
| all_scores.sort(key=lambda x: x[0], reverse=True) | |
| # Get top document for VLM | |
| top_score, top_label, top_path, top_img = all_scores[0] | |
| top_document_image = top_img | |
| top_document_label = top_label | |
| # Find rank of user's document | |
| user_rank = next((i for i, (_, label, _, _) in enumerate(all_scores) if "π" in label), 0) + 1 | |
| # Build ranking display | |
| total_docs = len(all_scores) | |
| ranking_lines = [f"### π Ranking (out of {total_docs} documents)"] | |
| ranking_lines.append(f"**Your document ranks #{user_rank}**\n") | |
| ranking_lines.append("| Rank | Score | Document |") | |
| ranking_lines.append("|------|-------|----------|") | |
| for rank, (score, label, _, _) in enumerate(all_scores[:10], 1): | |
| marker = "π " if "π" in label else "" | |
| ranking_lines.append(f"| {rank} | {score:.4f} | {marker}{label} |") | |
| if len(all_scores) > 10: | |
| ranking_lines.append(f"| ... | ... | *{len(all_scores) - 10} more documents* |") | |
| ranking_info = "\n".join(ranking_lines) | |
| # Generate heatmap | |
| gr.Info("π₯ Generating similarity heatmap...") | |
| heatmap_image = None | |
| heatmap_available = False | |
| if "ColQwen3" in loaded_colpali_model_name: | |
| try: | |
| spatial_merge_size = getattr(model, "spatial_merge_size", 2) | |
| n_patches = processor.get_n_patches( | |
| image_size=image.size, | |
| spatial_merge_size=spatial_merge_size, | |
| ) | |
| image_mask = batch_images["input_ids"] == processor.image_token_id | |
| pad_token_id = getattr(processor.tokenizer, "pad_token_id", 0) | |
| query_mask = batch_queries["input_ids"] != pad_token_id | |
| similarity_map = get_similarity_maps_from_embeddings( | |
| image_embeddings, | |
| query_embeddings, | |
| n_patches, | |
| image_mask, | |
| query_mask, | |
| ) | |
| aggregated = similarity_map.max(dim=0).values | |
| heatmap_image = create_heatmap_overlay(image, aggregated) | |
| heatmap_available = True | |
| except Exception as e: | |
| import traceback | |
| print(f"Heatmap error: {traceback.format_exc()}") | |
| elif "ColLFM2" in loaded_colpali_model_name: | |
| heatmap_image = get_collfm2_heatmap( | |
| model, processor, image, image_embeddings, query_embeddings, batch_images, batch_queries | |
| ) | |
| if heatmap_image is not None: | |
| heatmap_available = True | |
| # Build result text | |
| user_image_path = None | |
| # Try to detect if user's image is from examples | |
| if hasattr(image, 'filename'): | |
| user_image_path = image.filename | |
| if heatmap_available: | |
| if "ColLFM2" in loaded_colpali_model_name: | |
| result_text = f"""## π Similarity Score: **{main_score:.4f}** | |
| π΅ **Dark blue** = low relevance | π’ **Cyan/Green** = high relevance | |
| β οΈ **ColLFM2 Heatmap Note:** This model uses a SigLIP2 vision encoder with pixel unshuffle, producing "holistic" embeddings. The heatmap shows **region-level relevance** rather than precise word-level localization. This is expected behavior - ColLFM2 excels at determining *if* a document is relevant. For precise heatmaps, try ColQwen3 models.""" | |
| else: | |
| result_text = f"""## π Similarity Score: **{main_score:.4f}** | |
| The heatmap shows which areas of the document are most relevant to your query. | |
| π΅ **Dark blue** = low relevance | π’ **Cyan/Green** = high relevance""" | |
| else: | |
| result_text = f"""## π Similarity Score: **{main_score:.4f}** | |
| *Heatmap visualization is not available for this model configuration.*""" | |
| # Generate VLM answer using top-ranked document | |
| vlm_answer = "" | |
| if enable_vlm: | |
| vlm_answer = generate_vlm_answer(top_document_image, query) | |
| if enable_corpus and top_document_label != "π Your Document": | |
| vlm_answer = f"*[Answer based on top-ranked document: {top_document_label}]*\n\n{vlm_answer}" | |
| gr.Info("β Analysis complete!") | |
| return main_score, result_text, heatmap_image, vlm_answer, ranking_info | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"β Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```" | |
| return None, error_msg, None, "", None | |
| def create_demo(): | |
| available_examples = get_available_examples() | |
| with gr.Blocks(title="SauerkrautLM-ColPali Demo") as demo: | |
| # Header with logo | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px 0;"> | |
| <div style="margin: 0 auto 20px auto; max-width: 800px;"> | |
| <img src="https://vago-solutions.ai/wp-content/uploads/2025/12/Sauerkrautlm-colpali-scaled.png" | |
| alt="SauerkrautLM-ColPali" | |
| style="width: 75%; border-radius: 12px;"/> | |
| </div> | |
| <p style="color: #888; font-size: 1.2rem; margin: 0 0 16px 0;"> | |
| Visual Document Retrieval with Multi-Vector Embeddings + VLM Answer Generation | |
| </p> | |
| <div style="margin-top: 16px;"> | |
| <a href="https://huggingface.co/VAGOsolutions" target="_blank" style="color: #667eea; text-decoration: none; margin: 0 12px; font-weight: 500;">π€ Models</a> | |
| <a href="https://github.com/VAGOsolutions/sauerkrautlm-colpali" target="_blank" style="color: #667eea; text-decoration: none; margin: 0 12px; font-weight: 500;">π GitHub</a> | |
| <a href="https://vago-solutions.ai" target="_blank" style="color: #667eea; text-decoration: none; margin: 0 12px; font-weight: 500;">π VAGO Solutions</a> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left Column - Inputs | |
| with gr.Column(scale=1): | |
| gr.HTML('<h3 style="color: #b0b0b0; margin-bottom: 16px;">βοΈ Configuration</h3>') | |
| model_dropdown = gr.Dropdown( | |
| choices=list(COLPALI_MODELS.keys()), | |
| value="SauerkrautLM-ColQwen3-2B (Balanced, 4.4GB)", | |
| label="π Retrieval Model", | |
| info="ColPali-based model for document retrieval and heatmap", | |
| ) | |
| with gr.Row(): | |
| enable_vlm = gr.Checkbox( | |
| label="π€ Enable VLM", | |
| value=False, | |
| info="Use Qwen3-VL-4B for answers (adds ~30-60s on first use)", | |
| ) | |
| enable_corpus = gr.Checkbox( | |
| label="π Compare with Corpus", | |
| value=True, | |
| info="Rank against 31 example documents", | |
| ) | |
| gr.HTML('<h3 style="color: #b0b0b0; margin: 24px 0 16px 0;">π Document</h3>') | |
| image_input = gr.Image( | |
| label="Upload Document Image", | |
| type="pil", | |
| height=350, | |
| ) | |
| query_input = gr.Textbox( | |
| label="π Search Query", | |
| placeholder="e.g., What is the total revenue? / Wie hoch ist der Umsatz?", | |
| lines=2, | |
| ) | |
| submit_btn = gr.Button( | |
| "π Analyze Document", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| # Right Column - Results | |
| with gr.Column(scale=1): | |
| gr.HTML('<h3 style="color: #b0b0b0; margin-bottom: 16px;">π Results</h3>') | |
| with gr.Group(): | |
| score_output = gr.Number( | |
| label="Similarity Score", | |
| precision=4, | |
| ) | |
| result_markdown = gr.Markdown( | |
| value="*Upload an image and enter a query to get started*", | |
| ) | |
| gr.HTML('<h3 style="color: #b0b0b0; margin: 24px 0 16px 0;">π₯ Similarity Heatmap</h3>') | |
| heatmap_output = gr.Image( | |
| label="Heatmap Visualization", | |
| type="pil", | |
| height=400, | |
| ) | |
| gr.HTML(""" | |
| <div style="display: flex; align-items: center; gap: 12px; padding: 12px; background: rgba(255,255,255,0.03); border-radius: 8px; margin-top: 8px;"> | |
| <div style="width: 150px; height: 20px; background: linear-gradient(90deg, #0b0924 0%, #1f1147 20%, #3b1c6c 35%, #4a3880 50%, #3e7a8c 70%, #5ec5c0 85%, #c3f0e4 100%); border-radius: 4px;"></div> | |
| <span style="color: #888; font-size: 0.9rem;">Low β High Relevance (mako colormap)</span> | |
| </div> | |
| """) | |
| with gr.Accordion("π Corpus Ranking", open=True): | |
| ranking_output = gr.Markdown( | |
| value="*Enable 'Compare with Corpus' to see how your document ranks*", | |
| ) | |
| with gr.Accordion("π€ VLM Answer", open=True, visible=True): | |
| vlm_answer_output = gr.Textbox( | |
| label="Answer from Qwen3-VL-4B", | |
| lines=6, | |
| interactive=False, | |
| placeholder="Enable VLM and analyze to get an AI-generated answer...", | |
| ) | |
| # Examples section | |
| if available_examples: | |
| gr.HTML('<h3 style="color: #b0b0b0; margin: 32px 0 16px 0;">π Example Documents (31 multilingual documents)</h3>') | |
| gr.HTML(""" | |
| <div style="padding: 12px; background: rgba(102, 126, 234, 0.1); border-radius: 8px; margin-bottom: 16px;"> | |
| <p style="color: #a0a0a0; margin: 0; font-size: 0.9rem;"> | |
| π <strong>Languages:</strong> German (DE), English (EN), French (FR), Spanish (ES)<br> | |
| π <strong>Categories:</strong> Annual Reports, ESG, Tax Forms, Scientific Papers, Energy/Hydrogen, Healthcare, Economic Forecasts | |
| </p> | |
| </div> | |
| """) | |
| gr.Examples( | |
| examples=available_examples, | |
| inputs=[image_input, query_input], | |
| label="Click an example to load it", | |
| ) | |
| # Info section | |
| gr.HTML(""" | |
| <details style="margin-top: 12px; padding: 12px; background: rgba(255,200,100,0.08); border: 1px solid rgba(255,200,100,0.2); border-radius: 8px;"> | |
| <summary style="cursor: pointer; color: #e0c080; font-weight: 500;">βΉοΈ About Heatmap Differences: ColQwen3 vs ColLFM2</summary> | |
| <div style="margin-top: 12px; color: #a0a0a0; font-size: 0.85rem; line-height: 1.6;"> | |
| <p><strong style="color: #80c0ff;">ColQwen3 (Qwen-based):</strong> Uses Qwen3-VL's vision encoder which preserves strong spatial locality in patch embeddings. Each image patch maintains distinct features, allowing query tokens to differentiate between regions. Result: <em>precise, localized heatmaps</em>.</p> | |
| <p style="margin-top: 8px;"><strong style="color: #ffa080;">ColLFM2 (LFM2-based):</strong> Uses SigLIP2 NaFlex vision encoder with <em>pixel unshuffle</em> for efficient token reduction. This merges spatial information across patches, producing more "holistic" embeddings. Query tokens show high correlation (~0.97) across all patches. Result: <em>region-level relevance</em> rather than word-level precision.</p> | |
| <p style="margin-top: 8px;"><strong style="color: #80ffa0;">Why ColLFM2 still performs well for retrieval:</strong> The subtle similarity differences (e.g., 0.88 vs 0.94) are sufficient for ranking documents correctly. ColLFM2 excels at determining <em>if</em> a document is relevant, while ColQwen3 better shows <em>where</em> the relevance is.</p> | |
| </div> | |
| </details> | |
| """) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 32px 0 16px 0; border-top: 1px solid rgba(255,255,255,0.1); margin-top: 32px;"> | |
| <p style="color: #666; font-size: 0.9rem;"> | |
| π‘ <b>Tip:</b> ColQwen3 models provide the best heatmap visualization. Enable VLM (Qwen3-VL-4B) for AI-generated answers. | |
| </p> | |
| <p style="color: #555; font-size: 0.85rem; margin-top: 8px;"> | |
| Made with β€οΈ by <a href="https://vago-solutions.ai" target="_blank" style="color: #667eea;">VAGO Solutions</a> | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=process_query_with_corpus, | |
| inputs=[model_dropdown, image_input, query_input, enable_vlm, enable_corpus], | |
| outputs=[score_output, result_markdown, heatmap_output, vlm_answer_output, ranking_output], | |
| ) | |
| query_input.submit( | |
| fn=process_query_with_corpus, | |
| inputs=[model_dropdown, image_input, query_input, enable_vlm, enable_corpus], | |
| outputs=[score_output, result_markdown, heatmap_output, vlm_answer_output, ranking_output], | |
| ) | |
| # Show hint when VLM is enabled | |
| enable_vlm.change( | |
| fn=on_vlm_toggle, | |
| inputs=[enable_vlm], | |
| outputs=[enable_vlm], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| install_fa2() | |
| demo = create_demo() | |
| demo.queue(max_size=10).launch(debug=True) | |