Skip to content

Instantly share code, notes, and snippets.

@kwindla
Created December 15, 2024 03:32
Show Gist options
  • Select an option

  • Save kwindla/917ef95ae16204dfc9bb35fbbf612086 to your computer and use it in GitHub Desktop.

Select an option

Save kwindla/917ef95ae16204dfc9bb35fbbf612086 to your computer and use it in GitHub Desktop.
Gemini Multimodal Live compositional function calling test
# usage: gemini-lights-repro.py --initial-message "Make a pretty sequence with the lights. Turn them on and off a few times. Pick nice colors."
import argparse
import asyncio
import base64
import json
import os
import pyaudio
import re
import shutil
import websockets
MODEL = "models/gemini-2.0-flash-exp"
HOST = "generativelanguage.googleapis.com"
SYSTEM_INSTRUCTION = {
"parts": [
{
"text": "You have access to two functions: turn_on_the_lights to turn the lights on and turn_off_the_lights to turn the lights off. Use these functions if you need to manipulate the lights. If the user asks to turn on the lights without specifying a color, use a default color that is a warm white."
}
]
}
MIC_SAMPLE_RATE = 16000
SPEAKER_SAMPLE_RATE = 24000
FORMAT = "S16_LE"
CHANNELS = 1
INITIAL_MESSAGE = ""
SEARCH = False
CODE_EXECUTION = False
def parse_args():
parser = argparse.ArgumentParser(description="Gemini Talk with optional search functionality")
parser.add_argument("--use-search", action=argparse.BooleanOptionalAction)
parser.add_argument("--use-code-execution", action=argparse.BooleanOptionalAction)
parser.add_argument("--initial-message", type=str)
args = parser.parse_args()
global INITIAL_MESSAGE, SEARCH, CODE_EXECUTION
INITIAL_MESSAGE = args.initial_message
SEARCH = args.use_search
CODE_EXECUTION = args.use_code_execution
class AudioStreamer:
def __init__(self):
self.p = pyaudio.PyAudio()
self.speaker_audio_buffer = bytearray()
self.running = True
self.event_loop = None
self.pending_tool_calls = []
def mic_audio_in_callback(self, in_data, frame_count, time_info, status):
async def send_audio(self, raw_audio):
payload = base64.b64encode(raw_audio).decode("utf-8")
try:
msg = json.dumps(
{
"realtimeInput": {
"mediaChunks": [
{
"mimeType": f"audio/pcm;rate={MIC_SAMPLE_RATE}",
"data": payload,
}
],
},
}
)
await self.ws.send(msg)
except Exception as e:
print(f"Exception: {e}")
self.running = False
if not self.running:
return (None, pyaudio.paComplete)
self.event_loop.create_task(send_audio(self, in_data))
return (None, pyaudio.paContinue)
def speaker_audio_out_callback(self, in_data, frame_count, time_info, status):
if not self.running:
return (bytes(frame_count * CHANNELS * 2), pyaudio.paComplete)
audio = bytes(self.speaker_audio_buffer[: frame_count * CHANNELS * 2])
del self.speaker_audio_buffer[: frame_count * CHANNELS * 2]
audio += b"\0" * (frame_count * CHANNELS * 2 - len(audio))
return (audio, pyaudio.paContinue)
async def handle_tool_call(self, tool_call):
# print(f" <- handling tool call {tool_call}")
responses = []
for f in tool_call.get("functionCalls", []):
print(f" <- Function call: {f}")
responses.append(
{
"id": f.get("id"),
"name": f.get("name"),
"response": {"status": "success"},
}
)
msg = json.dumps(
{
"toolResponse": {
"functionResponses": responses,
}
}
)
print(f" -> {msg}")
await self.ws.send(msg)
async def print_audio_output_buffer_info(self):
while self.running:
if self.speaker_audio_buffer:
print(
f"Current audio buffer size: {len(self.speaker_audio_buffer) / (SPEAKER_SAMPLE_RATE * 2):.2f} seconds"
)
await asyncio.sleep(2)
def print_evt(self, evt, response):
columns, rows = shutil.get_terminal_size()
maxl = columns - 5
print(str(evt)[:maxl] + " ...")
if grounding := evt.get("serverContent", {}).get("groundingMetadata"):
for chunk in grounding.get("groundingChunks", []):
print(f" <- {chunk.get("web").get("title")}")
if parts := evt.get("serverContent", {}).get("modelTurn", {}).get("parts"):
for part in parts:
if part.get("inlineData") or part.get("text"):
continue
print(f" <- {part}")
async def ws_receive_worker(self):
try:
async for m in self.ws:
if not self.running:
break
evt = json.loads(m)
self.print_evt(evt, m)
if evt.get("setupComplete", None) is not None:
await self.send_initial_message()
print("Ready: say something to Gemini")
self.mic_audio_in.start_stream()
elif sc := evt.get("serverContent"):
if sc.get("interrupted"):
print("Interrupted by server")
self.speaker_audio_buffer.clear()
for tool_call in self.pending_tool_calls:
await self.handle_tool_call(tool_call)
self.pending_tool_calls.clear()
continue
if parts := sc.get("modelTurn", {}).get("parts"):
if text := parts[0].get("text"):
print(f" <- {text}")
elif inline_data := parts[0].get("inlineData"):
mime_str = inline_data.get("mimeType")
mime_type, sample_rate = re.match(
r"([\w/]+);rate=(\d+)", mime_str
).groups()
if mime_type == "audio/pcm" and sample_rate == str(SPEAKER_SAMPLE_RATE):
audio = base64.b64decode(inline_data.get("data"))
self.speaker_audio_buffer.extend(audio)
else:
print(f"Unsupported mime type or sample rate: {mime_str}")
if code := parts[0].get("executableCode"):
pass
elif tool_call := evt.get("toolCall"):
# self.pending_tool_calls.append(tool_call)
await self.handle_tool_call(tool_call)
except Exception as e:
print(f"Exception: {e}")
self.running = False
async def setup_model(self):
try:
setup = {
"setup": {
"model": MODEL,
"generation_config": {
# "response_modalities": ["AUDIO"],
"response_modalities": ["TEXT"],
"speech_config": {
"voice_config": {"prebuilt_voice_config": {"voice_name": "Charon"}},
},
},
"system_instruction": SYSTEM_INSTRUCTION,
"tools": [
{
"function_declarations": [
{
"name": "turn_on_the_lights",
"description": "Turns the lights on. Color argument is required, but default to a warm white color unless specified.",
"parameters": {
"type": "object",
"properties": {
"color": {
"type": "string",
"description": "Color in hex format",
},
},
"required": ["color"],
},
},
{
"name": "turn_off_the_lights",
"description": "Turns the lights off. Takes no argument.",
"parameters": None,
},
{"name": "turn_off_the_lights"},
]
}
],
},
}
if SEARCH:
print("Search enabled")
setup["setup"]["tools"].append({"google_search": {}})
if CODE_EXECUTION:
print("Code execution enabled")
setup["setup"]["tools"].append({"code_execution": {}})
print("Sending setup", setup)
await self.ws.send(json.dumps(setup))
except Exception as e:
print(f"Exception: {e}")
async def send_initial_message(self):
try:
if INITIAL_MESSAGE:
initial_message = {
"client_content": {
"turns": [
{"parts": [{"text": INITIAL_MESSAGE}], "role": "user"},
],
"turn_complete": True,
}
}
print("Sending initial message", initial_message)
await self.ws.send(json.dumps(initial_message))
except Exception as e:
print(f"Exception: {e}")
async def run(self):
self.event_loop = asyncio.get_event_loop()
self.mic_audio_in = self.p.open(
format=pyaudio.paInt16,
channels=CHANNELS,
rate=MIC_SAMPLE_RATE,
input=True,
stream_callback=self.mic_audio_in_callback,
frames_per_buffer=int(MIC_SAMPLE_RATE / 1000) * 2 * 50, # 50ms (S16_LE is 2 bytes)
start=False,
)
self.speaker_audio_out = self.p.open(
format=pyaudio.paInt16,
channels=1,
rate=SPEAKER_SAMPLE_RATE,
output=True,
frames_per_buffer=256,
stream_callback=self.speaker_audio_out_callback,
)
try:
self.ws = await websockets.connect(
uri=f'wss://{HOST}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={os.getenv("GEMINI_API_KEY")}'
)
print("Connected to Gemini")
except Exception as e:
print(f"Exception: {e}")
return
asyncio.create_task(self.ws_receive_worker())
asyncio.create_task(self.print_audio_output_buffer_info())
try:
await self.setup_model()
while self.running:
await asyncio.sleep(1)
except asyncio.CancelledError:
pass
except Exception as e:
print(f"Exception: {e}")
finally:
print("Exiting...")
self.running = False
self.p.terminate()
await self.ws.close()
# await self.cleanup()
if __name__ == "__main__":
parse_args()
asyncio.run(AudioStreamer().run())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment