Created
December 15, 2024 03:32
-
-
Save kwindla/917ef95ae16204dfc9bb35fbbf612086 to your computer and use it in GitHub Desktop.
Gemini Multimodal Live compositional function calling test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # usage: gemini-lights-repro.py --initial-message "Make a pretty sequence with the lights. Turn them on and off a few times. Pick nice colors." | |
| import argparse | |
| import asyncio | |
| import base64 | |
| import json | |
| import os | |
| import pyaudio | |
| import re | |
| import shutil | |
| import websockets | |
| MODEL = "models/gemini-2.0-flash-exp" | |
| HOST = "generativelanguage.googleapis.com" | |
| SYSTEM_INSTRUCTION = { | |
| "parts": [ | |
| { | |
| "text": "You have access to two functions: turn_on_the_lights to turn the lights on and turn_off_the_lights to turn the lights off. Use these functions if you need to manipulate the lights. If the user asks to turn on the lights without specifying a color, use a default color that is a warm white." | |
| } | |
| ] | |
| } | |
| MIC_SAMPLE_RATE = 16000 | |
| SPEAKER_SAMPLE_RATE = 24000 | |
| FORMAT = "S16_LE" | |
| CHANNELS = 1 | |
| INITIAL_MESSAGE = "" | |
| SEARCH = False | |
| CODE_EXECUTION = False | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Gemini Talk with optional search functionality") | |
| parser.add_argument("--use-search", action=argparse.BooleanOptionalAction) | |
| parser.add_argument("--use-code-execution", action=argparse.BooleanOptionalAction) | |
| parser.add_argument("--initial-message", type=str) | |
| args = parser.parse_args() | |
| global INITIAL_MESSAGE, SEARCH, CODE_EXECUTION | |
| INITIAL_MESSAGE = args.initial_message | |
| SEARCH = args.use_search | |
| CODE_EXECUTION = args.use_code_execution | |
| class AudioStreamer: | |
| def __init__(self): | |
| self.p = pyaudio.PyAudio() | |
| self.speaker_audio_buffer = bytearray() | |
| self.running = True | |
| self.event_loop = None | |
| self.pending_tool_calls = [] | |
| def mic_audio_in_callback(self, in_data, frame_count, time_info, status): | |
| async def send_audio(self, raw_audio): | |
| payload = base64.b64encode(raw_audio).decode("utf-8") | |
| try: | |
| msg = json.dumps( | |
| { | |
| "realtimeInput": { | |
| "mediaChunks": [ | |
| { | |
| "mimeType": f"audio/pcm;rate={MIC_SAMPLE_RATE}", | |
| "data": payload, | |
| } | |
| ], | |
| }, | |
| } | |
| ) | |
| await self.ws.send(msg) | |
| except Exception as e: | |
| print(f"Exception: {e}") | |
| self.running = False | |
| if not self.running: | |
| return (None, pyaudio.paComplete) | |
| self.event_loop.create_task(send_audio(self, in_data)) | |
| return (None, pyaudio.paContinue) | |
| def speaker_audio_out_callback(self, in_data, frame_count, time_info, status): | |
| if not self.running: | |
| return (bytes(frame_count * CHANNELS * 2), pyaudio.paComplete) | |
| audio = bytes(self.speaker_audio_buffer[: frame_count * CHANNELS * 2]) | |
| del self.speaker_audio_buffer[: frame_count * CHANNELS * 2] | |
| audio += b"\0" * (frame_count * CHANNELS * 2 - len(audio)) | |
| return (audio, pyaudio.paContinue) | |
| async def handle_tool_call(self, tool_call): | |
| # print(f" <- handling tool call {tool_call}") | |
| responses = [] | |
| for f in tool_call.get("functionCalls", []): | |
| print(f" <- Function call: {f}") | |
| responses.append( | |
| { | |
| "id": f.get("id"), | |
| "name": f.get("name"), | |
| "response": {"status": "success"}, | |
| } | |
| ) | |
| msg = json.dumps( | |
| { | |
| "toolResponse": { | |
| "functionResponses": responses, | |
| } | |
| } | |
| ) | |
| print(f" -> {msg}") | |
| await self.ws.send(msg) | |
| async def print_audio_output_buffer_info(self): | |
| while self.running: | |
| if self.speaker_audio_buffer: | |
| print( | |
| f"Current audio buffer size: {len(self.speaker_audio_buffer) / (SPEAKER_SAMPLE_RATE * 2):.2f} seconds" | |
| ) | |
| await asyncio.sleep(2) | |
| def print_evt(self, evt, response): | |
| columns, rows = shutil.get_terminal_size() | |
| maxl = columns - 5 | |
| print(str(evt)[:maxl] + " ...") | |
| if grounding := evt.get("serverContent", {}).get("groundingMetadata"): | |
| for chunk in grounding.get("groundingChunks", []): | |
| print(f" <- {chunk.get("web").get("title")}") | |
| if parts := evt.get("serverContent", {}).get("modelTurn", {}).get("parts"): | |
| for part in parts: | |
| if part.get("inlineData") or part.get("text"): | |
| continue | |
| print(f" <- {part}") | |
| async def ws_receive_worker(self): | |
| try: | |
| async for m in self.ws: | |
| if not self.running: | |
| break | |
| evt = json.loads(m) | |
| self.print_evt(evt, m) | |
| if evt.get("setupComplete", None) is not None: | |
| await self.send_initial_message() | |
| print("Ready: say something to Gemini") | |
| self.mic_audio_in.start_stream() | |
| elif sc := evt.get("serverContent"): | |
| if sc.get("interrupted"): | |
| print("Interrupted by server") | |
| self.speaker_audio_buffer.clear() | |
| for tool_call in self.pending_tool_calls: | |
| await self.handle_tool_call(tool_call) | |
| self.pending_tool_calls.clear() | |
| continue | |
| if parts := sc.get("modelTurn", {}).get("parts"): | |
| if text := parts[0].get("text"): | |
| print(f" <- {text}") | |
| elif inline_data := parts[0].get("inlineData"): | |
| mime_str = inline_data.get("mimeType") | |
| mime_type, sample_rate = re.match( | |
| r"([\w/]+);rate=(\d+)", mime_str | |
| ).groups() | |
| if mime_type == "audio/pcm" and sample_rate == str(SPEAKER_SAMPLE_RATE): | |
| audio = base64.b64decode(inline_data.get("data")) | |
| self.speaker_audio_buffer.extend(audio) | |
| else: | |
| print(f"Unsupported mime type or sample rate: {mime_str}") | |
| if code := parts[0].get("executableCode"): | |
| pass | |
| elif tool_call := evt.get("toolCall"): | |
| # self.pending_tool_calls.append(tool_call) | |
| await self.handle_tool_call(tool_call) | |
| except Exception as e: | |
| print(f"Exception: {e}") | |
| self.running = False | |
| async def setup_model(self): | |
| try: | |
| setup = { | |
| "setup": { | |
| "model": MODEL, | |
| "generation_config": { | |
| # "response_modalities": ["AUDIO"], | |
| "response_modalities": ["TEXT"], | |
| "speech_config": { | |
| "voice_config": {"prebuilt_voice_config": {"voice_name": "Charon"}}, | |
| }, | |
| }, | |
| "system_instruction": SYSTEM_INSTRUCTION, | |
| "tools": [ | |
| { | |
| "function_declarations": [ | |
| { | |
| "name": "turn_on_the_lights", | |
| "description": "Turns the lights on. Color argument is required, but default to a warm white color unless specified.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "color": { | |
| "type": "string", | |
| "description": "Color in hex format", | |
| }, | |
| }, | |
| "required": ["color"], | |
| }, | |
| }, | |
| { | |
| "name": "turn_off_the_lights", | |
| "description": "Turns the lights off. Takes no argument.", | |
| "parameters": None, | |
| }, | |
| {"name": "turn_off_the_lights"}, | |
| ] | |
| } | |
| ], | |
| }, | |
| } | |
| if SEARCH: | |
| print("Search enabled") | |
| setup["setup"]["tools"].append({"google_search": {}}) | |
| if CODE_EXECUTION: | |
| print("Code execution enabled") | |
| setup["setup"]["tools"].append({"code_execution": {}}) | |
| print("Sending setup", setup) | |
| await self.ws.send(json.dumps(setup)) | |
| except Exception as e: | |
| print(f"Exception: {e}") | |
| async def send_initial_message(self): | |
| try: | |
| if INITIAL_MESSAGE: | |
| initial_message = { | |
| "client_content": { | |
| "turns": [ | |
| {"parts": [{"text": INITIAL_MESSAGE}], "role": "user"}, | |
| ], | |
| "turn_complete": True, | |
| } | |
| } | |
| print("Sending initial message", initial_message) | |
| await self.ws.send(json.dumps(initial_message)) | |
| except Exception as e: | |
| print(f"Exception: {e}") | |
| async def run(self): | |
| self.event_loop = asyncio.get_event_loop() | |
| self.mic_audio_in = self.p.open( | |
| format=pyaudio.paInt16, | |
| channels=CHANNELS, | |
| rate=MIC_SAMPLE_RATE, | |
| input=True, | |
| stream_callback=self.mic_audio_in_callback, | |
| frames_per_buffer=int(MIC_SAMPLE_RATE / 1000) * 2 * 50, # 50ms (S16_LE is 2 bytes) | |
| start=False, | |
| ) | |
| self.speaker_audio_out = self.p.open( | |
| format=pyaudio.paInt16, | |
| channels=1, | |
| rate=SPEAKER_SAMPLE_RATE, | |
| output=True, | |
| frames_per_buffer=256, | |
| stream_callback=self.speaker_audio_out_callback, | |
| ) | |
| try: | |
| self.ws = await websockets.connect( | |
| uri=f'wss://{HOST}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={os.getenv("GEMINI_API_KEY")}' | |
| ) | |
| print("Connected to Gemini") | |
| except Exception as e: | |
| print(f"Exception: {e}") | |
| return | |
| asyncio.create_task(self.ws_receive_worker()) | |
| asyncio.create_task(self.print_audio_output_buffer_info()) | |
| try: | |
| await self.setup_model() | |
| while self.running: | |
| await asyncio.sleep(1) | |
| except asyncio.CancelledError: | |
| pass | |
| except Exception as e: | |
| print(f"Exception: {e}") | |
| finally: | |
| print("Exiting...") | |
| self.running = False | |
| self.p.terminate() | |
| await self.ws.close() | |
| # await self.cleanup() | |
| if __name__ == "__main__": | |
| parse_args() | |
| asyncio.run(AudioStreamer().run()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment