-
-
Save mbaneshi/d81bf3d9312e21ee7b24637fb020e0ec to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import base64 | |
import json | |
import os | |
import pyaudio | |
from websockets.asyncio.client import connect | |
class SimpleGeminiVoice: | |
def __init__(self): | |
self.audio_queue = asyncio.Queue() | |
self.api_key = os.environ.get("GEMINI_API_KEY") | |
self.model = "gemini-2.0-flash-exp" | |
self.uri = f"wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}" | |
# Audio settings | |
self.FORMAT = pyaudio.paInt16 | |
self.CHANNELS = 1 | |
self.CHUNK = 512 | |
self.RATE = 16000 | |
async def start(self): | |
# Initialize websocket | |
self.ws = await connect( | |
self.uri, additional_headers={"Content-Type": "application/json"} | |
) | |
await self.ws.send(json.dumps({"setup": {"model": f"models/{self.model}"}})) | |
await self.ws.recv(decode=False) | |
print("Connected to Gemini, You can start talking now") | |
# Start audio streaming | |
async with asyncio.TaskGroup() as tg: | |
tg.create_task(self.capture_audio()) | |
tg.create_task(self.stream_audio()) | |
tg.create_task(self.play_response()) | |
async def capture_audio(self): | |
audio = pyaudio.PyAudio() | |
stream = audio.open( | |
format=self.FORMAT, | |
channels=self.CHANNELS, | |
rate=self.RATE, | |
input=True, | |
frames_per_buffer=self.CHUNK, | |
) | |
while True: | |
data = await asyncio.to_thread(stream.read, self.CHUNK) | |
await self.ws.send( | |
json.dumps( | |
{ | |
"realtime_input": { | |
"media_chunks": [ | |
{ | |
"data": base64.b64encode(data).decode(), | |
"mime_type": "audio/pcm", | |
} | |
] | |
} | |
} | |
) | |
) | |
async def stream_audio(self): | |
async for msg in self.ws: | |
response = json.loads(msg) | |
try: | |
audio_data = response["serverContent"]["modelTurn"]["parts"][0][ | |
"inlineData" | |
]["data"] | |
self.audio_queue.put_nowait(base64.b64decode(audio_data)) | |
except KeyError: | |
pass | |
try: | |
turn_complete = response["serverContent"]["turnComplete"] | |
except KeyError: | |
pass | |
else: | |
if turn_complete: | |
# If you interrupt the model, it sends an end_of_turn. For interruptions to work, we need to empty out the audio queue | |
print("\nEnd of turn") | |
while not self.audio_queue.empty(): | |
self.audio_queue.get_nowait() | |
async def play_response(self): | |
audio = pyaudio.PyAudio() | |
stream = audio.open( | |
format=self.FORMAT, channels=self.CHANNELS, rate=24000, output=True | |
) | |
while True: | |
data = await self.audio_queue.get() | |
await asyncio.to_thread(stream.write, data) | |
if __name__ == "__main__": | |
client = SimpleGeminiVoice() | |
asyncio.run(client.start()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import base64 | |
import json | |
import os | |
import pyaudio | |
from websockets.asyncio.client import connect | |
class SimpleGeminiVoice: | |
def __init__(self): | |
self.api_key = os.environ.get("GEMINI_API_KEY") | |
self.model = "gemini-2.0-flash-exp" | |
self.uri = f"wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}" | |
# Audio settings | |
self.FORMAT = pyaudio.paInt16 | |
self.CHANNELS = 1 | |
self.CHUNK = 512 | |
async def start(self): | |
# Initialize websocket | |
self.ws = await connect( | |
self.uri, additional_headers={"Content-Type": "application/json"} | |
) | |
await self.ws.send(json.dumps({"setup": {"model": f"models/{self.model}"}})) | |
await self.ws.recv(decode=False) | |
print("Connected to Gemini, You can start talking now") | |
# Start audio streaming | |
async with asyncio.TaskGroup() as tg: | |
tg.create_task(self.send_user_audio()) | |
tg.create_task(self.recv_model_audio()) | |
async def send_user_audio(self): | |
audio = pyaudio.PyAudio() | |
stream = audio.open( | |
format=self.FORMAT, | |
channels=self.CHANNELS, | |
rate="16000", | |
input=True, | |
frames_per_buffer=self.CHUNK, | |
) | |
while True: | |
data = await asyncio.to_thread(stream.read, self.CHUNK) | |
await self.ws.send( | |
json.dumps( | |
{ | |
"realtime_input": { | |
"media_chunks": [ | |
{ | |
"data": base64.b64encode(data).decode(), | |
"mime_type": "audio/pcm", | |
} | |
] | |
} | |
} | |
) | |
) | |
async def recv_model_audio(self): | |
audio = pyaudio.PyAudio() | |
stream = audio.open( | |
format=self.FORMAT, channels=self.CHANNELS, rate=24000, output=True | |
) | |
async for msg in self.ws: | |
response = json.loads(msg) | |
try: | |
audio_data = response["serverContent"]["modelTurn"]["parts"][0][ | |
"inlineData" | |
]["data"] | |
await asyncio.to_thread(stream.write, base64.b64decode(audio_data)) | |
except KeyError: | |
pass | |
if __name__ == "__main__": | |
client = SimpleGeminiVoice() | |
asyncio.run(client.start()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment