ASureevaA commited on
Commit
de91dc1
·
1 Parent(s): ff2cc71

fix image q

Browse files
Files changed (1) hide show
  1. app.py +50 -30
app.py CHANGED
@@ -15,6 +15,8 @@ from transformers import (
15
  SamProcessor,
16
  VitsModel,
17
  pipeline,
 
 
18
  )
19
 
20
 
@@ -94,6 +96,17 @@ def get_zero_shot_audio_pipeline():
94
  return MODEL_STORE["audio_zero_shot_clap"]
95
 
96
 
 
 
 
 
 
 
 
 
 
 
 
97
  def get_vision_pipeline(model_key: str):
98
  if model_key in MODEL_STORE:
99
  return MODEL_STORE[model_key]
@@ -185,16 +198,15 @@ def get_silero_tts_model():
185
  return MODEL_STORE["silero_tts_model"]
186
 
187
 
188
- def get_mms_tts_components() -> Tuple[VitsModel, AutoTokenizer]:
189
- if "mms_tts_model" not in MODEL_STORE or "mms_tts_tokenizer" not in MODEL_STORE:
190
- vits_model = VitsModel.from_pretrained("kakao-enterprise/vits-ljs")
191
- vits_tokenizer = AutoTokenizer.from_pretrained("kakao-enterprise/vits-ljs")
192
- MODEL_STORE["mms_tts_model"] = vits_model
193
- MODEL_STORE["mms_tts_tokenizer"] = vits_tokenizer
 
194
 
195
- vits_model = MODEL_STORE["mms_tts_model"]
196
- vits_tokenizer = MODEL_STORE["mms_tts_tokenizer"]
197
- return vits_model, vits_tokenizer
198
 
199
 
200
  def get_sam_components() -> Tuple[SamModel, SamProcessor]:
@@ -262,18 +274,6 @@ def recognize_speech(audio_path: str, model_key: str) -> str:
262
 
263
 
264
  def synthesize_speech(text_value: str, model_key: str):
265
- if model_key == "silero":
266
- silero_model = get_silero_tts_model()
267
-
268
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
269
- silero_model.save_wav(
270
- text=text_value,
271
- speaker="aidar",
272
- sample_rate=48000,
273
- audio_path=file_object.name,
274
- )
275
- return file_object.name
276
-
277
  if model_key == "Google TTS":
278
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
279
  text_to_speech_engine = gTTS(text=text_value, lang="ru")
@@ -281,18 +281,18 @@ def synthesize_speech(text_value: str, model_key: str):
281
  return file_object.name
282
 
283
  if model_key == "vits-ljs":
284
- vits_model, vits_tokenizer = get_mms_tts_components()
285
- tokenized_input = vits_tokenizer(text_value, return_tensors="pt")
286
 
287
- with torch.no_grad():
288
- waveform_tensor = vits_model(**tokenized_input).waveform
 
 
289
 
290
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
291
- waveform_array = waveform_tensor.numpy().squeeze()
292
  soundfile_module.write(
293
  file_object.name,
294
- waveform_array,
295
- vits_model.config.sampling_rate,
296
  )
297
  return file_object.name
298
 
@@ -381,6 +381,26 @@ def answer_visual_question(image_object, question_text: str, model_key: str) ->
381
  if not question_text.strip():
382
  return "Пожалуйста, введите вопрос об изображении."
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  vqa_pipeline = get_vision_pipeline(model_key)
385
 
386
  vqa_result = vqa_pipeline(
@@ -680,9 +700,9 @@ def build_interface():
680
  lines=3,
681
  )
682
  tts_model_selector = gr.Dropdown(
683
- choices=["vits-ljs", "Google TTS",],
684
  label="Выберите модель",
685
- value="silero",
686
  info=(
687
  "kakao-enterprise/vits-ljs"
688
  "Google TTS"
 
15
  SamProcessor,
16
  VitsModel,
17
  pipeline,
18
+ BlipForQuestionAnswering,
19
+ BlipProcessor,
20
  )
21
 
22
 
 
96
  return MODEL_STORE["audio_zero_shot_clap"]
97
 
98
 
99
+ def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
100
+ if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
101
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
102
+ blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
103
+ MODEL_STORE["blip_vqa_model"] = blip_model
104
+ MODEL_STORE["blip_vqa_processor"] = blip_processor
105
+
106
+ blip_model = MODEL_STORE["blip_vqa_model"]
107
+ blip_processor = MODEL_STORE["blip_vqa_processor"]
108
+ return blip_model, blip_processor
109
+
110
  def get_vision_pipeline(model_key: str):
111
  if model_key in MODEL_STORE:
112
  return MODEL_STORE[model_key]
 
198
  return MODEL_STORE["silero_tts_model"]
199
 
200
 
201
+ def get_mms_tts_components():
202
+ if "mms_tts_pipeline" not in MODEL_STORE:
203
+ tts_pipeline = pipeline(
204
+ task="text-to-speech",
205
+ model="kakao-enterprise/vits-ljs",
206
+ )
207
+ MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
208
 
209
+ return MODEL_STORE["mms_tts_pipeline"]
 
 
210
 
211
 
212
  def get_sam_components() -> Tuple[SamModel, SamProcessor]:
 
274
 
275
 
276
  def synthesize_speech(text_value: str, model_key: str):
 
 
 
 
 
 
 
 
 
 
 
 
277
  if model_key == "Google TTS":
278
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
279
  text_to_speech_engine = gTTS(text=text_value, lang="ru")
 
281
  return file_object.name
282
 
283
  if model_key == "vits-ljs":
284
+ tts_pipeline = get_mms_tts_components()
 
285
 
286
+ tts_output = tts_pipeline(text_value)
287
+
288
+ audio_array = tts_output["audio"]
289
+ sampling_rate_value = tts_output["sampling_rate"]
290
 
291
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
 
292
  soundfile_module.write(
293
  file_object.name,
294
+ audio_array,
295
+ sampling_rate_value,
296
  )
297
  return file_object.name
298
 
 
381
  if not question_text.strip():
382
  return "Пожалуйста, введите вопрос об изображении."
383
 
384
+ if model_key == "vqa_blip_base":
385
+ blip_model, blip_processor = get_blip_vqa_components()
386
+
387
+ inputs = blip_processor(
388
+ images=image_object,
389
+ text=question_text,
390
+ return_tensors="pt",
391
+ )
392
+
393
+ with torch.no_grad():
394
+ output_ids = blip_model.generate(**inputs)
395
+
396
+ decoded_answers = blip_processor.batch_decode(
397
+ output_ids,
398
+ skip_special_tokens=True,
399
+ )
400
+ answer_text = decoded_answers[0] if decoded_answers else ""
401
+
402
+ return answer_text or "Модель не смогла сгенерировать ответ."
403
+
404
  vqa_pipeline = get_vision_pipeline(model_key)
405
 
406
  vqa_result = vqa_pipeline(
 
700
  lines=3,
701
  )
702
  tts_model_selector = gr.Dropdown(
703
+ choices=["vits-ljs", "Google TTS"],
704
  label="Выберите модель",
705
+ value="vits-ljs",
706
  info=(
707
  "kakao-enterprise/vits-ljs"
708
  "Google TTS"