image-hw / app.py
Ulyha's picture
Update app.py
9fbca7d verified
import tempfile
from typing import List, Tuple, Any
import gradio as gr
import torch
import torch.nn.functional as torch_functional
from PIL import Image, ImageDraw
from transformers import (
CLIPModel,
CLIPProcessor,
SamModel,
SamProcessor,
BlipForQuestionAnswering,
BlipProcessor,
pipeline,
)
MODEL_STORE = {}
def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
if not gallery_value:
return []
normalized_images: List[Image.Image] = []
for item in gallery_value:
if isinstance(item, Image.Image):
normalized_images.append(item)
continue
if isinstance(item, str):
try:
image_object = Image.open(item).convert("RGB")
normalized_images.append(image_object)
except Exception:
continue
continue
if isinstance(item, (list, tuple)) and item:
candidate = item[0]
if isinstance(candidate, Image.Image):
normalized_images.append(candidate)
continue
if isinstance(item, dict):
candidate = item.get("image") or item.get("value")
if isinstance(candidate, Image.Image):
normalized_images.append(candidate)
continue
return normalized_images
def get_vision_pipeline(model_key: str):
if model_key in MODEL_STORE:
return MODEL_STORE[model_key]
if model_key == "object_detection_conditional_detr":
vision_pipeline = pipeline(
task="object-detection",
model="microsoft/conditional-detr-resnet-50",
)
elif model_key == "object_detection_yolos_small":
vision_pipeline = pipeline(
task="object-detection",
model="hustvl/yolos-small",
)
elif model_key == "segmentation":
vision_pipeline = pipeline(
task="image-segmentation",
model="nvidia/segformer-b0-finetuned-ade-512-512",
)
elif model_key == "depth_estimation":
vision_pipeline = pipeline(
task="depth-estimation",
model="Intel/dpt-hybrid-midas",
)
elif model_key == "captioning_blip_base":
vision_pipeline = pipeline(
task="image-to-text",
model="Salesforce/blip-image-captioning-base",
)
elif model_key == "captioning_blip_large":
vision_pipeline = pipeline(
task="image-to-text",
model="Salesforce/blip-image-captioning-large",
)
elif model_key == "vqa_blip_base":
vision_pipeline = pipeline(
task="visual-question-answering",
model="Salesforce/blip-vqa-base",
)
elif model_key == "vqa_vilt_b32":
vision_pipeline = pipeline(
task="visual-question-answering",
model="dandelin/vilt-b32-finetuned-vqa",
)
else:
raise ValueError(f"Неизвестный тип модели: {model_key}")
MODEL_STORE[model_key] = vision_pipeline
return vision_pipeline
def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]:
model_store_key_model = f"clip_model_{clip_key}"
model_store_key_processor = f"clip_processor_{clip_key}"
if model_store_key_model not in MODEL_STORE or model_store_key_processor not in MODEL_STORE:
if clip_key == "clip_large_patch14":
clip_name = "openai/clip-vit-large-patch14"
elif clip_key == "clip_base_patch32":
clip_name = "openai/clip-vit-base-patch32"
else:
raise ValueError(f"Неизвестный вариант CLIP модели: {clip_key}")
clip_model = CLIPModel.from_pretrained(clip_name)
clip_processor = CLIPProcessor.from_pretrained(clip_name)
MODEL_STORE[model_store_key_model] = clip_model
MODEL_STORE[model_store_key_processor] = clip_processor
clip_model = MODEL_STORE[model_store_key_model]
clip_processor = MODEL_STORE[model_store_key_processor]
return clip_model, clip_processor
def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
MODEL_STORE["blip_vqa_model"] = blip_model
MODEL_STORE["blip_vqa_processor"] = blip_processor
blip_model = MODEL_STORE["blip_vqa_model"]
blip_processor = MODEL_STORE["blip_vqa_processor"]
return blip_model, blip_processor
def get_sam_components() -> Tuple[SamModel, SamProcessor]:
if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE:
sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
MODEL_STORE["sam_model"] = sam_model
MODEL_STORE["sam_processor"] = sam_processor
sam_model = MODEL_STORE["sam_model"]
sam_processor = MODEL_STORE["sam_processor"]
return sam_model, sam_processor
def detect_objects_on_image(image_object, model_key: str):
if image_object is None:
return None
try:
detector_pipeline = get_vision_pipeline(model_key)
detection_results = detector_pipeline(image_object)
drawer_object = ImageDraw.Draw(image_object)
for detection_item in detection_results:
box_data = detection_item["box"]
label_value = detection_item["label"]
score_value = detection_item["score"]
drawer_object.rectangle(
[
box_data["xmin"],
box_data["ymin"],
box_data["xmax"],
box_data["ymax"],
],
outline="red",
width=3,
)
drawer_object.text(
(box_data["xmin"], box_data["ymin"]),
f"{label_value}: {score_value:.2f}",
fill="red",
)
return image_object
except Exception as e:
print(f"Ошибка: {str(e)}")
return None
def segment_image(image_object):
if image_object is None:
return None
try:
segmentation_pipeline = get_vision_pipeline("segmentation")
segmentation_results = segmentation_pipeline(image_object)
return segmentation_results[0]["mask"]
except Exception as e:
print(f"Ошибка: {str(e)}")
return None
def estimate_image_depth(image_object):
if image_object is None:
return None
try:
depth_pipeline = get_vision_pipeline("depth_estimation")
depth_output = depth_pipeline(image_object)
predicted_depth_tensor = depth_output["predicted_depth"]
if predicted_depth_tensor.ndim == 3:
predicted_depth_tensor = predicted_depth_tensor.unsqueeze(1)
elif predicted_depth_tensor.ndim == 2:
predicted_depth_tensor = predicted_depth_tensor.unsqueeze(0).unsqueeze(0)
else:
raise ValueError(
f"Неожиданная размерность: {predicted_depth_tensor.shape}"
)
resized_depth_tensor = torch_functional.interpolate(
predicted_depth_tensor,
size=image_object.size[::-1],
mode="bicubic",
align_corners=False,
)
depth_array = resized_depth_tensor.squeeze().cpu().numpy()
max_value = float(depth_array.max())
if max_value <= 0.0:
return Image.new("L", image_object.size, color=0)
normalized_depth_array = (depth_array * 255.0 / max_value).astype("uint8")
depth_image = Image.fromarray(normalized_depth_array, mode="L")
return depth_image
except Exception as e:
print(f"Ошибка: {str(e)}")
return None
def generate_image_caption(image_object, model_key: str) -> str:
if image_object is None:
return "Загрузите изображение"
try:
caption_pipeline = get_vision_pipeline(model_key)
caption_result = caption_pipeline(image_object)
return caption_result[0]["generated_text"]
except Exception as e:
return f"Ошибка: {str(e)}"
def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
if image_object is None:
return "Загрузите изображение"
if not question_text.strip():
return "Введите вопрос"
try:
if model_key == "vqa_blip_base":
blip_model, blip_processor = get_blip_vqa_components()
inputs = blip_processor(
images=image_object,
text=question_text,
return_tensors="pt",
)
with torch.no_grad():
output_ids = blip_model.generate(**inputs)
decoded_answers = blip_processor.batch_decode(
output_ids,
skip_special_tokens=True,
)
answer_text = decoded_answers[0] if decoded_answers else ""
return answer_text or "Модель не смогла ответить"
vqa_pipeline = get_vision_pipeline(model_key)
vqa_result = vqa_pipeline(
image=image_object,
question=question_text,
)
top_item = vqa_result[0]
answer_text = top_item["answer"]
confidence_value = top_item["score"]
return f"{answer_text} (уверенность: {confidence_value:.3f})"
except Exception as e:
return f"Ошибка: {str(e)}"
def perform_zero_shot_classification(
image_object,
class_texts: str,
clip_key: str,
) -> str:
if image_object is None:
return "Загрузите изображение"
try:
clip_model, clip_processor = get_clip_components(clip_key)
class_list = [
class_name.strip()
for class_name in class_texts.split(",")
if class_name.strip()
]
if not class_list:
return "Укажите классы для классификации"
input_batch = clip_processor(
text=class_list,
images=image_object,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
clip_outputs = clip_model(**input_batch)
logits_per_image = clip_outputs.logits_per_image
probability_tensor = logits_per_image.softmax(dim=1)
result_lines = ["Результаты:"]
for class_index, class_name in enumerate(class_list):
probability_value = probability_tensor[0][class_index].item()
result_lines.append(f"{class_name}: {probability_value:.4f}")
return "\n".join(result_lines)
except Exception as e:
return f"Ошибка: {str(e)}"
def retrieve_best_image(
gallery_value: Any,
query_text: str,
clip_key: str,
) -> Tuple[str, Image.Image | None]:
image_list = _normalize_gallery_images(gallery_value)
if not image_list or not query_text.strip():
return "Загрузите изображения и введите запрос", None
try:
clip_model, clip_processor = get_clip_components(clip_key)
image_inputs = clip_processor(
images=image_list,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
image_features = clip_model.get_image_features(**image_inputs)
image_features = image_features / image_features.norm(
dim=-1,
keepdim=True,
)
text_inputs = clip_processor(
text=[query_text],
return_tensors="pt",
padding=True,
)
with torch.no_grad():
text_features = clip_model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(
dim=-1,
keepdim=True,
)
similarity_tensor = image_features @ text_features.T
best_index_tensor = similarity_tensor.argmax()
best_index_value = best_index_tensor.item()
best_score_value = similarity_tensor[best_index_value].item()
description_text = (
f"Изображение #{best_index_value + 1} "
f"(схожесть: {best_score_value:.4f})"
)
return description_text, image_list[best_index_value]
except Exception as e:
return f"Ошибка: {str(e)}", None
def segment_image_with_sam_points(
image_object,
point_coordinates_list: List[List[int]],
) -> Image.Image:
if image_object is None:
raise ValueError("Изображение не передано")
if not point_coordinates_list:
return Image.new("L", image_object.size, color=0)
try:
sam_model, sam_processor = get_sam_components()
batched_points: List[List[List[int]]] = [point_coordinates_list]
batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]]
sam_inputs = sam_processor(
image=image_object,
input_points=batched_points,
input_labels=batched_labels,
return_tensors="pt",
)
with torch.no_grad():
sam_outputs = sam_model(**sam_inputs, multimask_output=True)
processed_masks_list = sam_processor.image_processor.post_process_masks(
sam_outputs.pred_masks.squeeze(1).cpu(),
sam_inputs["original_sizes"].cpu(),
sam_inputs["reshaped_input_sizes"].cpu(),
)
batch_masks_tensor = processed_masks_list[0]
if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0:
return Image.new("L", image_object.size, color=0)
first_mask_tensor = batch_masks_tensor[0]
mask_array = first_mask_tensor.numpy()
binary_mask_array = (mask_array > 0.5).astype("uint8") * 255
mask_image = Image.fromarray(binary_mask_array, mode="L")
return mask_image
except Exception as e:
print(f"Ошибка: {str(e)}")
return Image.new("L", image_object.size, color=0)
def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image:
if image_object is None:
return None
coordinates_text_clean = coordinates_text.strip()
if not coordinates_text_clean:
return Image.new("L", image_object.size, color=0)
point_coordinates_list: List[List[int]] = []
for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"):
raw_pair_clean = raw_pair.strip()
if not raw_pair_clean:
continue
parts = raw_pair_clean.split(",")
if len(parts) != 2:
continue
try:
x_value = int(parts[0].strip())
y_value = int(parts[1].strip())
except ValueError:
continue
point_coordinates_list.append([x_value, y_value])
if not point_coordinates_list:
return Image.new("L", image_object.size, color=0)
return segment_image_with_sam_points(image_object, point_coordinates_list)
def build_interface():
with gr.Blocks(title="Vision Processing Demo") as demo:
gr.Markdown("# Система обработки изображений")
with gr.Tab("Детекция объектов"):
object_input_image = gr.Image(label="Загрузите изображение", type="pil")
object_model_selector = gr.Dropdown(
choices=[
"object_detection_conditional_detr",
"object_detection_yolos_small",
],
label="Модель",
value="object_detection_conditional_detr",
)
object_detect_button = gr.Button("Выполнить детекцию")
object_output_image = gr.Image(label="Результат")
object_detect_button.click(
fn=detect_objects_on_image,
inputs=[object_input_image, object_model_selector],
outputs=object_output_image,
)
with gr.Tab("Сегментация"):
segmentation_input_image = gr.Image(label="Загрузите изображение", type="pil")
segmentation_button = gr.Button("Запустить сегментацию")
segmentation_output_image = gr.Image(label="Маска")
segmentation_button.click(
fn=segment_image,
inputs=segmentation_input_image,
outputs=segmentation_output_image,
)
with gr.Tab("Оценка глубины"):
depth_input_image = gr.Image(label="Загрузите изображение", type="pil")
depth_button = gr.Button("Оценить глубину")
depth_output_image = gr.Image(label="Карта глубины")
depth_button.click(
fn=estimate_image_depth,
inputs=depth_input_image,
outputs=depth_output_image,
)
with gr.Tab("Описание"):
caption_input_image = gr.Image(label="Загрузите изображение", type="pil")
caption_model_selector = gr.Dropdown(
choices=[
"captioning_blip_base",
"captioning_blip_large",
],
label="Модель",
value="captioning_blip_base",
)
caption_button = gr.Button("Создать описание")
caption_output_text = gr.Textbox(label="Описание", lines=3)
caption_button.click(
fn=generate_image_caption,
inputs=[caption_input_image, caption_model_selector],
outputs=caption_output_text,
)
with gr.Tab("VQA"):
vqa_input_image = gr.Image(label="Загрузите изображение", type="pil")
vqa_question_text = gr.Textbox(label="Вопрос", lines=2)
vqa_model_selector = gr.Dropdown(
choices=[
"vqa_blip_base",
"vqa_vilt_b32",
],
label="Модель",
value="vqa_blip_base",
)
vqa_button = gr.Button("Задать вопрос")
vqa_output_text = gr.Textbox(label="Ответ", lines=3)
vqa_button.click(
fn=answer_visual_question,
inputs=[vqa_input_image, vqa_question_text, vqa_model_selector],
outputs=vqa_output_text,
)
with gr.Tab("Zero-Shot"):
zero_shot_input_image = gr.Image(label="Загрузите изображение", type="pil")
zero_shot_classes_text = gr.Textbox(
label="Классы",
placeholder="Введите классы через запятую",
lines=2,
)
clip_model_selector = gr.Dropdown(
choices=[
"clip_large_patch14",
"clip_base_patch32",
],
label="Модель",
value="clip_large_patch14",
)
zero_shot_button = gr.Button("Классифицировать")
zero_shot_output_text = gr.Textbox(label="Результаты", lines=8)
zero_shot_button.click(
fn=perform_zero_shot_classification,
inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector],
outputs=zero_shot_output_text,
)
with gr.Tab("Поиск"):
retrieval_dir = gr.File(
label="Загрузите папку",
file_count="directory",
file_types=["image"],
type="filepath",
)
retrieval_query_text = gr.Textbox(label="Текстовый запрос", lines=2)
retrieval_clip_selector = gr.Dropdown(
choices=[
"clip_large_patch14",
"clip_base_patch32",
],
label="Модель",
value="clip_large_patch14",
)
retrieval_button = gr.Button("Найти изображение")
retrieval_output_text = gr.Textbox(label="Результат")
retrieval_output_image = gr.Image(label="Найденное изображение")
retrieval_button.click(
fn=retrieve_best_image,
inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector],
outputs=[retrieval_output_text, retrieval_output_image],
)
return demo
if __name__ == "__main__":
interface = build_interface()
interface.launch(share=True, server_name="0.0.0.0")