Spaces:
Running
Running
| import os | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from scipy.ndimage import gaussian_filter | |
| # Make Tipsomaly package importable from repository root. | |
| ROOT_DIR = Path(__file__).resolve().parent | |
| TIPSOMALY_DIR = ROOT_DIR / "Tipsomaly" | |
| MODEL_DIR = TIPSOMALY_DIR / "model" | |
| if str(TIPSOMALY_DIR) not in sys.path: | |
| sys.path.insert(0, str(TIPSOMALY_DIR)) | |
| if str(MODEL_DIR) not in sys.path: | |
| sys.path.insert(0, str(MODEL_DIR)) | |
| from Tipsomaly.datasets import input_transforms | |
| from Tipsomaly.model import omaly, tips | |
| from Tipsomaly.utils.visualize import apply_ad_scoremap, normalize | |
| class DemoConfig: | |
| image_size: int = int(os.getenv("IMAGE_SIZE", "518")) | |
| models_dir: str = os.getenv("TIPS_MODELS_DIR", str(ROOT_DIR / "tips")) | |
| model_version: str = os.getenv("TIPS_MODEL_VERSION", "l14h") | |
| object_name: str = os.getenv("OBJECT_NAME", "object") | |
| sigma: float = float(os.getenv("ANOMALY_SMOOTH_SIGMA", "4")) | |
| use_local_to_global: bool = os.getenv("USE_LOCAL_TO_GLOBAL", "true").lower() == "true" | |
| prompt_learn_method: str = os.getenv("PROMPT_LEARN_METHOD", "concat") | |
| n_prompt: int = int(os.getenv("N_PROMPT", "8")) | |
| n_deep_tokens: int = int(os.getenv("N_DEEP_TOKENS", "0")) | |
| d_deep_tokens: int = int(os.getenv("D_DEEP_TOKENS", "0")) | |
| checkpoint_epoch: int = int(os.getenv("LEARNABLE_PROMPT_EPOCH", "2")) | |
| CHECKPOINTS: Dict[str, str] = { | |
| "mvtec": "Tipsomaly/workspaces/trained_on_mvtec_default/vegan-arkansas/checkpoints", | |
| "visa": "Tipsomaly/workspaces/trained_on_visa_default/vegan-arkansas/checkpoints", | |
| } | |
| def calc_soft_score(vis_feat: torch.Tensor, txt_feat: torch.Tensor, temp: torch.Tensor) -> torch.Tensor: | |
| return F.softmax((vis_feat @ txt_feat.permute(0, 2, 1)) / temp, dim=-1) | |
| def regrid_upsample_smooth(flat_scores: torch.Tensor, size: int, sigma: float) -> torch.Tensor: | |
| h_w = int(flat_scores.shape[1] ** 0.5) | |
| regrided = flat_scores.reshape(flat_scores.shape[0], h_w, h_w, -1).permute(0, 3, 1, 2) | |
| upsampled = F.interpolate(regrided, (size, size), mode="bilinear", align_corners=False).permute(0, 2, 3, 1) | |
| rough_maps = (1 - upsampled[..., 0] + upsampled[..., 1]) / 2 | |
| anomaly_map = torch.stack( | |
| [torch.from_numpy(gaussian_filter(one_map, sigma=sigma)) for one_map in rough_maps.detach().cpu()], | |
| dim=0, | |
| ) | |
| return anomaly_map | |
| def make_heatmap_rgb(image: Image.Image, anomaly_map: np.ndarray, image_size: int) -> Image.Image: | |
| # Reuse Tipsomaly visualization utilities. | |
| vis_image = np.asarray(image.convert("RGB").resize((image_size, image_size))) | |
| overlay = apply_ad_scoremap(vis_image, normalize(anomaly_map)) | |
| return Image.fromarray(overlay, mode="RGB") | |
| class TipsomalyDemo: | |
| def __init__(self, config: DemoConfig) -> None: | |
| self.config = config | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.transform, _ = input_transforms.create_transforms_tips(config.image_size) | |
| self._init_tips_backbone() | |
| def _init_tips_backbone(self) -> None: | |
| vision_backbone, text_backbone, tokenizer, temperature = tips.load_model.get_model( | |
| self.config.models_dir, self.config.model_version | |
| ) | |
| self.temperature = ( | |
| temperature.to(self.device) if torch.is_tensor(temperature) else torch.tensor(temperature, device=self.device) | |
| ) | |
| self.tokenizer = tokenizer | |
| self.text_backbone = text_backbone.to(self.device).eval() | |
| self.text_embd_dim = self.text_backbone.transformer.width | |
| self.vision_encoder = omaly.vision_encoder(vision_backbone.to(self.device).eval(), "tips").to(self.device).eval() | |
| def _build_text_encoder(self, domain: str, prompt_learn_method: str): | |
| return omaly.text_encoder( | |
| self.tokenizer, | |
| self.text_backbone, | |
| "tips", | |
| self.text_embd_dim, | |
| 64, | |
| prompt_learn_method, | |
| domain, | |
| self.config.n_prompt, | |
| self.config.n_deep_tokens, | |
| self.config.d_deep_tokens, | |
| ).to(self.device).eval() | |
| def _resolve_checkpoint_path(self, token_source: str, custom_checkpoint: str) -> Optional[Path]: | |
| if token_source == "none": | |
| return None | |
| if token_source == "custom": | |
| if not custom_checkpoint.strip(): | |
| raise gr.Error("Custom checkpoint selected, but path is empty.") | |
| path = Path(custom_checkpoint.strip()) | |
| else: | |
| if token_source not in CHECKPOINTS: | |
| raise gr.Error(f"Unknown token source: {token_source}") | |
| base = ROOT_DIR / CHECKPOINTS[token_source] | |
| path = base / f"learnable_params_{self.config.checkpoint_epoch}.pth" | |
| if not path.exists(): | |
| raise gr.Error(f"Checkpoint not found: {path}") | |
| return path | |
| def _load_learnable_prompts(self, text_encoder, checkpoint_path: Optional[Path]) -> bool: | |
| if checkpoint_path is None: | |
| return False | |
| checkpoint = torch.load(str(checkpoint_path), map_location=self.device, weights_only=False) | |
| text_encoder.learnable_prompts = checkpoint["learnable_prompts"] if isinstance(checkpoint, dict) else checkpoint | |
| return True | |
| def _preprocess_image(self, image: Image.Image) -> torch.Tensor: | |
| image = image.convert("RGB") | |
| image_tensor = self.transform(image).unsqueeze(0) | |
| return image_tensor.to(self.device) | |
| def infer( | |
| self, | |
| image: Image.Image, | |
| domain: str, | |
| token_source: str, | |
| custom_checkpoint: str, | |
| ) -> Tuple[Image.Image, float]: | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| checkpoint_path = self._resolve_checkpoint_path(token_source, custom_checkpoint) | |
| prompt_learn_method = self.config.prompt_learn_method if checkpoint_path else "none" | |
| text_encoder = self._build_text_encoder(domain, prompt_learn_method=prompt_learn_method) | |
| has_learned = self._load_learnable_prompts(text_encoder, checkpoint_path) | |
| fixed_text_features = text_encoder([self.config.object_name], self.device, learned=False) | |
| fixed_text_features = fixed_text_features / fixed_text_features.norm(dim=-1, keepdim=True) | |
| seg_text_features = fixed_text_features | |
| if has_learned: | |
| learned_text_features = text_encoder([self.config.object_name], self.device, learned=True) | |
| learned_text_features = learned_text_features / learned_text_features.norm(dim=-1, keepdim=True) | |
| seg_text_features = learned_text_features | |
| image_tensor = self._preprocess_image(image) | |
| vision_features = self.vision_encoder(image_tensor) | |
| vision_features = [feature / feature.norm(dim=-1, keepdim=True) for feature in vision_features] | |
| # Decoupled prompts: image-level score uses fixed prompts, pixel-level map can use learned prompts. | |
| img_scr0 = calc_soft_score(vision_features[0], fixed_text_features, self.temperature).squeeze(dim=1).detach() | |
| img_scr1 = calc_soft_score(vision_features[1], fixed_text_features, self.temperature).squeeze(dim=1).detach() | |
| img_map = calc_soft_score(vision_features[2], seg_text_features, self.temperature).detach() | |
| if self.config.use_local_to_global: | |
| max_local = torch.max(img_map, dim=1)[0] | |
| img_scr0 = img_scr0 + max_local | |
| img_scr1 = img_scr1 + max_local | |
| pixel_scores = regrid_upsample_smooth(img_map, self.config.image_size, self.config.sigma) | |
| anomaly_map = pixel_scores[0].cpu().numpy() | |
| anomaly_score = float(img_scr1[0][1].item()) | |
| return make_heatmap_rgb(image, anomaly_map, self.config.image_size), anomaly_score | |
| CONFIG = DemoConfig() | |
| MODEL = TipsomalyDemo(CONFIG) | |
| def predict( | |
| image: Image.Image, | |
| domain: str, | |
| token_source: str, | |
| custom_checkpoint: str = "", | |
| ) -> Tuple[Image.Image, float]: | |
| return MODEL.infer(image, domain, token_source, custom_checkpoint) | |
| with gr.Blocks(title="Tipsomaly Demo") as demo: | |
| gr.Markdown( | |
| "# Tipsomaly Anomaly Detection & Localization Demo\n" | |
| "Upload an image, choose domain, and choose learnable-token checkpoint." | |
| ) | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| with gr.Column(): | |
| domain_input = gr.Radio( | |
| choices=["industrial", "medical"], | |
| value="industrial", | |
| label="Domain", | |
| ) | |
| token_source_input = gr.Radio( | |
| choices=["none", "mvtec", "visa"], | |
| value="none", | |
| label="Learnable Tokens", | |
| info="Use pretrained prompt tokens from workspace checkpoints.", | |
| ) | |
| run_btn = gr.Button("Run Detection", variant="primary") | |
| with gr.Row(): | |
| anomaly_map_output = gr.Image(type="pil", label="Anomaly Map") | |
| anomaly_score_output = gr.Number(label="Anomaly Score") | |
| run_btn.click( | |
| fn=predict, | |
| inputs=[image_input, domain_input, token_source_input], | |
| outputs=[anomaly_map_output, anomaly_score_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |