Created
September 17, 2024 02:32
-
-
Save Jonovono/85ad4d5e7c4d0da753749516c72e4e12 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
import ssl | |
import pyaudio | |
from typing import cast, Dict | |
from qh3.asyncio.client import connect | |
from qh3.asyncio.protocol import QuicConnectionProtocol | |
from qh3.quic.configuration import QuicConfiguration | |
from qh3.quic.events import QuicEvent, StreamDataReceived | |
from docarray import BaseDoc | |
from docarray.typing import AudioBytes | |
logger = logging.getLogger("client") | |
# Audio settings for capturing from the microphone | |
CHUNK_SIZE = 1024 # Audio chunk size | |
FORMAT = pyaudio.paInt16 # 16-bit resolution | |
CHANNELS = 1 # 1 channel (mono audio) | |
RATE = 44100 # 44.1kHz sampling rate | |
class AudioDoc(BaseDoc): | |
bytes_: AudioBytes | |
class QuicAudioClientProtocol(QuicConnectionProtocol): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._data_buffers: Dict[int, bytearray] = {} | |
self._response_waiters: Dict[int, asyncio.Future] = {} | |
def quic_event_received(self, event: QuicEvent) -> None: | |
if isinstance(event, StreamDataReceived): | |
stream_id = event.stream_id | |
if stream_id not in self._data_buffers: | |
self._data_buffers[stream_id] = bytearray() | |
self._data_buffers[stream_id].extend(event.data) | |
logger.info(f"Received {len(event.data)} bytes from the server on stream {stream_id}") | |
if event.end_stream: | |
response_data = bytes(self._data_buffers[stream_id]) | |
if stream_id in self._response_waiters and not self._response_waiters[stream_id].done(): | |
self._response_waiters[stream_id].set_result(response_data) | |
del self._response_waiters[stream_id] | |
del self._data_buffers[stream_id] | |
async def send_audio(self, audio_data: bytes): | |
stream_id = self._quic.get_next_available_stream_id() | |
# Wrap audio data in AudioDoc and serialize it using docarray | |
doc = AudioDoc(bytes_=AudioBytes(audio_data)) | |
proto_data = doc.to_bytes(protocol='protobuf') | |
logger.info(f"Sending {len(proto_data)} bytes to the server on stream {stream_id}") | |
self._quic.send_stream_data(stream_id, proto_data, end_stream=True) | |
self.transmit() | |
# Wait for response data to accumulate | |
future = asyncio.Future() | |
self._response_waiters[stream_id] = future | |
response = await future | |
# Deserialize the response into an AudioDoc | |
try: | |
logger.info(f"Response size: {len(response)} bytes") | |
response_doc = AudioDoc.from_bytes(response, protocol='protobuf') | |
return response_doc.bytes_ | |
except Exception as e: | |
logger.error(f"Failed to deserialize response: {e}") | |
return b"" | |
async def audio_reader(audio_stream, input_queue): | |
loop = asyncio.get_event_loop() | |
try: | |
while True: | |
# Read audio data using executor to avoid blocking | |
audio_data = await loop.run_in_executor(None, audio_stream.read, CHUNK_SIZE) | |
await input_queue.put(audio_data) | |
except OSError as e: | |
if e.errno == -9981: | |
logger.error(f"Input overflowed: {e}") | |
else: | |
logger.error(f"Audio reader encountered an error: {e}") | |
except Exception as e: | |
logger.error(f"Audio reader encountered an unexpected error: {e}") | |
async def network_sender(client, input_queue, output_queue): | |
while True: | |
audio_data = await input_queue.get() | |
response_audio = await client.send_audio(audio_data) | |
if response_audio: | |
await output_queue.put(response_audio) | |
async def audio_player(playback_stream, output_queue): | |
loop = asyncio.get_event_loop() | |
while True: | |
audio_data = await output_queue.get() | |
# Write audio data using executor to avoid blocking | |
await loop.run_in_executor(None, playback_stream.write, audio_data) | |
async def stream_audio(configuration: QuicConfiguration, host: str, port: int) -> None: | |
audio_interface = pyaudio.PyAudio() | |
# Open the microphone for input | |
audio_stream = audio_interface.open(format=FORMAT, channels=CHANNELS, | |
rate=RATE, input=True, frames_per_buffer=CHUNK_SIZE) | |
# Open a stream for output (playback) | |
playback_stream = audio_interface.open(format=FORMAT, channels=CHANNELS, | |
rate=RATE, output=True) | |
input_queue = asyncio.Queue() | |
output_queue = asyncio.Queue() | |
async with connect( | |
host, port, configuration=configuration, create_protocol=QuicAudioClientProtocol | |
) as client: | |
client = cast(QuicAudioClientProtocol, client) | |
logger.info("Connected to %s:%d", host, port) | |
try: | |
reader_task = asyncio.create_task(audio_reader(audio_stream, input_queue)) | |
sender_task = asyncio.create_task(network_sender(client, input_queue, output_queue)) | |
player_task = asyncio.create_task(audio_player(playback_stream, output_queue)) | |
await asyncio.gather(reader_task, sender_task, player_task) | |
except KeyboardInterrupt: | |
pass | |
finally: | |
# Clean up streams and terminate the audio interface | |
try: | |
if audio_stream.is_active(): | |
audio_stream.stop_stream() | |
audio_stream.close() | |
except Exception as e: | |
logger.error(f"Error closing audio input stream: {e}") | |
try: | |
if playback_stream.is_active(): | |
playback_stream.stop_stream() | |
playback_stream.close() | |
except Exception as e: | |
logger.error(f"Error closing audio output stream: {e}") | |
audio_interface.terminate() | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
configuration = QuicConfiguration( | |
alpn_protocols=["h3"], | |
is_client=True, | |
) | |
# Disable certificate verification for testing purposes | |
configuration.verify_mode = ssl.CERT_NONE | |
# host = "localhost" | |
host = "34.121.136.5" | |
port = 4433 | |
try: | |
asyncio.run(stream_audio(configuration=configuration, host=host, port=port)) | |
except KeyboardInterrupt: | |
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
from typing import Dict, Optional | |
from qh3.asyncio import serve | |
from qh3.asyncio.protocol import QuicConnectionProtocol | |
from qh3.quic.configuration import QuicConfiguration | |
from qh3.quic.events import QuicEvent, StreamDataReceived, ProtocolNegotiated | |
from qh3.h3.connection import H3Connection, H3_ALPN | |
from qh3.h3.events import HeadersReceived, DataReceived | |
from docarray import BaseDoc | |
from docarray.typing import AudioBytes | |
logger = logging.getLogger("server") | |
class AudioDoc(BaseDoc): | |
bytes_: AudioBytes | |
import asyncio | |
import logging | |
from typing import Dict, Optional | |
from qh3.asyncio import serve | |
from qh3.asyncio.protocol import QuicConnectionProtocol | |
from qh3.quic.configuration import QuicConfiguration | |
from qh3.quic.events import QuicEvent, StreamDataReceived, ProtocolNegotiated | |
from qh3.h3.connection import H3Connection, H3_ALPN | |
from docarray.typing import AudioBytes | |
class QuicAudioServerProtocol(QuicConnectionProtocol): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._streams: Dict[int, bytearray] = {} | |
self._http: Optional[H3Connection] = None | |
def quic_event_received(self, event: QuicEvent) -> None: | |
if isinstance(event, ProtocolNegotiated): | |
logger.info("ALPN protocol negotiated: %s", event.alpn_protocol) | |
if event.alpn_protocol in H3_ALPN: | |
self._http = H3Connection(self._quic) | |
elif isinstance(event, StreamDataReceived): | |
stream_id = event.stream_id | |
data = event.data | |
# Accumulate data for the stream | |
if stream_id not in self._streams: | |
self._streams[stream_id] = bytearray() | |
self._streams[stream_id].extend(data) | |
logger.info(f"Received {len(data)} bytes on stream {stream_id}") | |
# If the stream has ended, process it | |
if event.end_stream: | |
asyncio.create_task(self.handle_stream(stream_id)) | |
if self._http: | |
for http_event in self._http.handle_event(event): | |
self.http_event_received(http_event) | |
def http_event_received(self, event: QuicEvent) -> None: | |
if isinstance(event, HeadersReceived): | |
logger.info(f"Headers received: {event.headers}") | |
elif isinstance(event, DataReceived): | |
logger.info(f"HTTP/3 data received: {event.data}") | |
async def handle_stream(self, stream_id: int): | |
try: | |
data = bytes(self._streams[stream_id]) | |
logger.info(f"Accumulated {len(data)} bytes of audio data on stream {stream_id}") | |
# Deserialize the data | |
# doc = AudioDoc.from_bytes(data, protocol='protobuf') | |
# (Optional) Process the audio data here | |
# Serialize the doc back to bytes | |
# response_data = doc.to_bytes(protocol='protobuf') | |
# Send the response back to the client | |
self._quic.send_stream_data(stream_id, data, end_stream=True) | |
self.transmit() | |
# Clear the buffer after sending the data | |
del self._streams[stream_id] | |
except Exception as e: | |
logger.error(f"Failed to process stream {stream_id}: {e}") | |
# host = "localhost" | |
host = "0.0.0.0" | |
# host = "34.121.136.5" | |
port = 4433 | |
async def main(host: str, port: int, configuration: QuicConfiguration) -> None: | |
await serve( | |
host, port, configuration=configuration, create_protocol=QuicAudioServerProtocol | |
) | |
await asyncio.Future() # run forever | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
configuration = QuicConfiguration( | |
alpn_protocols=H3_ALPN, # Enabling HTTP/3 ALPN | |
is_client=False, | |
) | |
configuration.load_cert_chain("../cert.pem", "../key.pem") | |
try: | |
print("Audio echo server running on QUIC with HTTP/3 support") | |
asyncio.run(main(host=host, port=4433, configuration=configuration)) | |
except KeyboardInterrupt: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment