Spaces:
Build error
Build error
| import os | |
| import queue | |
| import asyncio | |
| import concurrent.futures | |
| import functools | |
| import io | |
| import sys | |
| import random | |
| from threading import Thread | |
| import time | |
| from dotenv import load_dotenv | |
| import pyaudio | |
| import speech_recognition as sr | |
| import websockets | |
| from aioconsole import ainput # for async input | |
| from pydub import AudioSegment | |
| from simpleaudio import WaveObject | |
| load_dotenv() | |
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) | |
| web2_initial_message = True | |
| CHUNK = 1024 | |
| FORMAT = pyaudio.paInt16 | |
| CHANNELS = 1 | |
| RATE = 44100 | |
| class AudioPlayer: | |
| def __init__(self): | |
| self.play_thread = None | |
| self.stop_flag = False | |
| self.queue = queue.Queue() | |
| def play_audio(self): | |
| while not self.stop_flag or not self.queue.empty(): | |
| try: | |
| wav_data = self.queue.get_nowait() | |
| except queue.Empty: | |
| continue | |
| wave_obj = WaveObject.from_wave_file(wav_data) | |
| play_obj = wave_obj.play() | |
| while play_obj.is_playing() and not self.stop_flag: | |
| time.sleep(0.1) | |
| if self.stop_flag: | |
| play_obj.stop() | |
| def start_playing(self, wav_data): | |
| self.stop_flag = False | |
| self.queue.put(wav_data) | |
| if self.play_thread is None or not self.play_thread.is_alive(): | |
| self.play_thread = Thread(target=self.play_audio) | |
| self.play_thread.start() | |
| def stop_playing(self): | |
| if self.play_thread and self.play_thread.is_alive(): | |
| self.stop_flag = True | |
| self.play_thread.join() | |
| self.play_thread = None | |
| def add_to_queue(self, wav_data): | |
| self.queue.put(wav_data) | |
| audio_player = AudioPlayer() | |
| def get_input_device_id(): | |
| p = pyaudio.PyAudio() | |
| devices = [(i, p.get_device_info_by_index(i)['name']) | |
| for i in range(p.get_device_count()) | |
| if p.get_device_info_by_index(i).get('maxInputChannels')] | |
| print('Available devices:') | |
| for id, name in devices: | |
| print(f"Device id {id} - {name}") | |
| return int(input('Please select device id: ')) | |
| async def handle_audio(websocket, device_id): | |
| with sr.Microphone(device_index=device_id, sample_rate=RATE) as source: | |
| recognizer = sr.Recognizer() | |
| print('Source sample rate: ', source.SAMPLE_RATE) | |
| print('Source width: ', source.SAMPLE_WIDTH) | |
| print('Adjusting for ambient noise...Wait for 2 seconds') | |
| recognizer.energy_threshold = 5000 | |
| recognizer.dynamic_energy_ratio = 6 | |
| recognizer.dynamic_energy_adjustment_damping = 0.85 | |
| recognizer.non_speaking_duration = 0.5 | |
| recognizer.pause_threshold = 0.8 | |
| recognizer.phrase_threshold = 0.5 | |
| recognizer.adjust_for_ambient_noise(source, duration=2) | |
| listen_func = functools.partial( | |
| recognizer.listen, source, phrase_time_limit=30) | |
| print('Okay, start talking!') | |
| while True: | |
| print('[*]', end="") # indicate that we are listening | |
| audio = await asyncio.get_event_loop().run_in_executor(executor, listen_func) | |
| await websocket.send(audio.frame_data) | |
| print('[-]', end="") # indicate that we are done listening | |
| await asyncio.sleep(2) | |
| async def handle_text(websocket): | |
| print('You: ', end="", flush=False) | |
| while True: | |
| message = await ainput() | |
| await websocket.send(message) | |
| initial_message = True | |
| async def receive_message(websocket, websocket2): | |
| web1_init_message = await websocket.recv() | |
| print('web1_init_message: ', web1_init_message) | |
| web2_init_message = await websocket2.recv() | |
| print('web1_init_message: ', web2_init_message) | |
| message_to_websocket1 = "Suppose I'm Steve Jobs now. What question do you have for me?" | |
| await websocket.send(message_to_websocket1) | |
| web1_message = '' | |
| while True: | |
| try: | |
| message = await websocket.recv() | |
| print('here') | |
| except websockets.exceptions.ConnectionClosedError as e: | |
| print("Connection closed unexpectedly: ", e) | |
| break | |
| except Exception as e: | |
| print("An error occurred: ", e) | |
| break | |
| if isinstance(message, str): | |
| if message == '[end]\n': | |
| if not web1_message: | |
| continue | |
| # remove everything before '> ' in the message | |
| message_to_websocket2 = web1_message[web1_message.find('> ') + 2:] | |
| # print('message_to_websocket2: ', message_to_websocket2) | |
| await websocket2.send(message_to_websocket2) | |
| web2_message = '' | |
| j = 0 | |
| while True: | |
| j += 1 | |
| try: | |
| message = await websocket2.recv() | |
| except websockets.exceptions.ConnectionClosedError as e: | |
| print("Connection closed unexpectedly: ", e) | |
| break | |
| except Exception as e: | |
| print("An error occurred: ", e) | |
| break | |
| if isinstance(message, str): | |
| if message == '[end]\n': | |
| # print('\nWebsocket2: ', end="", flush=False) | |
| if not web2_message: | |
| # print('skip') | |
| continue | |
| # remove everything before '> ' in the message | |
| print(web2_message) | |
| message_from_websocket2 = web2_message[web2_message.find('> ') + 2:] | |
| await websocket.send(message_from_websocket2) | |
| break | |
| elif message.startswith('[+]'): | |
| # stop playing audio | |
| audio_player.stop_playing() | |
| # indicate the transcription is done | |
| # print(f"\nnWebsocket2: {message}", end="\n", flush=False) | |
| elif message.startswith('[=]'): | |
| # indicate the response is done | |
| # print(f"nWebsocket2: {web2_message}", end="\n", flush=False) | |
| pass | |
| else: | |
| # print('\nmessage++\n') | |
| web2_message += message | |
| elif isinstance(message, bytes): | |
| global web2_initial_message | |
| if web2_initial_message: | |
| web2_initial_message = False | |
| continue | |
| audio_data = io.BytesIO(message) | |
| audio = AudioSegment.from_mp3(audio_data) | |
| wav_data = io.BytesIO() | |
| audio.export(wav_data, format="wav") | |
| # Start playing audio | |
| audio_player.start_playing(wav_data) | |
| elif message.startswith('[+]'): | |
| # stop playing audio | |
| audio_player.stop_playing() | |
| # indicate the transcription is done | |
| print(f"\n{message}", end="\n", flush=False) | |
| elif message.startswith('[=]'): | |
| # indicate the response is done | |
| print(f"{message}", end="\n", flush=False) | |
| else: | |
| web1_message += message | |
| print(f"{message}", end="", flush=False) | |
| elif isinstance(message, bytes): | |
| audio_data = io.BytesIO(message) | |
| audio = AudioSegment.from_mp3(audio_data) | |
| wav_data = io.BytesIO() | |
| audio.export(wav_data, format="wav") | |
| # Start playing audio | |
| audio_player.start_playing(wav_data) | |
| else: | |
| print("Unexpected message") | |
| break | |
| def select_model(): | |
| llm_model_selection = input( | |
| '1: gpt-3.5-turbo-16k \n' | |
| '2: gpt-4 \n' | |
| '3: claude-2 \n' | |
| 'Select llm model:') | |
| if llm_model_selection == '1': | |
| llm_model = 'gpt-3.5-turbo-16k' | |
| elif llm_model_selection == '2': | |
| llm_model = 'gpt-4' | |
| elif llm_model_selection == '3': | |
| llm_model = 'claude-2' | |
| return llm_model | |
| async def start_client(client_id, url): | |
| api_key = os.getenv('AUTH_API_KEY') | |
| llm_model = select_model() | |
| uri = f"ws://{url}/ws/{client_id}?api_key={api_key}&llm_model={llm_model}" | |
| async with websockets.connect(uri) as websocket: | |
| uri2 = f"ws://{url}/ws/9999999?api_key={api_key}&llm_model={llm_model}" | |
| # send client platform info | |
| async with websockets.connect(uri2) as websocket2: | |
| await websocket.send('terminal') | |
| await websocket2.send('terminal') | |
| print(f"Client #{client_id} connected to websocket1") | |
| print(f"Client 9999999 connected to websocket2") | |
| welcome_message = await websocket.recv() | |
| welcome_message2 = await websocket2.recv() | |
| print(f"{welcome_message}") | |
| character = input('Select character: ') | |
| await websocket.send(character) | |
| await websocket2.send('6') | |
| mode = input('Select mode (1: audio, 2: text): ') | |
| if mode.lower() == '1': | |
| device_id = get_input_device_id() | |
| send_task = asyncio.create_task(handle_audio(websocket, device_id)) | |
| else: | |
| send_task = asyncio.create_task(handle_text(websocket)) | |
| receive_task = asyncio.create_task(receive_message(websocket, websocket2)) | |
| await asyncio.gather(receive_task, send_task) | |
| async def main(url): | |
| client_id = random.randint(0, 1000000) | |
| task = asyncio.create_task(start_client(client_id, url)) | |
| try: | |
| await task | |
| except KeyboardInterrupt: | |
| task.cancel() | |
| await asyncio.wait_for(task, timeout=None) | |
| print("Client stopped by user") | |
| if __name__ == "__main__": | |
| url = sys.argv[1] if len(sys.argv) > 1 else 'localhost:8000' | |
| asyncio.run(main(url)) | |