""" SauerkrautLM-ColPali Demo Space Visual Document Retrieval with Similarity Heat Maps + VLM Answer Generation Multi-document indexing for realistic retrieval scenarios """ import os import gradio as gr import torch import numpy as np from PIL import Image from einops import rearrange from huggingface_hub import login import spaces import math from pathlib import Path from typing import List, Tuple, Optional, Dict # Import model classes at startup from sauerkrautlm_colpali.models.lfm2.collfm2.modeling_collfm2 import ColLFM2 from sauerkrautlm_colpali.models.lfm2.collfm2.processing_collfm2 import ColLFM2Processor from sauerkrautlm_colpali.models.qwen3.colqwen3.modeling_colqwen3 import ColQwen3 from sauerkrautlm_colpali.models.qwen3.colqwen3.processing_colqwen3 import ColQwen3Processor print("All model imports successful!") EPSILON = 1e-10 def install_fa2(): print("Installing Flash Attention 2...") os.system("pip install flash-attn --no-build-isolation") # HF Token for private models hf_token = os.getenv("HF_KEY") if hf_token: login(token=hf_token) # ColPali Model options COLPALI_MODELS = { "SauerkrautLM-ColLFM2-450M (Fastest, 0.9GB)": "VAGOsolutions/SauerkrautLM-ColLFM2-450M-v0.1", "SauerkrautLM-ColQwen3-1.7B-Turbo (Fast, 3.4GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-1.7b-Turbo-v0.1", "SauerkrautLM-ColQwen3-2B (Balanced, 4.4GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-2b-v0.1", "SauerkrautLM-ColQwen3-4B (Quality, 8GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-4b-v0.1", "SauerkrautLM-ColQwen3-8B (Best, 16GB)": "VAGOsolutions/SauerkrautLM-ColQwen3-8b-v0.1", } # Global model cache loaded_colpali_model = None loaded_colpali_processor = None loaded_colpali_model_name = None loaded_vlm_model = None loaded_vlm_processor = None # ============================================================================= # EXAMPLE CONFIGURATION - Organized by language and use case # ============================================================================= EXAMPLE_CONFIG = [ # ==================== GERMAN (DE) ==================== # Annual Reports { "file": "deutsch/2024-infineon-geschaeftsbericht-v01-00-de_p4.png", "query": "Wie hoch ist der Umsatz von Infineon?", "category": "📊 Annual Report", "lang": "🇩🇪 DE", "description": "Infineon Geschäftsbericht 2024 - Kennzahlen" }, { "file": "deutsch/BASF_Bericht_2024_p90.png", "query": "Wie ist das Risikomanagement bei BASF organisiert?", "category": "📊 Annual Report", "lang": "🇩🇪 DE", "description": "BASF Bericht 2024 - Risikomanagement" }, { "file": "deutsch/entire-dtag-gb24_p351.png", "query": "Wie setzt sich die Vorstandsvergütung bei der Deutschen Telekom zusammen?", "category": "📊 Annual Report", "lang": "🇩🇪 DE", "description": "Deutsche Telekom GB 2024 - Vergütung" }, { "file": "deutsch/entire-dtag-gb24_p77.png", "query": "Wie hoch sind die Nettofinanzverbindlichkeiten der Deutschen Telekom?", "category": "📊 Annual Report", "lang": "🇩🇪 DE", "description": "Deutsche Telekom GB 2024 - Finanzen" }, # Hydrogen / Energy { "file": "deutsch/Bildschirmfoto 2025-12-14 um 01.31.45.png", "query": "Wo verläuft die geplante Wasserstofftrasse in Leipzig?", "category": "⚡ Hydrogen/Energy", "lang": "🇩🇪 DE", "description": "Wasserstoff-Infrastruktur Leipzig" }, { "file": "deutsch/Bildschirmfoto 2025-12-14 um 01.36.01.png", "query": "Wie wird die Gewässerquerung der Wasserstoffleitung realisiert?", "category": "⚡ Hydrogen/Energy", "lang": "🇩🇪 DE", "description": "Wasserstoff-Technische Zeichnung" }, # Tax/Forms { "file": "deutsch/ESt_1_A_2022_p2.png", "query": "Wo trage ich meine Bankverbindung in der Steuererklärung ein?", "category": "📝 Tax Form", "lang": "🇩🇪 DE", "description": "Einkommensteuererklärung ESt 1A" }, # Economic Reports { "file": "deutsch/Monatsbericht---Oktober-2025_p152.png", "query": "Wie hoch sind die aktuellen Zinssätze für Wohnungsbaukredite?", "category": "💰 Financial Report", "lang": "🇩🇪 DE", "description": "Bundesbank Monatsbericht - Zinssätze" }, { "file": "deutsch/sd-2025-digital-07-wollmershaeuser-etal-ifo-konjunkturprognose-sommer-2025_p23.png", "query": "Was sind die Annahmen für den Ölpreis in der ifo-Prognose?", "category": "📈 Economic Forecast", "lang": "🇩🇪 DE", "description": "ifo Konjunkturprognose 2025" }, # Environmental { "file": "deutsch/rep0913_p197.png", "query": "Wie hat sich der Fleischkonsum in Österreich entwickelt?", "category": "🌱 Environmental", "lang": "🇩🇪 DE", "description": "Klimaschutzbericht - Umweltbundesamt" }, # ==================== ENGLISH (EN) ==================== # ESG/Sustainability { "file": "englisch/2025051910270996484_p19.png", "query": "What is the waste recycling rate at CRRC?", "category": "🌱 ESG Report", "lang": "🇬🇧 EN", "description": "CRRC ESG Report 2024 - Waste Management" }, # Scientific { "file": "englisch/6e81bb8284357ea1773e99832d21c65b_new_myosinlecture_ag_p5.png", "query": "What are the different classes of myosin?", "category": "🔬 Scientific Paper", "lang": "🇬🇧 EN", "description": "Myosin Phylogenetic Tree" }, # Historical/Vintage { "file": "englisch/ADVE_0004.png", "query": "Which cigarette brand advertised 'Call for Philip Morris'?", "category": "📜 Historical Document", "lang": "🇬🇧 EN", "description": "Vintage Philip Morris Advertisement" }, # Business Forms { "file": "englisch/Form_0033.png", "query": "How long should vendor audit records be retained?", "category": "📝 Business Form", "lang": "🇬🇧 EN", "description": "Records Retention Schedule" }, { "file": "englisch/Letter_0061.png", "query": "Who requested copies of the 'Helping Youth Decide' booklet?", "category": "✉️ Business Letter", "lang": "🇬🇧 EN", "description": "Girl Scouts Correspondence" }, # Financial Reports { "file": "englisch/NASDAQ_DDD_2024_p76.png", "query": "What is the total stockholders' equity of 3D Systems?", "category": "📊 Annual Report", "lang": "🇬🇧 EN", "description": "3D Systems Financial Statement" }, { "file": "englisch/TMUS-2024-Annual-Report_p143.png", "query": "Who is the CEO of T-Mobile US?", "category": "📊 Annual Report", "lang": "🇬🇧 EN", "description": "T-Mobile US Annual Report 2024" }, { "file": "englisch/pwc-transparency-report-2023-2024_p33.png", "query": "What are the rotation periods for audit partners at PwC?", "category": "📋 Transparency Report", "lang": "🇬🇧 EN", "description": "PwC Transparency Report" }, # ==================== FRENCH (FR) ==================== { "file": "französisch/194000315_0_p178.png", "query": "Quel est le coût du travail au SMIC en France?", "category": "💰 Labor Statistics", "lang": "🇫🇷 FR", "description": "Statistiques du travail - SMIC" }, { "file": "französisch/194000315_0_p21.png", "query": "Quelle est la prévision de croissance du PIB en zone euro?", "category": "📈 Economic Forecast", "lang": "🇫🇷 FR", "description": "Prévisions économiques Zone Euro" }, { "file": "französisch/Cours-de-physique-1v2_p44.png", "query": "Comment fonctionne la vision des couleurs?", "category": "🎓 Educational", "lang": "🇫🇷 FR", "description": "Cours de Physique - Vision" }, { "file": "französisch/CSSF_RA_2024_FR_p14.png", "query": "Quelle est la répartition des employés de la CSSF par nationalité?", "category": "📊 Annual Report", "lang": "🇫🇷 FR", "description": "CSSF Luxembourg - Rapport Annuel" }, { "file": "französisch/ICN_Definition-Nursing_Report_FR_Web_p47.png", "query": "Combien d'associations nationales d'infirmières participent au CII?", "category": "🏥 Healthcare Report", "lang": "🇫🇷 FR", "description": "ICN Rapport Infirmières" }, { "file": "französisch/rapport-cns-2024-internet_p14.png", "query": "Quelles sont les principales missions de la CNS?", "category": "🏥 Healthcare Report", "lang": "🇫🇷 FR", "description": "CNS Rapport Annuel" }, { "file": "französisch/rapport-esg-2024.pdf.coredownload.inline_p29.png", "query": "Quels sont les objectifs ESG pour 2024?", "category": "🌱 ESG Report", "lang": "🇫🇷 FR", "description": "Rapport ESG 2024" }, # ==================== SPANISH (ES) ==================== { "file": "spanisch/Coeur-ESG-Report-23-May-2024-Spanish-version-compressed_p31.png", "query": "¿Cuáles son las emisiones de gases de efecto invernadero de Coeur Mining?", "category": "🌱 ESG Report", "lang": "🇪🇸 ES", "description": "Coeur Mining ESG - Emisiones" }, { "file": "spanisch/Coeur-ESG-Report-23-May-2024-Spanish-version-compressed_p39.png", "query": "¿Qué medidas de seguridad implementa Coeur Mining?", "category": "🌱 ESG Report", "lang": "🇪🇸 ES", "description": "Coeur Mining ESG - Seguridad" }, { "file": "spanisch/Informe-Economico-Regional-2022-2023_p112.png", "query": "¿Cuál es la situación económica regional en 2023?", "category": "📈 Economic Report", "lang": "🇪🇸 ES", "description": "Informe Económico Regional" }, { "file": "spanisch/Informe-Sostenibilidad-ESG-2024_p16.png", "query": "¿Cuáles son los objetivos de sostenibilidad para 2024?", "category": "🌱 ESG Report", "lang": "🇪🇸 ES", "description": "Informe Sostenibilidad ESG" }, { "file": "spanisch/MAPs_PLAN_DESARROLLO_p14.png", "query": "¿Cuántas propuestas de renovables se recibieron en España?", "category": "⚡ Energy Infrastructure", "lang": "🇪🇸 ES", "description": "Plan Desarrollo Red Eléctrica" }, { "file": "spanisch/Presupuestos_p32.png", "query": "¿Cuál es el presupuesto total asignado?", "category": "💰 Budget Document", "lang": "🇪🇸 ES", "description": "Presupuestos Generales" }, ] def get_all_example_images() -> List[Tuple[str, Image.Image]]: """Load all example images for multi-document indexing.""" examples_dir = Path(__file__).parent / "demopics" images = [] for example in EXAMPLE_CONFIG: filepath = examples_dir / example["file"] if filepath.exists(): try: img = Image.open(filepath).convert("RGB") images.append((str(filepath), img)) except Exception as e: print(f"Error loading {filepath}: {e}") return images def get_available_examples(): """Load examples for the Gradio Examples component (shuffled).""" import random examples_dir = Path(__file__).parent / "demopics" available = [] for example in EXAMPLE_CONFIG: filepath = examples_dir / example["file"] if filepath.exists(): available.append([str(filepath), example["query"]]) # Shuffle to mix languages random.seed(42) # Consistent shuffle random.shuffle(available) return available if available else None def get_example_gallery_data(): """Get data for the example gallery with categories.""" examples_dir = Path(__file__).parent / "demopics" gallery_data = [] for example in EXAMPLE_CONFIG: filepath = examples_dir / example["file"] if filepath.exists(): gallery_data.append({ "path": str(filepath), "query": example["query"], "label": f"{example['lang']} {example['category']}: {example['description']}", "category": example["category"], "lang": example["lang"], }) return gallery_data @spaces.GPU def load_colpali_model(model_choice: str): """Load the selected ColPali model with proper device placement.""" global loaded_colpali_model, loaded_colpali_processor, loaded_colpali_model_name model_name = COLPALI_MODELS[model_choice] if loaded_colpali_model_name == model_name and loaded_colpali_model is not None: gr.Info(f"✅ {model_choice} ready!") return loaded_colpali_model, loaded_colpali_processor gr.Info(f"⏳ Loading {model_choice}... Please wait.") if loaded_colpali_model is not None: gr.Info("🔄 Unloading previous model...") try: del loaded_colpali_model del loaded_colpali_processor except Exception: pass torch.cuda.empty_cache() try: import flash_attn attn_impl = "flash_attention_2" gr.Info("⚡ Using Flash Attention 2") except ImportError: attn_impl = "sdpa" gr.Info("🔧 Using SDPA attention") print(f"Loading {model_name} with attention: {attn_impl}") if "ColLFM2" in model_name: gr.Info("📥 Downloading model weights...") loaded_colpali_model = ColLFM2.from_pretrained( model_name, torch_dtype=torch.bfloat16, attn_implementation=attn_impl, token=hf_token, ).eval().to("cuda") gr.Info("📥 Downloading processor...") loaded_colpali_processor = ColLFM2Processor.from_pretrained(model_name, token=hf_token) elif "ColQwen3" in model_name: gr.Info("📥 Downloading model weights...") loaded_colpali_model = ColQwen3.from_pretrained( model_name, torch_dtype=torch.bfloat16, attn_implementation=attn_impl, device_map="cuda", token=hf_token, ).eval() gr.Info("📥 Downloading processor...") loaded_colpali_processor = ColQwen3Processor.from_pretrained(model_name, token=hf_token) else: raise ValueError(f"Unknown model type: {model_name}") loaded_colpali_model_name = model_name gr.Info(f"✅ {model_choice} loaded and ready!") return loaded_colpali_model, loaded_colpali_processor @spaces.GPU def load_vlm_model(): """Load Qwen3-VL-4B for answer generation.""" global loaded_vlm_model, loaded_vlm_processor if loaded_vlm_model is not None: gr.Info("✅ Qwen3-VL-4B ready!") return loaded_vlm_model, loaded_vlm_processor gr.Info("⏳ Loading Qwen3-VL-4B-Instruct... Please wait.") from transformers import AutoModelForImageTextToText, AutoProcessor vlm_model_name = "Qwen/Qwen3-VL-4B-Instruct" print(f"Loading VLM: {vlm_model_name}") gr.Info("📥 Downloading VLM model weights (8GB)...") loaded_vlm_model = AutoModelForImageTextToText.from_pretrained( vlm_model_name, torch_dtype=torch.bfloat16, device_map="cuda", token=hf_token, ).eval() gr.Info("📥 Downloading VLM processor...") loaded_vlm_processor = AutoProcessor.from_pretrained(vlm_model_name, token=hf_token) gr.Info("✅ Qwen3-VL-4B loaded and ready!") return loaded_vlm_model, loaded_vlm_processor def on_vlm_toggle(enabled): """Show hint when VLM is enabled.""" if enabled: gr.Info("ℹ️ VLM (Qwen3-VL-4B) will be loaded on first analysis. This adds ~30-60 seconds.") return enabled def get_similarity_maps_from_embeddings( image_embeddings: torch.Tensor, query_embeddings: torch.Tensor, n_patches: Tuple[int, int], image_mask: torch.Tensor, query_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """EXACT ColPali implementation of similarity map computation.""" idx = 0 n_patches_x, n_patches_y = n_patches[0], n_patches[1] n_image_tokens = int(image_mask[idx].sum().item()) expected_tokens = n_patches_x * n_patches_y if n_image_tokens != expected_tokens: n = n_image_tokens sqrt_n = int(math.sqrt(n)) for i in range(sqrt_n, 0, -1): if n % i == 0: n_patches_x, n_patches_y = n // i, i break image_embedding_grid = rearrange( image_embeddings[idx][image_mask[idx]], "(h w) c -> w h c", w=n_patches_x, h=n_patches_y, ) query_emb = query_embeddings[idx] if query_mask is not None: query_emb = query_emb[query_mask[idx]] similarity_map = torch.einsum( "nk,ijk->nij", query_emb, image_embedding_grid, ) return similarity_map def create_heatmap_overlay( image: Image.Image, similarity_map: torch.Tensor, alpha: float = 0.5, skip_normalize: bool = False, ) -> Image.Image: """Create heatmap overlay following EXACT ColPali visualization.""" import seaborn as sns sim_float = similarity_map.float() if skip_normalize: sim_array = sim_float.cpu().numpy() else: min_val = sim_float.min() max_val = sim_float.max() sim_normalized = (sim_float - min_val) / (max_val - min_val + 1e-10) sim_array = sim_normalized.cpu().numpy() sim_array = rearrange(sim_array, "h w -> w h") sim_image = Image.fromarray((sim_array * 255).astype(np.uint8)) sim_image = sim_image.resize(image.size, Image.Resampling.BICUBIC) sim_resized = np.array(sim_image) / 255.0 cmap = sns.color_palette("mako", as_cmap=True) heatmap_rgba = cmap(sim_resized) heatmap = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8) img_array = np.array(image.convert("RGB")).astype(np.float32) heatmap_float = heatmap.astype(np.float32) blended = img_array * (1 - alpha) + heatmap_float * alpha blended = np.clip(blended, 0, 255).astype(np.uint8) return Image.fromarray(blended) def get_collfm2_heatmap(model, processor, image, image_embeddings, query_embeddings, batch_images, batch_queries): """Generate heatmap for ColLFM2 models (simplified version).""" try: if "input_ids" not in batch_images: return None input_ids = batch_images["input_ids"][0] tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor.processor.tokenizer image_token_id = tokenizer.convert_tokens_to_ids('') image_mask = input_ids == image_token_id n_image_tokens = image_mask.sum().item() # Find best grid that matches the token count and image aspect ratio img_width, img_height = image.size img_ratio = img_width / img_height best_diff = float('inf') n_patches_x, n_patches_y = int(math.sqrt(n_image_tokens)), int(math.sqrt(n_image_tokens)) for i in range(1, int(math.sqrt(n_image_tokens)) + 1): if n_image_tokens % i == 0: j = n_image_tokens // i ratio1 = j / i ratio2 = i / j diff1 = abs(img_ratio - ratio1) diff2 = abs(img_ratio - ratio2) if diff1 < best_diff: best_diff = diff1 n_patches_x, n_patches_y = j, i if diff2 < best_diff: best_diff = diff2 n_patches_x, n_patches_y = i, j query_emb = query_embeddings[0] # Filter padding pad_token_id = getattr(tokenizer, "pad_token_id", 0) or 0 query_mask = batch_queries["input_ids"][0] != pad_token_id query_emb = query_emb[query_mask] # Get image embeddings image_emb = image_embeddings[0][image_mask][:n_patches_x * n_patches_y] image_grid = rearrange(image_emb, "(h w) c -> w h c", w=n_patches_x, h=n_patches_y) # Compute similarity similarity_map = torch.einsum("nk,ijk->nij", query_emb, image_grid) aggregated = similarity_map.max(dim=0).values # Aggressive normalization for ColLFM2 agg_float = aggregated.float() threshold = torch.quantile(agg_float.flatten(), 0.90) hot_mask = agg_float > threshold min_hot = agg_float[hot_mask].min() if hot_mask.sum() > 0 else threshold max_hot = agg_float.max() normalized = torch.zeros_like(agg_float) if hot_mask.sum() > 0: normalized[hot_mask] = 0.5 + 0.5 * (agg_float[hot_mask] - min_hot) / (max_hot - min_hot + 1e-10) return create_heatmap_overlay(image, normalized, skip_normalize=True) except Exception as e: import traceback print(f"ColLFM2 Heatmap error: {traceback.format_exc()}") return None @spaces.GPU def generate_vlm_answer(image: Image.Image, query: str) -> str: """Generate an answer using Qwen3-VL-4B-Instruct.""" try: vlm_model, vlm_processor = load_vlm_model() messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": f"Based on this document image, please answer the following question:\n\n{query}\n\nProvide a clear and concise answer based only on the information visible in the document."}, ], } ] text = vlm_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) from qwen_vl_utils import process_vision_info image_inputs, video_inputs = process_vision_info(messages) inputs = vlm_processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to("cuda") with torch.no_grad(): generated_ids = vlm_model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] answer = vlm_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return answer except Exception as e: import traceback return f"Error generating answer: {str(e)}" @spaces.GPU def process_query_with_corpus(model_choice: str, image: Image.Image, query: str, enable_vlm: bool, enable_corpus: bool): """ Process a query against an image and optionally against all corpus documents. Returns similarity score, ranking info, heatmap, and optional VLM answer. """ if image is None: return None, "⚠️ Please upload an image.", None, "", None if not query.strip(): return None, "⚠️ Please enter a search query.", None, "", None try: model, processor = load_colpali_model(model_choice) device = next(model.parameters()).device if image.mode != "RGB": image = image.convert("RGB") # Process query batch_queries = processor.process_queries([query]).to(device) with torch.no_grad(): query_embeddings = model(**batch_queries) # Process main image batch_images = processor.process_images([image]).to(device) with torch.no_grad(): image_embeddings = model(**batch_images) gr.Info("📊 Computing similarity scores...") scores = processor.score(query_embeddings, image_embeddings) main_score = scores[0][0].item() # Initialize for VLM ranking_info = None top_document_image = image top_document_label = "Your Document" if enable_corpus: gr.Info("📚 Indexing corpus documents...") # Index all corpus documents corpus_images = get_all_example_images() if corpus_images: # Check if user's image matches any corpus image by comparing pixels user_img_array = np.array(image.resize((64, 64))) user_is_example = False user_example_idx = -1 for idx, (path, img) in enumerate(corpus_images): corpus_img_array = np.array(img.resize((64, 64))) if np.allclose(user_img_array, corpus_img_array, atol=10): user_is_example = True user_example_idx = idx break all_scores = [] # Process ALL corpus images batch_size = 4 for i in range(0, len(corpus_images), batch_size): batch_imgs = [img for _, img in corpus_images[i:i+batch_size]] batch_corpus = processor.process_images(batch_imgs).to(device) with torch.no_grad(): corpus_embeddings = model(**batch_corpus) corpus_scores = processor.score(query_embeddings, corpus_embeddings) for j, (path, img) in enumerate(corpus_images[i:i+batch_size]): score = corpus_scores[0][j].item() example_name = Path(path).name is_user_doc = (i + j == user_example_idx) label = example_name for ex in EXAMPLE_CONFIG: if ex["file"].endswith(example_name): label = f"{ex['lang']} {ex['description']}" break if is_user_doc: label = f"📄 {label} (Selected)" all_scores.append((score, label, path, img)) # If user uploaded custom document (not from corpus), add it if not user_is_example: all_scores.append((main_score, "📄 Your Document", None, image)) # Sort by score descending all_scores.sort(key=lambda x: x[0], reverse=True) # Get top document for VLM top_score, top_label, top_path, top_img = all_scores[0] top_document_image = top_img top_document_label = top_label # Find rank of user's document user_rank = next((i for i, (_, label, _, _) in enumerate(all_scores) if "📄" in label), 0) + 1 # Build ranking display total_docs = len(all_scores) ranking_lines = [f"### 📊 Ranking (out of {total_docs} documents)"] ranking_lines.append(f"**Your document ranks #{user_rank}**\n") ranking_lines.append("| Rank | Score | Document |") ranking_lines.append("|------|-------|----------|") for rank, (score, label, _, _) in enumerate(all_scores[:10], 1): marker = "👉 " if "📄" in label else "" ranking_lines.append(f"| {rank} | {score:.4f} | {marker}{label} |") if len(all_scores) > 10: ranking_lines.append(f"| ... | ... | *{len(all_scores) - 10} more documents* |") ranking_info = "\n".join(ranking_lines) # Generate heatmap gr.Info("🔥 Generating similarity heatmap...") heatmap_image = None heatmap_available = False if "ColQwen3" in loaded_colpali_model_name: try: spatial_merge_size = getattr(model, "spatial_merge_size", 2) n_patches = processor.get_n_patches( image_size=image.size, spatial_merge_size=spatial_merge_size, ) image_mask = batch_images["input_ids"] == processor.image_token_id pad_token_id = getattr(processor.tokenizer, "pad_token_id", 0) query_mask = batch_queries["input_ids"] != pad_token_id similarity_map = get_similarity_maps_from_embeddings( image_embeddings, query_embeddings, n_patches, image_mask, query_mask, ) aggregated = similarity_map.max(dim=0).values heatmap_image = create_heatmap_overlay(image, aggregated) heatmap_available = True except Exception as e: import traceback print(f"Heatmap error: {traceback.format_exc()}") elif "ColLFM2" in loaded_colpali_model_name: heatmap_image = get_collfm2_heatmap( model, processor, image, image_embeddings, query_embeddings, batch_images, batch_queries ) if heatmap_image is not None: heatmap_available = True # Build result text user_image_path = None # Try to detect if user's image is from examples if hasattr(image, 'filename'): user_image_path = image.filename if heatmap_available: if "ColLFM2" in loaded_colpali_model_name: result_text = f"""## 📊 Similarity Score: **{main_score:.4f}** 🔵 **Dark blue** = low relevance | 🟢 **Cyan/Green** = high relevance ⚠️ **ColLFM2 Heatmap Note:** This model uses a SigLIP2 vision encoder with pixel unshuffle, producing "holistic" embeddings. The heatmap shows **region-level relevance** rather than precise word-level localization. This is expected behavior - ColLFM2 excels at determining *if* a document is relevant. For precise heatmaps, try ColQwen3 models.""" else: result_text = f"""## 📊 Similarity Score: **{main_score:.4f}** The heatmap shows which areas of the document are most relevant to your query. 🔵 **Dark blue** = low relevance | 🟢 **Cyan/Green** = high relevance""" else: result_text = f"""## 📊 Similarity Score: **{main_score:.4f}** *Heatmap visualization is not available for this model configuration.*""" # Generate VLM answer using top-ranked document vlm_answer = "" if enable_vlm: vlm_answer = generate_vlm_answer(top_document_image, query) if enable_corpus and top_document_label != "📄 Your Document": vlm_answer = f"*[Answer based on top-ranked document: {top_document_label}]*\n\n{vlm_answer}" gr.Info("✅ Analysis complete!") return main_score, result_text, heatmap_image, vlm_answer, ranking_info except Exception as e: import traceback error_msg = f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```" return None, error_msg, None, "", None def create_demo(): available_examples = get_available_examples() with gr.Blocks(title="SauerkrautLM-ColPali Demo") as demo: # Header with logo gr.HTML("""
SauerkrautLM-ColPali

