Skip to content

Instantly share code, notes, and snippets.

@Jonovono
Created September 17, 2024 02:32
Show Gist options
  • Save Jonovono/85ad4d5e7c4d0da753749516c72e4e12 to your computer and use it in GitHub Desktop.
Save Jonovono/85ad4d5e7c4d0da753749516c72e4e12 to your computer and use it in GitHub Desktop.
import asyncio
import logging
import ssl
import pyaudio
from typing import cast, Dict
from qh3.asyncio.client import connect
from qh3.asyncio.protocol import QuicConnectionProtocol
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.events import QuicEvent, StreamDataReceived
from docarray import BaseDoc
from docarray.typing import AudioBytes
logger = logging.getLogger("client")
# Audio settings for capturing from the microphone
CHUNK_SIZE = 1024 # Audio chunk size
FORMAT = pyaudio.paInt16 # 16-bit resolution
CHANNELS = 1 # 1 channel (mono audio)
RATE = 44100 # 44.1kHz sampling rate
class AudioDoc(BaseDoc):
bytes_: AudioBytes
class QuicAudioClientProtocol(QuicConnectionProtocol):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._data_buffers: Dict[int, bytearray] = {}
self._response_waiters: Dict[int, asyncio.Future] = {}
def quic_event_received(self, event: QuicEvent) -> None:
if isinstance(event, StreamDataReceived):
stream_id = event.stream_id
if stream_id not in self._data_buffers:
self._data_buffers[stream_id] = bytearray()
self._data_buffers[stream_id].extend(event.data)
logger.info(f"Received {len(event.data)} bytes from the server on stream {stream_id}")
if event.end_stream:
response_data = bytes(self._data_buffers[stream_id])
if stream_id in self._response_waiters and not self._response_waiters[stream_id].done():
self._response_waiters[stream_id].set_result(response_data)
del self._response_waiters[stream_id]
del self._data_buffers[stream_id]
async def send_audio(self, audio_data: bytes):
stream_id = self._quic.get_next_available_stream_id()
# Wrap audio data in AudioDoc and serialize it using docarray
doc = AudioDoc(bytes_=AudioBytes(audio_data))
proto_data = doc.to_bytes(protocol='protobuf')
logger.info(f"Sending {len(proto_data)} bytes to the server on stream {stream_id}")
self._quic.send_stream_data(stream_id, proto_data, end_stream=True)
self.transmit()
# Wait for response data to accumulate
future = asyncio.Future()
self._response_waiters[stream_id] = future
response = await future
# Deserialize the response into an AudioDoc
try:
logger.info(f"Response size: {len(response)} bytes")
response_doc = AudioDoc.from_bytes(response, protocol='protobuf')
return response_doc.bytes_
except Exception as e:
logger.error(f"Failed to deserialize response: {e}")
return b""
async def audio_reader(audio_stream, input_queue):
loop = asyncio.get_event_loop()
try:
while True:
# Read audio data using executor to avoid blocking
audio_data = await loop.run_in_executor(None, audio_stream.read, CHUNK_SIZE)
await input_queue.put(audio_data)
except OSError as e:
if e.errno == -9981:
logger.error(f"Input overflowed: {e}")
else:
logger.error(f"Audio reader encountered an error: {e}")
except Exception as e:
logger.error(f"Audio reader encountered an unexpected error: {e}")
async def network_sender(client, input_queue, output_queue):
while True:
audio_data = await input_queue.get()
response_audio = await client.send_audio(audio_data)
if response_audio:
await output_queue.put(response_audio)
async def audio_player(playback_stream, output_queue):
loop = asyncio.get_event_loop()
while True:
audio_data = await output_queue.get()
# Write audio data using executor to avoid blocking
await loop.run_in_executor(None, playback_stream.write, audio_data)
async def stream_audio(configuration: QuicConfiguration, host: str, port: int) -> None:
audio_interface = pyaudio.PyAudio()
# Open the microphone for input
audio_stream = audio_interface.open(format=FORMAT, channels=CHANNELS,
rate=RATE, input=True, frames_per_buffer=CHUNK_SIZE)
# Open a stream for output (playback)
playback_stream = audio_interface.open(format=FORMAT, channels=CHANNELS,
rate=RATE, output=True)
input_queue = asyncio.Queue()
output_queue = asyncio.Queue()
async with connect(
host, port, configuration=configuration, create_protocol=QuicAudioClientProtocol
) as client:
client = cast(QuicAudioClientProtocol, client)
logger.info("Connected to %s:%d", host, port)
try:
reader_task = asyncio.create_task(audio_reader(audio_stream, input_queue))
sender_task = asyncio.create_task(network_sender(client, input_queue, output_queue))
player_task = asyncio.create_task(audio_player(playback_stream, output_queue))
await asyncio.gather(reader_task, sender_task, player_task)
except KeyboardInterrupt:
pass
finally:
# Clean up streams and terminate the audio interface
try:
if audio_stream.is_active():
audio_stream.stop_stream()
audio_stream.close()
except Exception as e:
logger.error(f"Error closing audio input stream: {e}")
try:
if playback_stream.is_active():
playback_stream.stop_stream()
playback_stream.close()
except Exception as e:
logger.error(f"Error closing audio output stream: {e}")
audio_interface.terminate()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
configuration = QuicConfiguration(
alpn_protocols=["h3"],
is_client=True,
)
# Disable certificate verification for testing purposes
configuration.verify_mode = ssl.CERT_NONE
# host = "localhost"
host = "34.121.136.5"
port = 4433
try:
asyncio.run(stream_audio(configuration=configuration, host=host, port=port))
except KeyboardInterrupt:
pass
import asyncio
import logging
from typing import Dict, Optional
from qh3.asyncio import serve
from qh3.asyncio.protocol import QuicConnectionProtocol
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.events import QuicEvent, StreamDataReceived, ProtocolNegotiated
from qh3.h3.connection import H3Connection, H3_ALPN
from qh3.h3.events import HeadersReceived, DataReceived
from docarray import BaseDoc
from docarray.typing import AudioBytes
logger = logging.getLogger("server")
class AudioDoc(BaseDoc):
bytes_: AudioBytes
import asyncio
import logging
from typing import Dict, Optional
from qh3.asyncio import serve
from qh3.asyncio.protocol import QuicConnectionProtocol
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.events import QuicEvent, StreamDataReceived, ProtocolNegotiated
from qh3.h3.connection import H3Connection, H3_ALPN
from docarray.typing import AudioBytes
class QuicAudioServerProtocol(QuicConnectionProtocol):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._streams: Dict[int, bytearray] = {}
self._http: Optional[H3Connection] = None
def quic_event_received(self, event: QuicEvent) -> None:
if isinstance(event, ProtocolNegotiated):
logger.info("ALPN protocol negotiated: %s", event.alpn_protocol)
if event.alpn_protocol in H3_ALPN:
self._http = H3Connection(self._quic)
elif isinstance(event, StreamDataReceived):
stream_id = event.stream_id
data = event.data
# Accumulate data for the stream
if stream_id not in self._streams:
self._streams[stream_id] = bytearray()
self._streams[stream_id].extend(data)
logger.info(f"Received {len(data)} bytes on stream {stream_id}")
# If the stream has ended, process it
if event.end_stream:
asyncio.create_task(self.handle_stream(stream_id))
if self._http:
for http_event in self._http.handle_event(event):
self.http_event_received(http_event)
def http_event_received(self, event: QuicEvent) -> None:
if isinstance(event, HeadersReceived):
logger.info(f"Headers received: {event.headers}")
elif isinstance(event, DataReceived):
logger.info(f"HTTP/3 data received: {event.data}")
async def handle_stream(self, stream_id: int):
try:
data = bytes(self._streams[stream_id])
logger.info(f"Accumulated {len(data)} bytes of audio data on stream {stream_id}")
# Deserialize the data
# doc = AudioDoc.from_bytes(data, protocol='protobuf')
# (Optional) Process the audio data here
# Serialize the doc back to bytes
# response_data = doc.to_bytes(protocol='protobuf')
# Send the response back to the client
self._quic.send_stream_data(stream_id, data, end_stream=True)
self.transmit()
# Clear the buffer after sending the data
del self._streams[stream_id]
except Exception as e:
logger.error(f"Failed to process stream {stream_id}: {e}")
# host = "localhost"
host = "0.0.0.0"
# host = "34.121.136.5"
port = 4433
async def main(host: str, port: int, configuration: QuicConfiguration) -> None:
await serve(
host, port, configuration=configuration, create_protocol=QuicAudioServerProtocol
)
await asyncio.Future() # run forever
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
configuration = QuicConfiguration(
alpn_protocols=H3_ALPN, # Enabling HTTP/3 ALPN
is_client=False,
)
configuration.load_cert_chain("../cert.pem", "../key.pem")
try:
print("Audio echo server running on QUIC with HTTP/3 support")
asyncio.run(main(host=host, port=4433, configuration=configuration))
except KeyboardInterrupt:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment