import os import sys from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Tuple import gradio as gr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from scipy.ndimage import gaussian_filter # Make Tipsomaly package importable from repository root. ROOT_DIR = Path(__file__).resolve().parent TIPSOMALY_DIR = ROOT_DIR / "Tipsomaly" MODEL_DIR = TIPSOMALY_DIR / "model" if str(TIPSOMALY_DIR) not in sys.path: sys.path.insert(0, str(TIPSOMALY_DIR)) if str(MODEL_DIR) not in sys.path: sys.path.insert(0, str(MODEL_DIR)) from Tipsomaly.datasets import input_transforms from Tipsomaly.model import omaly, tips from Tipsomaly.utils.visualize import apply_ad_scoremap, normalize @dataclass class DemoConfig: image_size: int = int(os.getenv("IMAGE_SIZE", "518")) models_dir: str = os.getenv("TIPS_MODELS_DIR", str(ROOT_DIR / "tips")) model_version: str = os.getenv("TIPS_MODEL_VERSION", "l14h") object_name: str = os.getenv("OBJECT_NAME", "object") sigma: float = float(os.getenv("ANOMALY_SMOOTH_SIGMA", "4")) use_local_to_global: bool = os.getenv("USE_LOCAL_TO_GLOBAL", "true").lower() == "true" prompt_learn_method: str = os.getenv("PROMPT_LEARN_METHOD", "concat") n_prompt: int = int(os.getenv("N_PROMPT", "8")) n_deep_tokens: int = int(os.getenv("N_DEEP_TOKENS", "0")) d_deep_tokens: int = int(os.getenv("D_DEEP_TOKENS", "0")) checkpoint_epoch: int = int(os.getenv("LEARNABLE_PROMPT_EPOCH", "2")) CHECKPOINTS: Dict[str, str] = { "mvtec": "Tipsomaly/workspaces/trained_on_mvtec_default/vegan-arkansas/checkpoints", "visa": "Tipsomaly/workspaces/trained_on_visa_default/vegan-arkansas/checkpoints", } def calc_soft_score(vis_feat: torch.Tensor, txt_feat: torch.Tensor, temp: torch.Tensor) -> torch.Tensor: return F.softmax((vis_feat @ txt_feat.permute(0, 2, 1)) / temp, dim=-1) def regrid_upsample_smooth(flat_scores: torch.Tensor, size: int, sigma: float) -> torch.Tensor: h_w = int(flat_scores.shape[1] ** 0.5) regrided = flat_scores.reshape(flat_scores.shape[0], h_w, h_w, -1).permute(0, 3, 1, 2) upsampled = F.interpolate(regrided, (size, size), mode="bilinear", align_corners=False).permute(0, 2, 3, 1) rough_maps = (1 - upsampled[..., 0] + upsampled[..., 1]) / 2 anomaly_map = torch.stack( [torch.from_numpy(gaussian_filter(one_map, sigma=sigma)) for one_map in rough_maps.detach().cpu()], dim=0, ) return anomaly_map def make_heatmap_rgb(image: Image.Image, anomaly_map: np.ndarray, image_size: int) -> Image.Image: # Reuse Tipsomaly visualization utilities. vis_image = np.asarray(image.convert("RGB").resize((image_size, image_size))) overlay = apply_ad_scoremap(vis_image, normalize(anomaly_map)) return Image.fromarray(overlay, mode="RGB") class TipsomalyDemo: def __init__(self, config: DemoConfig) -> None: self.config = config self.device = "cuda" if torch.cuda.is_available() else "cpu" self.transform, _ = input_transforms.create_transforms_tips(config.image_size) self._init_tips_backbone() def _init_tips_backbone(self) -> None: vision_backbone, text_backbone, tokenizer, temperature = tips.load_model.get_model( self.config.models_dir, self.config.model_version ) self.temperature = ( temperature.to(self.device) if torch.is_tensor(temperature) else torch.tensor(temperature, device=self.device) ) self.tokenizer = tokenizer self.text_backbone = text_backbone.to(self.device).eval() self.text_embd_dim = self.text_backbone.transformer.width self.vision_encoder = omaly.vision_encoder(vision_backbone.to(self.device).eval(), "tips").to(self.device).eval() def _build_text_encoder(self, domain: str, prompt_learn_method: str): return omaly.text_encoder( self.tokenizer, self.text_backbone, "tips", self.text_embd_dim, 64, prompt_learn_method, domain, self.config.n_prompt, self.config.n_deep_tokens, self.config.d_deep_tokens, ).to(self.device).eval() def _resolve_checkpoint_path(self, token_source: str, custom_checkpoint: str) -> Optional[Path]: if token_source == "none": return None if token_source == "custom": if not custom_checkpoint.strip(): raise gr.Error("Custom checkpoint selected, but path is empty.") path = Path(custom_checkpoint.strip()) else: if token_source not in CHECKPOINTS: raise gr.Error(f"Unknown token source: {token_source}") base = ROOT_DIR / CHECKPOINTS[token_source] path = base / f"learnable_params_{self.config.checkpoint_epoch}.pth" if not path.exists(): raise gr.Error(f"Checkpoint not found: {path}") return path def _load_learnable_prompts(self, text_encoder, checkpoint_path: Optional[Path]) -> bool: if checkpoint_path is None: return False checkpoint = torch.load(str(checkpoint_path), map_location=self.device, weights_only=False) text_encoder.learnable_prompts = checkpoint["learnable_prompts"] if isinstance(checkpoint, dict) else checkpoint return True def _preprocess_image(self, image: Image.Image) -> torch.Tensor: image = image.convert("RGB") image_tensor = self.transform(image).unsqueeze(0) return image_tensor.to(self.device) @torch.inference_mode() def infer( self, image: Image.Image, domain: str, token_source: str, custom_checkpoint: str, ) -> Tuple[Image.Image, float]: if image is None: raise gr.Error("Please upload an image.") checkpoint_path = self._resolve_checkpoint_path(token_source, custom_checkpoint) prompt_learn_method = self.config.prompt_learn_method if checkpoint_path else "none" text_encoder = self._build_text_encoder(domain, prompt_learn_method=prompt_learn_method) has_learned = self._load_learnable_prompts(text_encoder, checkpoint_path) fixed_text_features = text_encoder([self.config.object_name], self.device, learned=False) fixed_text_features = fixed_text_features / fixed_text_features.norm(dim=-1, keepdim=True) seg_text_features = fixed_text_features if has_learned: learned_text_features = text_encoder([self.config.object_name], self.device, learned=True) learned_text_features = learned_text_features / learned_text_features.norm(dim=-1, keepdim=True) seg_text_features = learned_text_features image_tensor = self._preprocess_image(image) vision_features = self.vision_encoder(image_tensor) vision_features = [feature / feature.norm(dim=-1, keepdim=True) for feature in vision_features] # Decoupled prompts: image-level score uses fixed prompts, pixel-level map can use learned prompts. img_scr0 = calc_soft_score(vision_features[0], fixed_text_features, self.temperature).squeeze(dim=1).detach() img_scr1 = calc_soft_score(vision_features[1], fixed_text_features, self.temperature).squeeze(dim=1).detach() img_map = calc_soft_score(vision_features[2], seg_text_features, self.temperature).detach() if self.config.use_local_to_global: max_local = torch.max(img_map, dim=1)[0] img_scr0 = img_scr0 + max_local img_scr1 = img_scr1 + max_local pixel_scores = regrid_upsample_smooth(img_map, self.config.image_size, self.config.sigma) anomaly_map = pixel_scores[0].cpu().numpy() anomaly_score = float(img_scr1[0][1].item()) return make_heatmap_rgb(image, anomaly_map, self.config.image_size), anomaly_score CONFIG = DemoConfig() MODEL = TipsomalyDemo(CONFIG) def predict( image: Image.Image, domain: str, token_source: str, custom_checkpoint: str = "", ) -> Tuple[Image.Image, float]: return MODEL.infer(image, domain, token_source, custom_checkpoint) with gr.Blocks(title="Tipsomaly Demo") as demo: gr.Markdown( "# Tipsomaly Anomaly Detection & Localization Demo\n" "Upload an image, choose domain, and choose learnable-token checkpoint." ) with gr.Row(): image_input = gr.Image(type="pil", label="Input Image") with gr.Column(): domain_input = gr.Radio( choices=["industrial", "medical"], value="industrial", label="Domain", ) token_source_input = gr.Radio( choices=["none", "mvtec", "visa"], value="none", label="Learnable Tokens", info="Use pretrained prompt tokens from workspace checkpoints.", ) run_btn = gr.Button("Run Detection", variant="primary") with gr.Row(): anomaly_map_output = gr.Image(type="pil", label="Anomaly Map") anomaly_score_output = gr.Number(label="Anomaly Score") run_btn.click( fn=predict, inputs=[image_input, domain_input, token_source_input], outputs=[anomaly_map_output, anomaly_score_output], ) if __name__ == "__main__": demo.launch()