Tipsomaly / app.py
AlirezaSalehy
Make pipeline closer to the main codebase
7d32de0
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from scipy.ndimage import gaussian_filter
# Make Tipsomaly package importable from repository root.
ROOT_DIR = Path(__file__).resolve().parent
TIPSOMALY_DIR = ROOT_DIR / "Tipsomaly"
MODEL_DIR = TIPSOMALY_DIR / "model"
if str(TIPSOMALY_DIR) not in sys.path:
sys.path.insert(0, str(TIPSOMALY_DIR))
if str(MODEL_DIR) not in sys.path:
sys.path.insert(0, str(MODEL_DIR))
from Tipsomaly.datasets import input_transforms
from Tipsomaly.model import omaly, tips
from Tipsomaly.utils.visualize import apply_ad_scoremap, normalize
@dataclass
class DemoConfig:
image_size: int = int(os.getenv("IMAGE_SIZE", "518"))
models_dir: str = os.getenv("TIPS_MODELS_DIR", str(ROOT_DIR / "tips"))
model_version: str = os.getenv("TIPS_MODEL_VERSION", "l14h")
object_name: str = os.getenv("OBJECT_NAME", "object")
sigma: float = float(os.getenv("ANOMALY_SMOOTH_SIGMA", "4"))
use_local_to_global: bool = os.getenv("USE_LOCAL_TO_GLOBAL", "true").lower() == "true"
prompt_learn_method: str = os.getenv("PROMPT_LEARN_METHOD", "concat")
n_prompt: int = int(os.getenv("N_PROMPT", "8"))
n_deep_tokens: int = int(os.getenv("N_DEEP_TOKENS", "0"))
d_deep_tokens: int = int(os.getenv("D_DEEP_TOKENS", "0"))
checkpoint_epoch: int = int(os.getenv("LEARNABLE_PROMPT_EPOCH", "2"))
CHECKPOINTS: Dict[str, str] = {
"mvtec": "Tipsomaly/workspaces/trained_on_mvtec_default/vegan-arkansas/checkpoints",
"visa": "Tipsomaly/workspaces/trained_on_visa_default/vegan-arkansas/checkpoints",
}
def calc_soft_score(vis_feat: torch.Tensor, txt_feat: torch.Tensor, temp: torch.Tensor) -> torch.Tensor:
return F.softmax((vis_feat @ txt_feat.permute(0, 2, 1)) / temp, dim=-1)
def regrid_upsample_smooth(flat_scores: torch.Tensor, size: int, sigma: float) -> torch.Tensor:
h_w = int(flat_scores.shape[1] ** 0.5)
regrided = flat_scores.reshape(flat_scores.shape[0], h_w, h_w, -1).permute(0, 3, 1, 2)
upsampled = F.interpolate(regrided, (size, size), mode="bilinear", align_corners=False).permute(0, 2, 3, 1)
rough_maps = (1 - upsampled[..., 0] + upsampled[..., 1]) / 2
anomaly_map = torch.stack(
[torch.from_numpy(gaussian_filter(one_map, sigma=sigma)) for one_map in rough_maps.detach().cpu()],
dim=0,
)
return anomaly_map
def make_heatmap_rgb(image: Image.Image, anomaly_map: np.ndarray, image_size: int) -> Image.Image:
# Reuse Tipsomaly visualization utilities.
vis_image = np.asarray(image.convert("RGB").resize((image_size, image_size)))
overlay = apply_ad_scoremap(vis_image, normalize(anomaly_map))
return Image.fromarray(overlay, mode="RGB")
class TipsomalyDemo:
def __init__(self, config: DemoConfig) -> None:
self.config = config
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.transform, _ = input_transforms.create_transforms_tips(config.image_size)
self._init_tips_backbone()
def _init_tips_backbone(self) -> None:
vision_backbone, text_backbone, tokenizer, temperature = tips.load_model.get_model(
self.config.models_dir, self.config.model_version
)
self.temperature = (
temperature.to(self.device) if torch.is_tensor(temperature) else torch.tensor(temperature, device=self.device)
)
self.tokenizer = tokenizer
self.text_backbone = text_backbone.to(self.device).eval()
self.text_embd_dim = self.text_backbone.transformer.width
self.vision_encoder = omaly.vision_encoder(vision_backbone.to(self.device).eval(), "tips").to(self.device).eval()
def _build_text_encoder(self, domain: str, prompt_learn_method: str):
return omaly.text_encoder(
self.tokenizer,
self.text_backbone,
"tips",
self.text_embd_dim,
64,
prompt_learn_method,
domain,
self.config.n_prompt,
self.config.n_deep_tokens,
self.config.d_deep_tokens,
).to(self.device).eval()
def _resolve_checkpoint_path(self, token_source: str, custom_checkpoint: str) -> Optional[Path]:
if token_source == "none":
return None
if token_source == "custom":
if not custom_checkpoint.strip():
raise gr.Error("Custom checkpoint selected, but path is empty.")
path = Path(custom_checkpoint.strip())
else:
if token_source not in CHECKPOINTS:
raise gr.Error(f"Unknown token source: {token_source}")
base = ROOT_DIR / CHECKPOINTS[token_source]
path = base / f"learnable_params_{self.config.checkpoint_epoch}.pth"
if not path.exists():
raise gr.Error(f"Checkpoint not found: {path}")
return path
def _load_learnable_prompts(self, text_encoder, checkpoint_path: Optional[Path]) -> bool:
if checkpoint_path is None:
return False
checkpoint = torch.load(str(checkpoint_path), map_location=self.device, weights_only=False)
text_encoder.learnable_prompts = checkpoint["learnable_prompts"] if isinstance(checkpoint, dict) else checkpoint
return True
def _preprocess_image(self, image: Image.Image) -> torch.Tensor:
image = image.convert("RGB")
image_tensor = self.transform(image).unsqueeze(0)
return image_tensor.to(self.device)
@torch.inference_mode()
def infer(
self,
image: Image.Image,
domain: str,
token_source: str,
custom_checkpoint: str,
) -> Tuple[Image.Image, float]:
if image is None:
raise gr.Error("Please upload an image.")
checkpoint_path = self._resolve_checkpoint_path(token_source, custom_checkpoint)
prompt_learn_method = self.config.prompt_learn_method if checkpoint_path else "none"
text_encoder = self._build_text_encoder(domain, prompt_learn_method=prompt_learn_method)
has_learned = self._load_learnable_prompts(text_encoder, checkpoint_path)
fixed_text_features = text_encoder([self.config.object_name], self.device, learned=False)
fixed_text_features = fixed_text_features / fixed_text_features.norm(dim=-1, keepdim=True)
seg_text_features = fixed_text_features
if has_learned:
learned_text_features = text_encoder([self.config.object_name], self.device, learned=True)
learned_text_features = learned_text_features / learned_text_features.norm(dim=-1, keepdim=True)
seg_text_features = learned_text_features
image_tensor = self._preprocess_image(image)
vision_features = self.vision_encoder(image_tensor)
vision_features = [feature / feature.norm(dim=-1, keepdim=True) for feature in vision_features]
# Decoupled prompts: image-level score uses fixed prompts, pixel-level map can use learned prompts.
img_scr0 = calc_soft_score(vision_features[0], fixed_text_features, self.temperature).squeeze(dim=1).detach()
img_scr1 = calc_soft_score(vision_features[1], fixed_text_features, self.temperature).squeeze(dim=1).detach()
img_map = calc_soft_score(vision_features[2], seg_text_features, self.temperature).detach()
if self.config.use_local_to_global:
max_local = torch.max(img_map, dim=1)[0]
img_scr0 = img_scr0 + max_local
img_scr1 = img_scr1 + max_local
pixel_scores = regrid_upsample_smooth(img_map, self.config.image_size, self.config.sigma)
anomaly_map = pixel_scores[0].cpu().numpy()
anomaly_score = float(img_scr1[0][1].item())
return make_heatmap_rgb(image, anomaly_map, self.config.image_size), anomaly_score
CONFIG = DemoConfig()
MODEL = TipsomalyDemo(CONFIG)
def predict(
image: Image.Image,
domain: str,
token_source: str,
custom_checkpoint: str = "",
) -> Tuple[Image.Image, float]:
return MODEL.infer(image, domain, token_source, custom_checkpoint)
with gr.Blocks(title="Tipsomaly Demo") as demo:
gr.Markdown(
"# Tipsomaly Anomaly Detection & Localization Demo\n"
"Upload an image, choose domain, and choose learnable-token checkpoint."
)
with gr.Row():
image_input = gr.Image(type="pil", label="Input Image")
with gr.Column():
domain_input = gr.Radio(
choices=["industrial", "medical"],
value="industrial",
label="Domain",
)
token_source_input = gr.Radio(
choices=["none", "mvtec", "visa"],
value="none",
label="Learnable Tokens",
info="Use pretrained prompt tokens from workspace checkpoints.",
)
run_btn = gr.Button("Run Detection", variant="primary")
with gr.Row():
anomaly_map_output = gr.Image(type="pil", label="Anomaly Map")
anomaly_score_output = gr.Number(label="Anomaly Score")
run_btn.click(
fn=predict,
inputs=[image_input, domain_input, token_source_input],
outputs=[anomaly_map_output, anomaly_score_output],
)
if __name__ == "__main__":
demo.launch()