Visual Document Retrieval with Multi-Vector Embeddings + VLM Answer Generation

🤗 Models 📖 GitHub 🌐 VAGO Solutions
""") with gr.Row(): # Left Column - Inputs with gr.Column(scale=1): gr.HTML('

⚙️ Configuration

') model_dropdown = gr.Dropdown( choices=list(COLPALI_MODELS.keys()), value="SauerkrautLM-ColQwen3-2B (Balanced, 4.4GB)", label="🔍 Retrieval Model", info="ColPali-based model for document retrieval and heatmap", ) with gr.Row(): enable_vlm = gr.Checkbox( label="🤖 Enable VLM", value=False, info="Use Qwen3-VL-4B for answers (adds ~30-60s on first use)", ) enable_corpus = gr.Checkbox( label="📚 Compare with Corpus", value=True, info="Rank against 31 example documents", ) gr.HTML('

📄 Document

') image_input = gr.Image( label="Upload Document Image", type="pil", height=350, ) query_input = gr.Textbox( label="🔍 Search Query", placeholder="e.g., What is the total revenue? / Wie hoch ist der Umsatz?", lines=2, ) submit_btn = gr.Button( "🚀 Analyze Document", variant="primary", size="lg", ) # Right Column - Results with gr.Column(scale=1): gr.HTML('

📊 Results

') with gr.Group(): score_output = gr.Number( label="Similarity Score", precision=4, ) result_markdown = gr.Markdown( value="*Upload an image and enter a query to get started*", ) gr.HTML('

🔥 Similarity Heatmap

') heatmap_output = gr.Image( label="Heatmap Visualization", type="pil", height=400, ) gr.HTML("""
Low → High Relevance (mako colormap)
""") with gr.Accordion("📚 Corpus Ranking", open=True): ranking_output = gr.Markdown( value="*Enable 'Compare with Corpus' to see how your document ranks*", ) with gr.Accordion("🤖 VLM Answer", open=True, visible=True): vlm_answer_output = gr.Textbox( label="Answer from Qwen3-VL-4B", lines=6, interactive=False, placeholder="Enable VLM and analyze to get an AI-generated answer...", ) # Examples section if available_examples: gr.HTML('

📚 Example Documents (31 multilingual documents)

') gr.HTML("""

