ASureevaA commited on
Commit
3680138
·
1 Parent(s): c14e744

fix image q

Browse files
Files changed (1) hide show
  1. app.py +112 -37
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import tempfile
2
- from typing import List, Tuple
3
 
4
  import gradio as gr
5
  import soundfile as soundfile_module
@@ -20,8 +20,32 @@ from transformers import (
20
 
21
  MODEL_STORE = {}
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
 
 
25
  def get_audio_pipeline(model_key: str):
26
  if model_key in MODEL_STORE:
27
  return MODEL_STORE[model_key]
@@ -310,9 +334,18 @@ def estimate_image_depth(image_object):
310
 
311
  predicted_depth_tensor = depth_output["predicted_depth"]
312
 
 
 
 
 
 
 
 
 
 
313
  resized_depth_tensor = torch_functional.interpolate(
314
- predicted_depth_tensor.unsqueeze(0).unsqueeze(0),
315
- size=image_object.size[::-1], # (width, height) -> (H, W)
316
  mode="bicubic",
317
  align_corners=False,
318
  )
@@ -335,13 +368,24 @@ def generate_image_caption(image_object, model_key: str) -> str:
335
 
336
 
337
  def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
 
 
 
 
 
 
338
  vqa_pipeline = get_vision_pipeline(model_key)
339
- vqa_result = vqa_pipeline(image_object, question_text)
340
 
341
- answer_text = vqa_result[0]["answer"]
342
- confidence_value = vqa_result[0]["score"]
343
- return f"{answer_text} (confidence: {confidence_value:.3f})"
 
 
 
 
 
344
 
 
345
 
346
  def perform_zero_shot_classification(
347
  image_object,
@@ -379,11 +423,13 @@ def perform_zero_shot_classification(
379
 
380
 
381
  def retrieve_best_image(
382
- image_list: List,
383
  query_text: str,
384
  clip_key: str,
385
- ):
386
- if not image_list or not query_text:
 
 
387
  return "Пожалуйста, загрузите изображения и введите запрос", None
388
 
389
  clip_model, clip_processor = get_clip_components(clip_key)
@@ -426,44 +472,83 @@ def retrieve_best_image(
426
 
427
  def segment_image_with_sam_points(
428
  image_object,
429
- point_coordinates_list: List[List[int]] | None,
430
- ) -> Image:
 
 
431
 
432
  if not point_coordinates_list:
433
  return Image.new("L", image_object.size, color=0)
434
 
435
  sam_model, sam_processor = get_sam_components()
436
 
437
- batched_points = [point_coordinates_list]
438
- batched_labels = [[1 for _ in point_coordinates_list]]
439
 
440
  sam_inputs = sam_processor(
441
- image_object,
442
  input_points=batched_points,
443
  input_labels=batched_labels,
444
  return_tensors="pt",
445
  )
446
 
447
  with torch.no_grad():
448
- sam_outputs = sam_model(**sam_inputs)
449
 
450
- post_processed_masks_list = sam_processor.image_processor.post_process_masks(
451
- sam_outputs.pred_masks.cpu(),
452
  sam_inputs["original_sizes"].cpu(),
453
  sam_inputs["reshaped_input_sizes"].cpu(),
454
  )
455
 
456
- batched_masks_tensor = post_processed_masks_list[0] # shape: [num_masks, H, W]
457
- if batched_masks_tensor.ndim != 3 or batched_masks_tensor.shape[0] == 0:
 
458
  return Image.new("L", image_object.size, color=0)
459
 
460
- first_mask_tensor = batched_masks_tensor[0] # [H, W]
461
- mask_array = first_mask_tensor.cpu().numpy()
462
 
463
- mask_image = Image.fromarray((mask_array * 255.0).astype("uint8"), mode="L")
 
 
464
  return mask_image
465
 
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
468
  if not coordinates_text.strip():
469
  return []
@@ -485,16 +570,6 @@ def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
485
 
486
  return point_list
487
 
488
-
489
- def segment_image_with_sam_points_ui(
490
- image_object,
491
- coordinates_text: str,
492
- ):
493
- point_coordinates_list = parse_point_coordinates_text(coordinates_text)
494
- return segment_image_with_sam_points(image_object, point_coordinates_list)
495
-
496
-
497
-
498
  def build_interface():
499
  with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo_block:
500
  gr.Markdown("#Мультимодальные AI модели")
@@ -588,9 +663,8 @@ def build_interface():
588
  inputs=[asr_audio_input_component, asr_model_selector],
589
  outputs=asr_output_component,
590
  )
591
-
592
  with gr.Tab("Синтез речи"):
593
- gr.Markdown("## Text-to-Speech (TTS)")
594
  with gr.Row():
595
  with gr.Column():
596
  tts_text_component = gr.Textbox(
@@ -612,11 +686,12 @@ def build_interface():
612
  with gr.Column():
613
  tts_audio_output_component = gr.Audio(
614
  label="Синтезированная речь",
 
615
  )
616
 
617
  tts_button.click(
618
  fn=synthesize_speech,
619
- inputs=tts_text_component,
620
  outputs=tts_audio_output_component,
621
  )
622
 
@@ -706,7 +781,7 @@ def build_interface():
706
  )
707
  sam_coordinates_text = gr.Textbox(
708
  label="Координаты точек",
709
- placeholder="100,150; 200,220",
710
  lines=2,
711
  )
712
  sam_button = gr.Button("Сегментировать по точкам")
 
1
  import tempfile
2
+ from typing import List, Tuple, Any
3
 
4
  import gradio as gr
5
  import soundfile as soundfile_module
 
20
 
21
  MODEL_STORE = {}
22
 
23
+ def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
24
+ if not gallery_value:
25
+ return []
26
+
27
+ normalized_images: List[Image.Image] = []
28
+
29
+ for item in gallery_value:
30
+ if isinstance(item, Image.Image):
31
+ normalized_images.append(item)
32
+ continue
33
+
34
+ if isinstance(item, (list, tuple)) and item:
35
+ candidate = item[0]
36
+ if isinstance(candidate, Image.Image):
37
+ normalized_images.append(candidate)
38
+ continue
39
+
40
+ if isinstance(item, dict):
41
+ candidate = item.get("image") or item.get("value")
42
+ if isinstance(candidate, Image.Image):
43
+ normalized_images.append(candidate)
44
+ continue
45
 
46
 
47
+ return normalized_images
48
+
49
  def get_audio_pipeline(model_key: str):
50
  if model_key in MODEL_STORE:
51
  return MODEL_STORE[model_key]
 
334
 
335
  predicted_depth_tensor = depth_output["predicted_depth"]
336
 
337
+ if predicted_depth_tensor.ndim == 3:
338
+ predicted_depth_tensor = predicted_depth_tensor.unsqueeze(1)
339
+ elif predicted_depth_tensor.ndim == 2:
340
+ predicted_depth_tensor = predicted_depth_tensor.unsqueeze(0).unsqueeze(0)
341
+ else:
342
+ raise ValueError(
343
+ f"Неожиданная размерность predicted_depth: {predicted_depth_tensor.shape}"
344
+ )
345
+
346
  resized_depth_tensor = torch_functional.interpolate(
347
+ predicted_depth_tensor,
348
+ size=image_object.size[::-1],
349
  mode="bicubic",
350
  align_corners=False,
351
  )
 
368
 
369
 
370
  def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
371
+ if image_object is None:
372
+ return "Пожалуйста, сначала загрузите изображение."
373
+
374
+ if not question_text.strip():
375
+ return "Пожалуйста, введите вопрос об изображении."
376
+
377
  vqa_pipeline = get_vision_pipeline(model_key)
 
378
 
379
+ vqa_result = vqa_pipeline(
380
+ image=image_object,
381
+ question=question_text,
382
+ )
383
+
384
+ top_item = vqa_result[0]
385
+ answer_text = top_item["answer"]
386
+ confidence_value = top_item["score"]
387
 
388
+ return f"{answer_text} (confidence: {confidence_value:.3f})"
389
 
390
  def perform_zero_shot_classification(
391
  image_object,
 
423
 
424
 
425
  def retrieve_best_image(
426
+ gallery_value: Any,
427
  query_text: str,
428
  clip_key: str,
429
+ ) -> Tuple[str, Image.Image | None]:
430
+ image_list = _normalize_gallery_images(gallery_value)
431
+
432
+ if not image_list or not query_text.strip():
433
  return "Пожалуйста, загрузите изображения и введите запрос", None
434
 
435
  clip_model, clip_processor = get_clip_components(clip_key)
 
472
 
473
  def segment_image_with_sam_points(
474
  image_object,
475
+ point_coordinates_list: List[List[int]],
476
+ ) -> Image.Image:
477
+ if image_object is None:
478
+ raise ValueError("Изображение не передано в segment_image_with_sam_points")
479
 
480
  if not point_coordinates_list:
481
  return Image.new("L", image_object.size, color=0)
482
 
483
  sam_model, sam_processor = get_sam_components()
484
 
485
+ batched_points: List[List[List[int]]] = [point_coordinates_list]
486
+ batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]]
487
 
488
  sam_inputs = sam_processor(
489
+ image=image_object,
490
  input_points=batched_points,
491
  input_labels=batched_labels,
492
  return_tensors="pt",
493
  )
494
 
495
  with torch.no_grad():
496
+ sam_outputs = sam_model(**sam_inputs, multimask_output=True)
497
 
498
+ processed_masks_list = sam_processor.image_processor.post_process_masks(
499
+ sam_outputs.pred_masks.squeeze(1).cpu(),
500
  sam_inputs["original_sizes"].cpu(),
501
  sam_inputs["reshaped_input_sizes"].cpu(),
502
  )
503
 
504
+ batch_masks_tensor = processed_masks_list[0]
505
+
506
+ if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0:
507
  return Image.new("L", image_object.size, color=0)
508
 
509
+ first_mask_tensor = batch_masks_tensor[0]
510
+ mask_array = first_mask_tensor.numpy()
511
 
512
+ binary_mask_array = (mask_array > 0.5).astype("uint8") * 255
513
+
514
+ mask_image = Image.fromarray(binary_mask_array, mode="L")
515
  return mask_image
516
 
517
 
518
+ def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image:
519
+
520
+ if image_object is None:
521
+ return None
522
+
523
+ coordinates_text_clean = coordinates_text.strip()
524
+ if not coordinates_text_clean:
525
+ return Image.new("L", image_object.size, color=0)
526
+
527
+ point_coordinates_list: List[List[int]] = []
528
+
529
+ for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"):
530
+ raw_pair_clean = raw_pair.strip()
531
+ if not raw_pair_clean:
532
+ continue
533
+
534
+ parts = raw_pair_clean.split(",")
535
+ if len(parts) != 2:
536
+ continue
537
+
538
+ try:
539
+ x_value = int(parts[0].strip())
540
+ y_value = int(parts[1].strip())
541
+ except ValueError:
542
+ continue
543
+
544
+ point_coordinates_list.append([x_value, y_value])
545
+
546
+ if not point_coordinates_list:
547
+ return Image.new("L", image_object.size, color=0)
548
+
549
+ return segment_image_with_sam_points(image_object, point_coordinates_list)
550
+
551
+
552
  def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
553
  if not coordinates_text.strip():
554
  return []
 
570
 
571
  return point_list
572
 
 
 
 
 
 
 
 
 
 
 
573
  def build_interface():
574
  with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo_block:
575
  gr.Markdown("#Мультимодальные AI модели")
 
663
  inputs=[asr_audio_input_component, asr_model_selector],
664
  outputs=asr_output_component,
665
  )
 
666
  with gr.Tab("Синтез речи"):
667
+ gr.Markdown("## Text-to-Speech")
668
  with gr.Row():
669
  with gr.Column():
670
  tts_text_component = gr.Textbox(
 
686
  with gr.Column():
687
  tts_audio_output_component = gr.Audio(
688
  label="Синтезированная речь",
689
+ type="filepath",
690
  )
691
 
692
  tts_button.click(
693
  fn=synthesize_speech,
694
+ inputs=[tts_text_component, tts_model_selector],
695
  outputs=tts_audio_output_component,
696
  )
697
 
 
781
  )
782
  sam_coordinates_text = gr.Textbox(
783
  label="Координаты точек",
784
+ placeholder="650,380; 600,450; 550,520",
785
  lines=2,
786
  )
787
  sam_button = gr.Button("Сегментировать по точкам")