ASureevaA
commited on
Commit
·
de91dc1
1
Parent(s):
ff2cc71
fix image q
Browse files
app.py
CHANGED
|
@@ -15,6 +15,8 @@ from transformers import (
|
|
| 15 |
SamProcessor,
|
| 16 |
VitsModel,
|
| 17 |
pipeline,
|
|
|
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
|
|
@@ -94,6 +96,17 @@ def get_zero_shot_audio_pipeline():
|
|
| 94 |
return MODEL_STORE["audio_zero_shot_clap"]
|
| 95 |
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
def get_vision_pipeline(model_key: str):
|
| 98 |
if model_key in MODEL_STORE:
|
| 99 |
return MODEL_STORE[model_key]
|
|
@@ -185,16 +198,15 @@ def get_silero_tts_model():
|
|
| 185 |
return MODEL_STORE["silero_tts_model"]
|
| 186 |
|
| 187 |
|
| 188 |
-
def get_mms_tts_components()
|
| 189 |
-
if "
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
| 194 |
|
| 195 |
-
|
| 196 |
-
vits_tokenizer = MODEL_STORE["mms_tts_tokenizer"]
|
| 197 |
-
return vits_model, vits_tokenizer
|
| 198 |
|
| 199 |
|
| 200 |
def get_sam_components() -> Tuple[SamModel, SamProcessor]:
|
|
@@ -262,18 +274,6 @@ def recognize_speech(audio_path: str, model_key: str) -> str:
|
|
| 262 |
|
| 263 |
|
| 264 |
def synthesize_speech(text_value: str, model_key: str):
|
| 265 |
-
if model_key == "silero":
|
| 266 |
-
silero_model = get_silero_tts_model()
|
| 267 |
-
|
| 268 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
|
| 269 |
-
silero_model.save_wav(
|
| 270 |
-
text=text_value,
|
| 271 |
-
speaker="aidar",
|
| 272 |
-
sample_rate=48000,
|
| 273 |
-
audio_path=file_object.name,
|
| 274 |
-
)
|
| 275 |
-
return file_object.name
|
| 276 |
-
|
| 277 |
if model_key == "Google TTS":
|
| 278 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
|
| 279 |
text_to_speech_engine = gTTS(text=text_value, lang="ru")
|
|
@@ -281,18 +281,18 @@ def synthesize_speech(text_value: str, model_key: str):
|
|
| 281 |
return file_object.name
|
| 282 |
|
| 283 |
if model_key == "vits-ljs":
|
| 284 |
-
|
| 285 |
-
tokenized_input = vits_tokenizer(text_value, return_tensors="pt")
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
| 289 |
|
| 290 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
|
| 291 |
-
waveform_array = waveform_tensor.numpy().squeeze()
|
| 292 |
soundfile_module.write(
|
| 293 |
file_object.name,
|
| 294 |
-
|
| 295 |
-
|
| 296 |
)
|
| 297 |
return file_object.name
|
| 298 |
|
|
@@ -381,6 +381,26 @@ def answer_visual_question(image_object, question_text: str, model_key: str) ->
|
|
| 381 |
if not question_text.strip():
|
| 382 |
return "Пожалуйста, введите вопрос об изображении."
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
vqa_pipeline = get_vision_pipeline(model_key)
|
| 385 |
|
| 386 |
vqa_result = vqa_pipeline(
|
|
@@ -680,9 +700,9 @@ def build_interface():
|
|
| 680 |
lines=3,
|
| 681 |
)
|
| 682 |
tts_model_selector = gr.Dropdown(
|
| 683 |
-
choices=["vits-ljs", "Google TTS"
|
| 684 |
label="Выберите модель",
|
| 685 |
-
value="
|
| 686 |
info=(
|
| 687 |
"kakao-enterprise/vits-ljs"
|
| 688 |
"Google TTS"
|
|
|
|
| 15 |
SamProcessor,
|
| 16 |
VitsModel,
|
| 17 |
pipeline,
|
| 18 |
+
BlipForQuestionAnswering,
|
| 19 |
+
BlipProcessor,
|
| 20 |
)
|
| 21 |
|
| 22 |
|
|
|
|
| 96 |
return MODEL_STORE["audio_zero_shot_clap"]
|
| 97 |
|
| 98 |
|
| 99 |
+
def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
|
| 100 |
+
if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
|
| 101 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
| 102 |
+
blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
| 103 |
+
MODEL_STORE["blip_vqa_model"] = blip_model
|
| 104 |
+
MODEL_STORE["blip_vqa_processor"] = blip_processor
|
| 105 |
+
|
| 106 |
+
blip_model = MODEL_STORE["blip_vqa_model"]
|
| 107 |
+
blip_processor = MODEL_STORE["blip_vqa_processor"]
|
| 108 |
+
return blip_model, blip_processor
|
| 109 |
+
|
| 110 |
def get_vision_pipeline(model_key: str):
|
| 111 |
if model_key in MODEL_STORE:
|
| 112 |
return MODEL_STORE[model_key]
|
|
|
|
| 198 |
return MODEL_STORE["silero_tts_model"]
|
| 199 |
|
| 200 |
|
| 201 |
+
def get_mms_tts_components():
|
| 202 |
+
if "mms_tts_pipeline" not in MODEL_STORE:
|
| 203 |
+
tts_pipeline = pipeline(
|
| 204 |
+
task="text-to-speech",
|
| 205 |
+
model="kakao-enterprise/vits-ljs",
|
| 206 |
+
)
|
| 207 |
+
MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
|
| 208 |
|
| 209 |
+
return MODEL_STORE["mms_tts_pipeline"]
|
|
|
|
|
|
|
| 210 |
|
| 211 |
|
| 212 |
def get_sam_components() -> Tuple[SamModel, SamProcessor]:
|
|
|
|
| 274 |
|
| 275 |
|
| 276 |
def synthesize_speech(text_value: str, model_key: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
if model_key == "Google TTS":
|
| 278 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
|
| 279 |
text_to_speech_engine = gTTS(text=text_value, lang="ru")
|
|
|
|
| 281 |
return file_object.name
|
| 282 |
|
| 283 |
if model_key == "vits-ljs":
|
| 284 |
+
tts_pipeline = get_mms_tts_components()
|
|
|
|
| 285 |
|
| 286 |
+
tts_output = tts_pipeline(text_value)
|
| 287 |
+
|
| 288 |
+
audio_array = tts_output["audio"]
|
| 289 |
+
sampling_rate_value = tts_output["sampling_rate"]
|
| 290 |
|
| 291 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
|
|
|
|
| 292 |
soundfile_module.write(
|
| 293 |
file_object.name,
|
| 294 |
+
audio_array,
|
| 295 |
+
sampling_rate_value,
|
| 296 |
)
|
| 297 |
return file_object.name
|
| 298 |
|
|
|
|
| 381 |
if not question_text.strip():
|
| 382 |
return "Пожалуйста, введите вопрос об изображении."
|
| 383 |
|
| 384 |
+
if model_key == "vqa_blip_base":
|
| 385 |
+
blip_model, blip_processor = get_blip_vqa_components()
|
| 386 |
+
|
| 387 |
+
inputs = blip_processor(
|
| 388 |
+
images=image_object,
|
| 389 |
+
text=question_text,
|
| 390 |
+
return_tensors="pt",
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
with torch.no_grad():
|
| 394 |
+
output_ids = blip_model.generate(**inputs)
|
| 395 |
+
|
| 396 |
+
decoded_answers = blip_processor.batch_decode(
|
| 397 |
+
output_ids,
|
| 398 |
+
skip_special_tokens=True,
|
| 399 |
+
)
|
| 400 |
+
answer_text = decoded_answers[0] if decoded_answers else ""
|
| 401 |
+
|
| 402 |
+
return answer_text or "Модель не смогла сгенерировать ответ."
|
| 403 |
+
|
| 404 |
vqa_pipeline = get_vision_pipeline(model_key)
|
| 405 |
|
| 406 |
vqa_result = vqa_pipeline(
|
|
|
|
| 700 |
lines=3,
|
| 701 |
)
|
| 702 |
tts_model_selector = gr.Dropdown(
|
| 703 |
+
choices=["vits-ljs", "Google TTS"],
|
| 704 |
label="Выберите модель",
|
| 705 |
+
value="vits-ljs",
|
| 706 |
info=(
|
| 707 |
"kakao-enterprise/vits-ljs"
|
| 708 |
"Google TTS"
|