Last active
November 17, 2023 05:09
-
-
Save aleksandr-smechov/c789caa0b65772865a3dc1e60e0f2c5d to your computer and use it in GitHub Desktop.
Client-side distil-whisper streaming script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import platform | |
import subprocess | |
import sys | |
import asyncio | |
import websockets | |
import numpy as np | |
from typing import Tuple, Optional, Union | |
def _ffmpeg_stream(ffmpeg_command, buflen: int): | |
""" | |
Internal function to create the generator of data through ffmpeg | |
""" | |
bufsize = 2**24 | |
try: | |
with subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, bufsize=bufsize) as ffmpeg_process: | |
while True: | |
raw = ffmpeg_process.stdout.read(buflen) | |
if raw == b"": | |
break | |
yield raw | |
except FileNotFoundError as error: | |
raise ValueError("ffmpeg was not found but is required to stream audio files from filename") from error | |
def ffmpeg_microphone( | |
sampling_rate: int, | |
chunk_length_s: float, | |
format_for_conversion: str = "f32le", | |
): | |
""" | |
Helper function ro read raw microphone data. | |
""" | |
ar = f"{sampling_rate}" | |
ac = "1" | |
if format_for_conversion == "s16le": | |
size_of_sample = 2 | |
elif format_for_conversion == "f32le": | |
size_of_sample = 4 | |
else: | |
raise ValueError(f"Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`") | |
system = platform.system() | |
if system == "Linux": | |
format_ = "alsa" | |
input_ = "default" | |
elif system == "Darwin": | |
format_ = "avfoundation" | |
input_ = ":0" | |
elif system == "Windows": | |
format_ = "dshow" | |
input_ = "default" | |
ffmpeg_command = [ | |
"ffmpeg", | |
"-f", | |
format_, | |
"-i", | |
input_, | |
"-ac", | |
ac, | |
"-ar", | |
ar, | |
"-f", | |
format_for_conversion, | |
"-fflags", | |
"nobuffer", | |
"-hide_banner", | |
"-loglevel", | |
"quiet", | |
"pipe:1", | |
] | |
chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample | |
iterator = _ffmpeg_stream(ffmpeg_command, chunk_len) | |
for item in iterator: | |
yield item | |
def chunk_bytes_iter(iterator, chunk_len: int, stride: Tuple[int, int], stream: bool = False): | |
""" | |
Reads raw bytes from an iterator and does chunks of length `chunk_len`. Optionally adds `stride` to each chunks to | |
get overlaps. `stream` is used to return partial results even if a full `chunk_len` is not yet available. | |
""" | |
acc = b"" | |
stride_left, stride_right = stride | |
if stride_left + stride_right >= chunk_len: | |
raise ValueError( | |
f"Stride needs to be strictly smaller than chunk_len: ({stride_left}, {stride_right}) vs {chunk_len}" | |
) | |
_stride_left = 0 | |
for raw in iterator: | |
acc += raw | |
if stream and len(acc) < chunk_len: | |
stride = (_stride_left, 0) | |
yield {"raw": acc[:chunk_len], "stride": stride, "partial": True} | |
else: | |
while len(acc) >= chunk_len: | |
stride = (_stride_left, stride_right) | |
item = {"raw": acc[:chunk_len], "stride": stride} | |
if stream: | |
item["partial"] = False | |
yield item | |
_stride_left = stride_left | |
acc = acc[chunk_len - stride_left - stride_right :] | |
if len(acc) > stride_left: | |
item = {"raw": acc, "stride": (_stride_left, 0)} | |
if stream: | |
item["partial"] = False | |
yield item | |
def ffmpeg_microphone_live( | |
sampling_rate: int, | |
chunk_length_s: float, | |
stream_chunk_s: Optional[int] = None, | |
stride_length_s: Optional[Union[Tuple[float, float], float]] = None, | |
format_for_conversion: str = "f32le", | |
): | |
if stream_chunk_s is not None: | |
chunk_s = stream_chunk_s | |
else: | |
chunk_s = chunk_length_s | |
microphone = ffmpeg_microphone(sampling_rate, chunk_s, format_for_conversion=format_for_conversion) | |
if format_for_conversion == "s16le": | |
dtype = np.int16 | |
size_of_sample = 2 | |
elif format_for_conversion == "f32le": | |
dtype = np.float32 | |
size_of_sample = 4 | |
else: | |
raise ValueError(f"Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`") | |
if stride_length_s is None: | |
stride_length_s = chunk_length_s / 6 | |
chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample | |
if isinstance(stride_length_s, (int, float)): | |
stride_length_s = [stride_length_s, stride_length_s] | |
stride_left = int(round(sampling_rate * stride_length_s[0])) * size_of_sample | |
stride_right = int(round(sampling_rate * stride_length_s[1])) * size_of_sample | |
audio_time = datetime.datetime.now() | |
delta = datetime.timedelta(seconds=chunk_s) | |
for item in chunk_bytes_iter(microphone, chunk_len, stride=(stride_left, stride_right), stream=True): | |
item["raw"] = np.frombuffer(item["raw"], dtype=dtype) | |
item["stride"] = ( | |
item["stride"][0] // size_of_sample, | |
item["stride"][1] // size_of_sample, | |
) | |
item["sampling_rate"] = sampling_rate | |
audio_time += delta | |
if datetime.datetime.now() > audio_time + 10 * delta: | |
continue | |
yield item | |
async def send_audio(websocket, mic_capture): | |
try: | |
while True: | |
audio_chunk = await asyncio.to_thread(next, mic_capture) | |
await websocket.send(audio_chunk["raw"].tobytes()) | |
except Exception as e: | |
print(f"Error sending audio: {e}") | |
async def display_transcription(websocket, typing_speed=0.05): | |
displayed_text = "" | |
previous_text = "" | |
async for transcription in websocket: | |
min_length = min(len(displayed_text), len(transcription)) | |
diff_index = next((i for i in range(min_length) if displayed_text[i] != transcription[i]), min_length) | |
sys.stdout.write('\r' + ' ' * len(displayed_text) + '\r') | |
sys.stdout.flush() | |
displayed_text = transcription | |
if displayed_text != previous_text: | |
sys.stdout.write(displayed_text[:diff_index]) | |
for char in displayed_text[diff_index:]: | |
sys.stdout.write(char) | |
sys.stdout.flush() | |
await asyncio.sleep(typing_speed) | |
previous_text = displayed_text | |
print() | |
async def receive_transcription(websocket, display_func): | |
try: | |
while True: | |
transcription = await websocket.recv() | |
await display_func(websocket, transcription) | |
except Exception as e: | |
print(f"Error receiving transcription: {e}") | |
async def send_audio_and_receive_transcription(uri, sampling_rate, chunk_length_s, stream_chunk_s): | |
async with websockets.connect(uri) as websocket: | |
mic_capture = await asyncio.to_thread( | |
ffmpeg_microphone_live, | |
sampling_rate=sampling_rate, | |
chunk_length_s=chunk_length_s, | |
stream_chunk_s=stream_chunk_s | |
) | |
send_task = asyncio.create_task(send_audio(websocket, mic_capture)) | |
display_task = asyncio.create_task(display_transcription(websocket)) | |
receive_task = asyncio.create_task(receive_transcription(websocket, display_task)) | |
await asyncio.gather(send_task, receive_task) | |
asyncio.run(send_audio_and_receive_transcription(f"ws://IP:PORT/ws/transcribe", 16000, 10.0, 1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment