Last active
November 6, 2024 14:05
-
-
Save sirk390/82032cc1f92189824dca76b98fb9df16 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import base64 | |
import asyncio | |
import websockets | |
import numpy as np | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
API_KEY = "..." | |
WEBSOCKET_URL = "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01" | |
class RealtimeChat: | |
def __init__(self): | |
self.ws = None | |
self.audio = AsyncAudio() | |
self.outqueue = asyncio.Queue() | |
self.playqueue = asyncio.Queue() | |
self.cancel_play = asyncio.Event() | |
async def main(self): | |
async with websockets.connect(WEBSOCKET_URL, extra_headers={ "Authorization": f"Bearer {API_KEY}", "OpenAI-Beta": "realtime=v1" }) as websocket: | |
self.ws = websocket | |
logger.info("WebSocket connection opened") | |
record_task = asyncio.create_task(self.record_audio()) | |
send_task = asyncio.create_task(self.send_audio()) | |
receive_task = asyncio.create_task(self.receive_messages()) | |
play_audio_task = asyncio.create_task(self.play_audio()) | |
await asyncio.gather(receive_task, record_task, send_task, play_audio_task) | |
async def play_audio(self): | |
while self.is_running: | |
audio_data = await self.playqueue.get() | |
self.cancel_play.clear() | |
play_task = asyncio.create_task(self.audio.play(audio_data)) | |
wait_cancel = asyncio.create_task(self.cancel_play.wait()) | |
done, pending = await asyncio.wait([wait_cancel, play_task], return_when=asyncio.FIRST_COMPLETED) | |
if wait_cancel in done: | |
play_task.cancel() | |
while not self.playqueue.empty(): | |
await self.playqueue.get() | |
else: | |
wait_cancel.cancel() | |
async def receive_messages(self): | |
while self.is_running: | |
message = await self.ws.recv() | |
event = json.loads(message) | |
event_type = event.get('type') | |
if event_type == 'response.audio.delta': | |
audio_data = base64.b64decode(event.get('delta', '')) | |
self.playqueue.put_nowait(audio_data) | |
elif event_type == 'response.audio_transcript.delta': | |
transcript = event.get('delta', '') | |
print(transcript, end="", flush=True) | |
elif event_type == 'input_audio_buffer.speech_stopped': | |
self.cancel_play.set() | |
else: | |
pass | |
async def send_audio(self): | |
while True: | |
audio_data = await self.outqueue.get() | |
base64_audio = base64.b64encode(audio_data).decode('utf-8') | |
await self.ws.send(json.dumps({ "type": "input_audio_buffer.append", "audio": base64_audio })) | |
async def record_audio(self): | |
while True: | |
audio_data = await self.audio.record() | |
self.outqueue.put_nowait(audio_data) | |
async def run(self): | |
self.is_running = True | |
try: | |
await self.main() | |
except asyncio.CancelledError: | |
pass | |
finally: | |
self.is_running = False | |
if self.ws and self.ws.open: | |
await self.ws.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment