import tempfile from typing import List, Tuple, Any import gradio as gr import torch import torch.nn.functional as torch_functional from PIL import Image, ImageDraw from transformers import ( CLIPModel, CLIPProcessor, SamModel, SamProcessor, BlipForQuestionAnswering, BlipProcessor, pipeline, ) MODEL_STORE = {} def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]: if not gallery_value: return [] normalized_images: List[Image.Image] = [] for item in gallery_value: if isinstance(item, Image.Image): normalized_images.append(item) continue if isinstance(item, str): try: image_object = Image.open(item).convert("RGB") normalized_images.append(image_object) except Exception: continue continue if isinstance(item, (list, tuple)) and item: candidate = item[0] if isinstance(candidate, Image.Image): normalized_images.append(candidate) continue if isinstance(item, dict): candidate = item.get("image") or item.get("value") if isinstance(candidate, Image.Image): normalized_images.append(candidate) continue return normalized_images def get_vision_pipeline(model_key: str): if model_key in MODEL_STORE: return MODEL_STORE[model_key] if model_key == "object_detection_conditional_detr": vision_pipeline = pipeline( task="object-detection", model="microsoft/conditional-detr-resnet-50", ) elif model_key == "object_detection_yolos_small": vision_pipeline = pipeline( task="object-detection", model="hustvl/yolos-small", ) elif model_key == "segmentation": vision_pipeline = pipeline( task="image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512", ) elif model_key == "depth_estimation": vision_pipeline = pipeline( task="depth-estimation", model="Intel/dpt-hybrid-midas", ) elif model_key == "captioning_blip_base": vision_pipeline = pipeline( task="image-to-text", model="Salesforce/blip-image-captioning-base", ) elif model_key == "captioning_blip_large": vision_pipeline = pipeline( task="image-to-text", model="Salesforce/blip-image-captioning-large", ) elif model_key == "vqa_blip_base": vision_pipeline = pipeline( task="visual-question-answering", model="Salesforce/blip-vqa-base", ) elif model_key == "vqa_vilt_b32": vision_pipeline = pipeline( task="visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa", ) else: raise ValueError(f"Неизвестный тип модели: {model_key}") MODEL_STORE[model_key] = vision_pipeline return vision_pipeline def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]: model_store_key_model = f"clip_model_{clip_key}" model_store_key_processor = f"clip_processor_{clip_key}" if model_store_key_model not in MODEL_STORE or model_store_key_processor not in MODEL_STORE: if clip_key == "clip_large_patch14": clip_name = "openai/clip-vit-large-patch14" elif clip_key == "clip_base_patch32": clip_name = "openai/clip-vit-base-patch32" else: raise ValueError(f"Неизвестный вариант CLIP модели: {clip_key}") clip_model = CLIPModel.from_pretrained(clip_name) clip_processor = CLIPProcessor.from_pretrained(clip_name) MODEL_STORE[model_store_key_model] = clip_model MODEL_STORE[model_store_key_processor] = clip_processor clip_model = MODEL_STORE[model_store_key_model] clip_processor = MODEL_STORE[model_store_key_processor] return clip_model, clip_processor def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]: if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE: blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") MODEL_STORE["blip_vqa_model"] = blip_model MODEL_STORE["blip_vqa_processor"] = blip_processor blip_model = MODEL_STORE["blip_vqa_model"] blip_processor = MODEL_STORE["blip_vqa_processor"] return blip_model, blip_processor def get_sam_components() -> Tuple[SamModel, SamProcessor]: if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE: sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") MODEL_STORE["sam_model"] = sam_model MODEL_STORE["sam_processor"] = sam_processor sam_model = MODEL_STORE["sam_model"] sam_processor = MODEL_STORE["sam_processor"] return sam_model, sam_processor def detect_objects_on_image(image_object, model_key: str): if image_object is None: return None try: detector_pipeline = get_vision_pipeline(model_key) detection_results = detector_pipeline(image_object) drawer_object = ImageDraw.Draw(image_object) for detection_item in detection_results: box_data = detection_item["box"] label_value = detection_item["label"] score_value = detection_item["score"] drawer_object.rectangle( [ box_data["xmin"], box_data["ymin"], box_data["xmax"], box_data["ymax"], ], outline="red", width=3, ) drawer_object.text( (box_data["xmin"], box_data["ymin"]), f"{label_value}: {score_value:.2f}", fill="red", ) return image_object except Exception as e: print(f"Ошибка: {str(e)}") return None def segment_image(image_object): if image_object is None: return None try: segmentation_pipeline = get_vision_pipeline("segmentation") segmentation_results = segmentation_pipeline(image_object) return segmentation_results[0]["mask"] except Exception as e: print(f"Ошибка: {str(e)}") return None def estimate_image_depth(image_object): if image_object is None: return None try: depth_pipeline = get_vision_pipeline("depth_estimation") depth_output = depth_pipeline(image_object) predicted_depth_tensor = depth_output["predicted_depth"] if predicted_depth_tensor.ndim == 3: predicted_depth_tensor = predicted_depth_tensor.unsqueeze(1) elif predicted_depth_tensor.ndim == 2: predicted_depth_tensor = predicted_depth_tensor.unsqueeze(0).unsqueeze(0) else: raise ValueError( f"Неожиданная размерность: {predicted_depth_tensor.shape}" ) resized_depth_tensor = torch_functional.interpolate( predicted_depth_tensor, size=image_object.size[::-1], mode="bicubic", align_corners=False, ) depth_array = resized_depth_tensor.squeeze().cpu().numpy() max_value = float(depth_array.max()) if max_value <= 0.0: return Image.new("L", image_object.size, color=0) normalized_depth_array = (depth_array * 255.0 / max_value).astype("uint8") depth_image = Image.fromarray(normalized_depth_array, mode="L") return depth_image except Exception as e: print(f"Ошибка: {str(e)}") return None def generate_image_caption(image_object, model_key: str) -> str: if image_object is None: return "Загрузите изображение" try: caption_pipeline = get_vision_pipeline(model_key) caption_result = caption_pipeline(image_object) return caption_result[0]["generated_text"] except Exception as e: return f"Ошибка: {str(e)}" def answer_visual_question(image_object, question_text: str, model_key: str) -> str: if image_object is None: return "Загрузите изображение" if not question_text.strip(): return "Введите вопрос" try: if model_key == "vqa_blip_base": blip_model, blip_processor = get_blip_vqa_components() inputs = blip_processor( images=image_object, text=question_text, return_tensors="pt", ) with torch.no_grad(): output_ids = blip_model.generate(**inputs) decoded_answers = blip_processor.batch_decode( output_ids, skip_special_tokens=True, ) answer_text = decoded_answers[0] if decoded_answers else "" return answer_text or "Модель не смогла ответить" vqa_pipeline = get_vision_pipeline(model_key) vqa_result = vqa_pipeline( image=image_object, question=question_text, ) top_item = vqa_result[0] answer_text = top_item["answer"] confidence_value = top_item["score"] return f"{answer_text} (уверенность: {confidence_value:.3f})" except Exception as e: return f"Ошибка: {str(e)}" def perform_zero_shot_classification( image_object, class_texts: str, clip_key: str, ) -> str: if image_object is None: return "Загрузите изображение" try: clip_model, clip_processor = get_clip_components(clip_key) class_list = [ class_name.strip() for class_name in class_texts.split(",") if class_name.strip() ] if not class_list: return "Укажите классы для классификации" input_batch = clip_processor( text=class_list, images=image_object, return_tensors="pt", padding=True, ) with torch.no_grad(): clip_outputs = clip_model(**input_batch) logits_per_image = clip_outputs.logits_per_image probability_tensor = logits_per_image.softmax(dim=1) result_lines = ["Результаты:"] for class_index, class_name in enumerate(class_list): probability_value = probability_tensor[0][class_index].item() result_lines.append(f"{class_name}: {probability_value:.4f}") return "\n".join(result_lines) except Exception as e: return f"Ошибка: {str(e)}" def retrieve_best_image( gallery_value: Any, query_text: str, clip_key: str, ) -> Tuple[str, Image.Image | None]: image_list = _normalize_gallery_images(gallery_value) if not image_list or not query_text.strip(): return "Загрузите изображения и введите запрос", None try: clip_model, clip_processor = get_clip_components(clip_key) image_inputs = clip_processor( images=image_list, return_tensors="pt", padding=True, ) with torch.no_grad(): image_features = clip_model.get_image_features(**image_inputs) image_features = image_features / image_features.norm( dim=-1, keepdim=True, ) text_inputs = clip_processor( text=[query_text], return_tensors="pt", padding=True, ) with torch.no_grad(): text_features = clip_model.get_text_features(**text_inputs) text_features = text_features / text_features.norm( dim=-1, keepdim=True, ) similarity_tensor = image_features @ text_features.T best_index_tensor = similarity_tensor.argmax() best_index_value = best_index_tensor.item() best_score_value = similarity_tensor[best_index_value].item() description_text = ( f"Изображение #{best_index_value + 1} " f"(схожесть: {best_score_value:.4f})" ) return description_text, image_list[best_index_value] except Exception as e: return f"Ошибка: {str(e)}", None def segment_image_with_sam_points( image_object, point_coordinates_list: List[List[int]], ) -> Image.Image: if image_object is None: raise ValueError("Изображение не передано") if not point_coordinates_list: return Image.new("L", image_object.size, color=0) try: sam_model, sam_processor = get_sam_components() batched_points: List[List[List[int]]] = [point_coordinates_list] batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]] sam_inputs = sam_processor( image=image_object, input_points=batched_points, input_labels=batched_labels, return_tensors="pt", ) with torch.no_grad(): sam_outputs = sam_model(**sam_inputs, multimask_output=True) processed_masks_list = sam_processor.image_processor.post_process_masks( sam_outputs.pred_masks.squeeze(1).cpu(), sam_inputs["original_sizes"].cpu(), sam_inputs["reshaped_input_sizes"].cpu(), ) batch_masks_tensor = processed_masks_list[0] if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0: return Image.new("L", image_object.size, color=0) first_mask_tensor = batch_masks_tensor[0] mask_array = first_mask_tensor.numpy() binary_mask_array = (mask_array > 0.5).astype("uint8") * 255 mask_image = Image.fromarray(binary_mask_array, mode="L") return mask_image except Exception as e: print(f"Ошибка: {str(e)}") return Image.new("L", image_object.size, color=0) def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image: if image_object is None: return None coordinates_text_clean = coordinates_text.strip() if not coordinates_text_clean: return Image.new("L", image_object.size, color=0) point_coordinates_list: List[List[int]] = [] for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"): raw_pair_clean = raw_pair.strip() if not raw_pair_clean: continue parts = raw_pair_clean.split(",") if len(parts) != 2: continue try: x_value = int(parts[0].strip()) y_value = int(parts[1].strip()) except ValueError: continue point_coordinates_list.append([x_value, y_value]) if not point_coordinates_list: return Image.new("L", image_object.size, color=0) return segment_image_with_sam_points(image_object, point_coordinates_list) def build_interface(): with gr.Blocks(title="Vision Processing Demo") as demo: gr.Markdown("# Система обработки изображений") with gr.Tab("Детекция объектов"): object_input_image = gr.Image(label="Загрузите изображение", type="pil") object_model_selector = gr.Dropdown( choices=[ "object_detection_conditional_detr", "object_detection_yolos_small", ], label="Модель", value="object_detection_conditional_detr", ) object_detect_button = gr.Button("Выполнить детекцию") object_output_image = gr.Image(label="Результат") object_detect_button.click( fn=detect_objects_on_image, inputs=[object_input_image, object_model_selector], outputs=object_output_image, ) with gr.Tab("Сегментация"): segmentation_input_image = gr.Image(label="Загрузите изображение", type="pil") segmentation_button = gr.Button("Запустить сегментацию") segmentation_output_image = gr.Image(label="Маска") segmentation_button.click( fn=segment_image, inputs=segmentation_input_image, outputs=segmentation_output_image, ) with gr.Tab("Оценка глубины"): depth_input_image = gr.Image(label="Загрузите изображение", type="pil") depth_button = gr.Button("Оценить глубину") depth_output_image = gr.Image(label="Карта глубины") depth_button.click( fn=estimate_image_depth, inputs=depth_input_image, outputs=depth_output_image, ) with gr.Tab("Описание"): caption_input_image = gr.Image(label="Загрузите изображение", type="pil") caption_model_selector = gr.Dropdown( choices=[ "captioning_blip_base", "captioning_blip_large", ], label="Модель", value="captioning_blip_base", ) caption_button = gr.Button("Создать описание") caption_output_text = gr.Textbox(label="Описание", lines=3) caption_button.click( fn=generate_image_caption, inputs=[caption_input_image, caption_model_selector], outputs=caption_output_text, ) with gr.Tab("VQA"): vqa_input_image = gr.Image(label="Загрузите изображение", type="pil") vqa_question_text = gr.Textbox(label="Вопрос", lines=2) vqa_model_selector = gr.Dropdown( choices=[ "vqa_blip_base", "vqa_vilt_b32", ], label="Модель", value="vqa_blip_base", ) vqa_button = gr.Button("Задать вопрос") vqa_output_text = gr.Textbox(label="Ответ", lines=3) vqa_button.click( fn=answer_visual_question, inputs=[vqa_input_image, vqa_question_text, vqa_model_selector], outputs=vqa_output_text, ) with gr.Tab("Zero-Shot"): zero_shot_input_image = gr.Image(label="Загрузите изображение", type="pil") zero_shot_classes_text = gr.Textbox( label="Классы", placeholder="Введите классы через запятую", lines=2, ) clip_model_selector = gr.Dropdown( choices=[ "clip_large_patch14", "clip_base_patch32", ], label="Модель", value="clip_large_patch14", ) zero_shot_button = gr.Button("Классифицировать") zero_shot_output_text = gr.Textbox(label="Результаты", lines=8) zero_shot_button.click( fn=perform_zero_shot_classification, inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector], outputs=zero_shot_output_text, ) with gr.Tab("Поиск"): retrieval_dir = gr.File( label="Загрузите папку", file_count="directory", file_types=["image"], type="filepath", ) retrieval_query_text = gr.Textbox(label="Текстовый запрос", lines=2) retrieval_clip_selector = gr.Dropdown( choices=[ "clip_large_patch14", "clip_base_patch32", ], label="Модель", value="clip_large_patch14", ) retrieval_button = gr.Button("Найти изображение") retrieval_output_text = gr.Textbox(label="Результат") retrieval_output_image = gr.Image(label="Найденное изображение") retrieval_button.click( fn=retrieve_best_image, inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector], outputs=[retrieval_output_text, retrieval_output_image], ) return demo if __name__ == "__main__": interface = build_interface() interface.launch(share=True, server_name="0.0.0.0")