🌍 Languages: German (DE), English (EN), French (FR), Spanish (ES)
📂 Categories: Annual Reports, ESG, Tax Forms, Scientific Papers, Energy/Hydrogen, Healthcare, Economic Forecasts

""") gr.Examples( examples=available_examples, inputs=[image_input, query_input], label="Click an example to load it", ) # Info section gr.HTML("""
ℹ️ About Heatmap Differences: ColQwen3 vs ColLFM2

ColQwen3 (Qwen-based): Uses Qwen3-VL's vision encoder which preserves strong spatial locality in patch embeddings. Each image patch maintains distinct features, allowing query tokens to differentiate between regions. Result: precise, localized heatmaps.

ColLFM2 (LFM2-based): Uses SigLIP2 NaFlex vision encoder with pixel unshuffle for efficient token reduction. This merges spatial information across patches, producing more "holistic" embeddings. Query tokens show high correlation (~0.97) across all patches. Result: region-level relevance rather than word-level precision.

Why ColLFM2 still performs well for retrieval: The subtle similarity differences (e.g., 0.88 vs 0.94) are sufficient for ranking documents correctly. ColLFM2 excels at determining if a document is relevant, while ColQwen3 better shows where the relevance is.

""") # Footer gr.HTML("""

💡 Tip: ColQwen3 models provide the best heatmap visualization. Enable VLM (Qwen3-VL-4B) for AI-generated answers.

Made with ❤️ by VAGO Solutions

""") # Event handlers submit_btn.click( fn=process_query_with_corpus, inputs=[model_dropdown, image_input, query_input, enable_vlm, enable_corpus], outputs=[score_output, result_markdown, heatmap_output, vlm_answer_output, ranking_output], ) query_input.submit( fn=process_query_with_corpus, inputs=[model_dropdown, image_input, query_input, enable_vlm, enable_corpus], outputs=[score_output, result_markdown, heatmap_output, vlm_answer_output, ranking_output], ) # Show hint when VLM is enabled enable_vlm.change( fn=on_vlm_toggle, inputs=[enable_vlm], outputs=[enable_vlm], ) return demo if __name__ == "__main__": install_fa2() demo = create_demo() demo.queue(max_size=10).launch(debug=True)