| import tempfile | |
| from typing import List, Tuple, Any | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as torch_functional | |
| from PIL import Image, ImageDraw | |
| from transformers import ( | |
| CLIPModel, | |
| CLIPProcessor, | |
| SamModel, | |
| SamProcessor, | |
| BlipForQuestionAnswering, | |
| BlipProcessor, | |
| pipeline, | |
| ) | |
| MODEL_STORE = {} | |
| def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]: | |
| if not gallery_value: | |
| return [] | |
| normalized_images: List[Image.Image] = [] | |
| for item in gallery_value: | |
| if isinstance(item, Image.Image): | |
| normalized_images.append(item) | |
| continue | |
| if isinstance(item, str): | |
| try: | |
| image_object = Image.open(item).convert("RGB") | |
| normalized_images.append(image_object) | |
| except Exception: | |
| continue | |
| continue | |
| if isinstance(item, (list, tuple)) and item: | |
| candidate = item[0] | |
| if isinstance(candidate, Image.Image): | |
| normalized_images.append(candidate) | |
| continue | |
| if isinstance(item, dict): | |
| candidate = item.get("image") or item.get("value") | |
| if isinstance(candidate, Image.Image): | |
| normalized_images.append(candidate) | |
| continue | |
| return normalized_images | |
| def get_vision_pipeline(model_key: str): | |
| if model_key in MODEL_STORE: | |
| return MODEL_STORE[model_key] | |
| if model_key == "object_detection_conditional_detr": | |
| vision_pipeline = pipeline( | |
| task="object-detection", | |
| model="microsoft/conditional-detr-resnet-50", | |
| ) | |
| elif model_key == "object_detection_yolos_small": | |
| vision_pipeline = pipeline( | |
| task="object-detection", | |
| model="hustvl/yolos-small", | |
| ) | |
| elif model_key == "segmentation": | |
| vision_pipeline = pipeline( | |
| task="image-segmentation", | |
| model="nvidia/segformer-b0-finetuned-ade-512-512", | |
| ) | |
| elif model_key == "depth_estimation": | |
| vision_pipeline = pipeline( | |
| task="depth-estimation", | |
| model="Intel/dpt-hybrid-midas", | |
| ) | |
| elif model_key == "captioning_blip_base": | |
| vision_pipeline = pipeline( | |
| task="image-to-text", | |
| model="Salesforce/blip-image-captioning-base", | |
| ) | |
| elif model_key == "captioning_blip_large": | |
| vision_pipeline = pipeline( | |
| task="image-to-text", | |
| model="Salesforce/blip-image-captioning-large", | |
| ) | |
| elif model_key == "vqa_blip_base": | |
| vision_pipeline = pipeline( | |
| task="visual-question-answering", | |
| model="Salesforce/blip-vqa-base", | |
| ) | |
| elif model_key == "vqa_vilt_b32": | |
| vision_pipeline = pipeline( | |
| task="visual-question-answering", | |
| model="dandelin/vilt-b32-finetuned-vqa", | |
| ) | |
| else: | |
| raise ValueError(f"Неизвестный тип модели: {model_key}") | |
| MODEL_STORE[model_key] = vision_pipeline | |
| return vision_pipeline | |
| def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]: | |
| model_store_key_model = f"clip_model_{clip_key}" | |
| model_store_key_processor = f"clip_processor_{clip_key}" | |
| if model_store_key_model not in MODEL_STORE or model_store_key_processor not in MODEL_STORE: | |
| if clip_key == "clip_large_patch14": | |
| clip_name = "openai/clip-vit-large-patch14" | |
| elif clip_key == "clip_base_patch32": | |
| clip_name = "openai/clip-vit-base-patch32" | |
| else: | |
| raise ValueError(f"Неизвестный вариант CLIP модели: {clip_key}") | |
| clip_model = CLIPModel.from_pretrained(clip_name) | |
| clip_processor = CLIPProcessor.from_pretrained(clip_name) | |
| MODEL_STORE[model_store_key_model] = clip_model | |
| MODEL_STORE[model_store_key_processor] = clip_processor | |
| clip_model = MODEL_STORE[model_store_key_model] | |
| clip_processor = MODEL_STORE[model_store_key_processor] | |
| return clip_model, clip_processor | |
| def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]: | |
| if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE: | |
| blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") | |
| MODEL_STORE["blip_vqa_model"] = blip_model | |
| MODEL_STORE["blip_vqa_processor"] = blip_processor | |
| blip_model = MODEL_STORE["blip_vqa_model"] | |
| blip_processor = MODEL_STORE["blip_vqa_processor"] | |
| return blip_model, blip_processor | |
| def get_sam_components() -> Tuple[SamModel, SamProcessor]: | |
| if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE: | |
| sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") | |
| sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") | |
| MODEL_STORE["sam_model"] = sam_model | |
| MODEL_STORE["sam_processor"] = sam_processor | |
| sam_model = MODEL_STORE["sam_model"] | |
| sam_processor = MODEL_STORE["sam_processor"] | |
| return sam_model, sam_processor | |
| def detect_objects_on_image(image_object, model_key: str): | |
| if image_object is None: | |
| return None | |
| try: | |
| detector_pipeline = get_vision_pipeline(model_key) | |
| detection_results = detector_pipeline(image_object) | |
| drawer_object = ImageDraw.Draw(image_object) | |
| for detection_item in detection_results: | |
| box_data = detection_item["box"] | |
| label_value = detection_item["label"] | |
| score_value = detection_item["score"] | |
| drawer_object.rectangle( | |
| [ | |
| box_data["xmin"], | |
| box_data["ymin"], | |
| box_data["xmax"], | |
| box_data["ymax"], | |
| ], | |
| outline="red", | |
| width=3, | |
| ) | |
| drawer_object.text( | |
| (box_data["xmin"], box_data["ymin"]), | |
| f"{label_value}: {score_value:.2f}", | |
| fill="red", | |
| ) | |
| return image_object | |
| except Exception as e: | |
| print(f"Ошибка: {str(e)}") | |
| return None | |
| def segment_image(image_object): | |
| if image_object is None: | |
| return None | |
| try: | |
| segmentation_pipeline = get_vision_pipeline("segmentation") | |
| segmentation_results = segmentation_pipeline(image_object) | |
| return segmentation_results[0]["mask"] | |
| except Exception as e: | |
| print(f"Ошибка: {str(e)}") | |
| return None | |
| def estimate_image_depth(image_object): | |
| if image_object is None: | |
| return None | |
| try: | |
| depth_pipeline = get_vision_pipeline("depth_estimation") | |
| depth_output = depth_pipeline(image_object) | |
| predicted_depth_tensor = depth_output["predicted_depth"] | |
| if predicted_depth_tensor.ndim == 3: | |
| predicted_depth_tensor = predicted_depth_tensor.unsqueeze(1) | |
| elif predicted_depth_tensor.ndim == 2: | |
| predicted_depth_tensor = predicted_depth_tensor.unsqueeze(0).unsqueeze(0) | |
| else: | |
| raise ValueError( | |
| f"Неожиданная размерность: {predicted_depth_tensor.shape}" | |
| ) | |
| resized_depth_tensor = torch_functional.interpolate( | |
| predicted_depth_tensor, | |
| size=image_object.size[::-1], | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| depth_array = resized_depth_tensor.squeeze().cpu().numpy() | |
| max_value = float(depth_array.max()) | |
| if max_value <= 0.0: | |
| return Image.new("L", image_object.size, color=0) | |
| normalized_depth_array = (depth_array * 255.0 / max_value).astype("uint8") | |
| depth_image = Image.fromarray(normalized_depth_array, mode="L") | |
| return depth_image | |
| except Exception as e: | |
| print(f"Ошибка: {str(e)}") | |
| return None | |
| def generate_image_caption(image_object, model_key: str) -> str: | |
| if image_object is None: | |
| return "Загрузите изображение" | |
| try: | |
| caption_pipeline = get_vision_pipeline(model_key) | |
| caption_result = caption_pipeline(image_object) | |
| return caption_result[0]["generated_text"] | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def answer_visual_question(image_object, question_text: str, model_key: str) -> str: | |
| if image_object is None: | |
| return "Загрузите изображение" | |
| if not question_text.strip(): | |
| return "Введите вопрос" | |
| try: | |
| if model_key == "vqa_blip_base": | |
| blip_model, blip_processor = get_blip_vqa_components() | |
| inputs = blip_processor( | |
| images=image_object, | |
| text=question_text, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| output_ids = blip_model.generate(**inputs) | |
| decoded_answers = blip_processor.batch_decode( | |
| output_ids, | |
| skip_special_tokens=True, | |
| ) | |
| answer_text = decoded_answers[0] if decoded_answers else "" | |
| return answer_text or "Модель не смогла ответить" | |
| vqa_pipeline = get_vision_pipeline(model_key) | |
| vqa_result = vqa_pipeline( | |
| image=image_object, | |
| question=question_text, | |
| ) | |
| top_item = vqa_result[0] | |
| answer_text = top_item["answer"] | |
| confidence_value = top_item["score"] | |
| return f"{answer_text} (уверенность: {confidence_value:.3f})" | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def perform_zero_shot_classification( | |
| image_object, | |
| class_texts: str, | |
| clip_key: str, | |
| ) -> str: | |
| if image_object is None: | |
| return "Загрузите изображение" | |
| try: | |
| clip_model, clip_processor = get_clip_components(clip_key) | |
| class_list = [ | |
| class_name.strip() | |
| for class_name in class_texts.split(",") | |
| if class_name.strip() | |
| ] | |
| if not class_list: | |
| return "Укажите классы для классификации" | |
| input_batch = clip_processor( | |
| text=class_list, | |
| images=image_object, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| clip_outputs = clip_model(**input_batch) | |
| logits_per_image = clip_outputs.logits_per_image | |
| probability_tensor = logits_per_image.softmax(dim=1) | |
| result_lines = ["Результаты:"] | |
| for class_index, class_name in enumerate(class_list): | |
| probability_value = probability_tensor[0][class_index].item() | |
| result_lines.append(f"{class_name}: {probability_value:.4f}") | |
| return "\n".join(result_lines) | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| def retrieve_best_image( | |
| gallery_value: Any, | |
| query_text: str, | |
| clip_key: str, | |
| ) -> Tuple[str, Image.Image | None]: | |
| image_list = _normalize_gallery_images(gallery_value) | |
| if not image_list or not query_text.strip(): | |
| return "Загрузите изображения и введите запрос", None | |
| try: | |
| clip_model, clip_processor = get_clip_components(clip_key) | |
| image_inputs = clip_processor( | |
| images=image_list, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| image_features = clip_model.get_image_features(**image_inputs) | |
| image_features = image_features / image_features.norm( | |
| dim=-1, | |
| keepdim=True, | |
| ) | |
| text_inputs = clip_processor( | |
| text=[query_text], | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| text_features = clip_model.get_text_features(**text_inputs) | |
| text_features = text_features / text_features.norm( | |
| dim=-1, | |
| keepdim=True, | |
| ) | |
| similarity_tensor = image_features @ text_features.T | |
| best_index_tensor = similarity_tensor.argmax() | |
| best_index_value = best_index_tensor.item() | |
| best_score_value = similarity_tensor[best_index_value].item() | |
| description_text = ( | |
| f"Изображение #{best_index_value + 1} " | |
| f"(схожесть: {best_score_value:.4f})" | |
| ) | |
| return description_text, image_list[best_index_value] | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}", None | |
| def segment_image_with_sam_points( | |
| image_object, | |
| point_coordinates_list: List[List[int]], | |
| ) -> Image.Image: | |
| if image_object is None: | |
| raise ValueError("Изображение не передано") | |
| if not point_coordinates_list: | |
| return Image.new("L", image_object.size, color=0) | |
| try: | |
| sam_model, sam_processor = get_sam_components() | |
| batched_points: List[List[List[int]]] = [point_coordinates_list] | |
| batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]] | |
| sam_inputs = sam_processor( | |
| image=image_object, | |
| input_points=batched_points, | |
| input_labels=batched_labels, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| sam_outputs = sam_model(**sam_inputs, multimask_output=True) | |
| processed_masks_list = sam_processor.image_processor.post_process_masks( | |
| sam_outputs.pred_masks.squeeze(1).cpu(), | |
| sam_inputs["original_sizes"].cpu(), | |
| sam_inputs["reshaped_input_sizes"].cpu(), | |
| ) | |
| batch_masks_tensor = processed_masks_list[0] | |
| if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0: | |
| return Image.new("L", image_object.size, color=0) | |
| first_mask_tensor = batch_masks_tensor[0] | |
| mask_array = first_mask_tensor.numpy() | |
| binary_mask_array = (mask_array > 0.5).astype("uint8") * 255 | |
| mask_image = Image.fromarray(binary_mask_array, mode="L") | |
| return mask_image | |
| except Exception as e: | |
| print(f"Ошибка: {str(e)}") | |
| return Image.new("L", image_object.size, color=0) | |
| def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image: | |
| if image_object is None: | |
| return None | |
| coordinates_text_clean = coordinates_text.strip() | |
| if not coordinates_text_clean: | |
| return Image.new("L", image_object.size, color=0) | |
| point_coordinates_list: List[List[int]] = [] | |
| for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"): | |
| raw_pair_clean = raw_pair.strip() | |
| if not raw_pair_clean: | |
| continue | |
| parts = raw_pair_clean.split(",") | |
| if len(parts) != 2: | |
| continue | |
| try: | |
| x_value = int(parts[0].strip()) | |
| y_value = int(parts[1].strip()) | |
| except ValueError: | |
| continue | |
| point_coordinates_list.append([x_value, y_value]) | |
| if not point_coordinates_list: | |
| return Image.new("L", image_object.size, color=0) | |
| return segment_image_with_sam_points(image_object, point_coordinates_list) | |
| def build_interface(): | |
| with gr.Blocks(title="Vision Processing Demo") as demo: | |
| gr.Markdown("# Система обработки изображений") | |
| with gr.Tab("Детекция объектов"): | |
| object_input_image = gr.Image(label="Загрузите изображение", type="pil") | |
| object_model_selector = gr.Dropdown( | |
| choices=[ | |
| "object_detection_conditional_detr", | |
| "object_detection_yolos_small", | |
| ], | |
| label="Модель", | |
| value="object_detection_conditional_detr", | |
| ) | |
| object_detect_button = gr.Button("Выполнить детекцию") | |
| object_output_image = gr.Image(label="Результат") | |
| object_detect_button.click( | |
| fn=detect_objects_on_image, | |
| inputs=[object_input_image, object_model_selector], | |
| outputs=object_output_image, | |
| ) | |
| with gr.Tab("Сегментация"): | |
| segmentation_input_image = gr.Image(label="Загрузите изображение", type="pil") | |
| segmentation_button = gr.Button("Запустить сегментацию") | |
| segmentation_output_image = gr.Image(label="Маска") | |
| segmentation_button.click( | |
| fn=segment_image, | |
| inputs=segmentation_input_image, | |
| outputs=segmentation_output_image, | |
| ) | |
| with gr.Tab("Оценка глубины"): | |
| depth_input_image = gr.Image(label="Загрузите изображение", type="pil") | |
| depth_button = gr.Button("Оценить глубину") | |
| depth_output_image = gr.Image(label="Карта глубины") | |
| depth_button.click( | |
| fn=estimate_image_depth, | |
| inputs=depth_input_image, | |
| outputs=depth_output_image, | |
| ) | |
| with gr.Tab("Описание"): | |
| caption_input_image = gr.Image(label="Загрузите изображение", type="pil") | |
| caption_model_selector = gr.Dropdown( | |
| choices=[ | |
| "captioning_blip_base", | |
| "captioning_blip_large", | |
| ], | |
| label="Модель", | |
| value="captioning_blip_base", | |
| ) | |
| caption_button = gr.Button("Создать описание") | |
| caption_output_text = gr.Textbox(label="Описание", lines=3) | |
| caption_button.click( | |
| fn=generate_image_caption, | |
| inputs=[caption_input_image, caption_model_selector], | |
| outputs=caption_output_text, | |
| ) | |
| with gr.Tab("VQA"): | |
| vqa_input_image = gr.Image(label="Загрузите изображение", type="pil") | |
| vqa_question_text = gr.Textbox(label="Вопрос", lines=2) | |
| vqa_model_selector = gr.Dropdown( | |
| choices=[ | |
| "vqa_blip_base", | |
| "vqa_vilt_b32", | |
| ], | |
| label="Модель", | |
| value="vqa_blip_base", | |
| ) | |
| vqa_button = gr.Button("Задать вопрос") | |
| vqa_output_text = gr.Textbox(label="Ответ", lines=3) | |
| vqa_button.click( | |
| fn=answer_visual_question, | |
| inputs=[vqa_input_image, vqa_question_text, vqa_model_selector], | |
| outputs=vqa_output_text, | |
| ) | |
| with gr.Tab("Zero-Shot"): | |
| zero_shot_input_image = gr.Image(label="Загрузите изображение", type="pil") | |
| zero_shot_classes_text = gr.Textbox( | |
| label="Классы", | |
| placeholder="Введите классы через запятую", | |
| lines=2, | |
| ) | |
| clip_model_selector = gr.Dropdown( | |
| choices=[ | |
| "clip_large_patch14", | |
| "clip_base_patch32", | |
| ], | |
| label="Модель", | |
| value="clip_large_patch14", | |
| ) | |
| zero_shot_button = gr.Button("Классифицировать") | |
| zero_shot_output_text = gr.Textbox(label="Результаты", lines=8) | |
| zero_shot_button.click( | |
| fn=perform_zero_shot_classification, | |
| inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector], | |
| outputs=zero_shot_output_text, | |
| ) | |
| with gr.Tab("Поиск"): | |
| retrieval_dir = gr.File( | |
| label="Загрузите папку", | |
| file_count="directory", | |
| file_types=["image"], | |
| type="filepath", | |
| ) | |
| retrieval_query_text = gr.Textbox(label="Текстовый запрос", lines=2) | |
| retrieval_clip_selector = gr.Dropdown( | |
| choices=[ | |
| "clip_large_patch14", | |
| "clip_base_patch32", | |
| ], | |
| label="Модель", | |
| value="clip_large_patch14", | |
| ) | |
| retrieval_button = gr.Button("Найти изображение") | |
| retrieval_output_text = gr.Textbox(label="Результат") | |
| retrieval_output_image = gr.Image(label="Найденное изображение") | |
| retrieval_button.click( | |
| fn=retrieve_best_image, | |
| inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector], | |
| outputs=[retrieval_output_text, retrieval_output_image], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| interface = build_interface() | |
| interface.launch(share=True, server_name="0.0.0.0") |