Created
January 3, 2025 09:18
-
-
Save info-wordcab/ae01828f19c3d350b5d30be33fac5a93 to your computer and use it in GitHub Desktop.
Example VAD chunking for Triton
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import aiohttp | |
import json | |
import wave | |
import webrtcvad | |
import contextlib | |
import time | |
from typing import List, Dict, Union, Optional | |
from pathlib import Path | |
import tempfile | |
import os | |
import logging | |
from pydub import AudioSegment | |
BATCH_SIZE = 64 | |
class VADClient: | |
def __init__(self, base_url: str = "http://0.0.0.0:8003"): | |
self.base_url = base_url | |
self.logger = logging.getLogger(__name__) | |
self.vad = webrtcvad.Vad(3) | |
self.max_duration_ms = 29000 | |
self.frame_ms = 30 | |
self.min_silence_duration = 0.5 | |
def _find_speech_segments(self, audio_path: str) -> List[Dict]: | |
with contextlib.closing(wave.open(audio_path, "rb")) as wf: | |
pcm_data = wf.readframes(wf.getnframes()) | |
sample_rate = wf.getframerate() | |
duration = wf.getnframes() / float(sample_rate) | |
frame_bytes = int(sample_rate * (self.frame_ms / 1000.0) * 2) | |
silence_regions = [] | |
silent_start = None | |
consecutive_silent = 0 | |
for i in range(0, len(pcm_data), frame_bytes): | |
frame = pcm_data[i:i + frame_bytes] | |
if len(frame) == frame_bytes: | |
is_speech = self.vad.is_speech(frame, sample_rate) | |
frame_time = i / (2 * sample_rate) | |
if not is_speech: | |
consecutive_silent += self.frame_ms / 1000 | |
if silent_start is None: | |
silent_start = frame_time | |
elif silent_start is not None: | |
if consecutive_silent >= self.min_silence_duration: | |
silence_regions.append((silent_start, frame_time, consecutive_silent)) | |
silent_start = None | |
consecutive_silent = 0 | |
chunks = [] | |
current_start = 0 | |
overlap_duration = 0.2 | |
while current_start < duration: | |
max_end = min(current_start + 24.5, duration) # Leave margin | |
candidate_silences = [ | |
(start, end, length) for start, end, length in silence_regions | |
if start > current_start and start < max_end | |
] | |
candidate_silences.sort(key=lambda x: ( | |
abs(current_start + self.max_duration_ms - x[0]) * 0.7 + | |
-x[2] * 0.3 | |
)) | |
if candidate_silences: | |
split_point = candidate_silences[0][0] | |
else: | |
split_point = max_end | |
chunk_start_bytes = int(current_start * sample_rate * 2) | |
chunk_end_bytes = int((split_point + overlap_duration) * sample_rate * 2) | |
chunk_end_bytes = min(chunk_end_bytes, len(pcm_data)) | |
chunks.append({ | |
"start": current_start, | |
"end": split_point + overlap_duration, | |
"audio": pcm_data[chunk_start_bytes:chunk_end_bytes] | |
}) | |
current_start = split_point | |
return chunks | |
def _prepare_audio_batch(self, segments: List[Dict]) -> List[Dict]: | |
chunks = [] | |
for i, segment in enumerate(segments): | |
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
with wave.open(temp_file.name, "wb") as wf: | |
wf.setnchannels(1) | |
wf.setsampwidth(2) | |
wf.setframerate(16000) | |
wf.writeframes(segment["audio"]) | |
chunks.append({ | |
"filepath": temp_file.name, | |
"metadata": { | |
"chunk_index": i, | |
"chunk_start_time": segment["start"], | |
"chunk_end_time": segment["end"] | |
} | |
}) | |
return chunks | |
async def transcribe(self, | |
audio_file: Union[str, Path], | |
language: str = "en", | |
metadata: Optional[Dict] = None) -> Dict: | |
chunks = [] | |
try: | |
vad_start = time.time() | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
audio = AudioSegment.from_file(str(audio_file)) | |
audio = audio.set_frame_rate(16000).set_channels(1) | |
audio.export(tmp.name, format="wav") | |
segments = self._find_speech_segments(tmp.name) | |
chunks = self._prepare_audio_batch(segments) | |
vad_time = time.time() - vad_start | |
self.logger.info(f"Processing {len(chunks)} chunks") | |
all_responses = [] | |
transcribe_start = time.time() | |
for i in range(0, len(chunks), BATCH_SIZE): | |
batch = chunks[i:i + BATCH_SIZE] | |
data = aiohttp.FormData() | |
metadata_list = [] | |
for chunk in batch: | |
data.add_field("files", | |
open(chunk["filepath"], "rb"), | |
filename=Path(chunk["filepath"]).name) | |
chunk_metadata = chunk["metadata"].copy() | |
if metadata: | |
chunk_metadata.update(metadata) | |
metadata_list.append(chunk_metadata) | |
data.add_field("metadata", json.dumps(metadata_list)) | |
async with aiohttp.ClientSession() as session: | |
async with session.post( | |
f"{self.base_url}/transcribe/batch", | |
data=data, | |
params={'lang': language} | |
) as response: | |
if response.status != 200: | |
raise Exception(f"Batch transcription failed: {await response.text()}") | |
result = await response.json() | |
all_responses.extend(result['responses']) | |
transcribe_time = time.time() - transcribe_start | |
all_responses.sort(key=lambda x: x["metadata"]["chunk_index"]) | |
return { | |
"text": " ".join(r['text'].strip() for r in all_responses), | |
"vad_time": vad_time, | |
"transcribe_time": transcribe_time, | |
"total_time": vad_time + transcribe_time, | |
"audio_duration": sum(r["audio_duration"] for r in all_responses) | |
} | |
finally: | |
for chunk in chunks: | |
if os.path.exists(chunk['filepath']): | |
os.unlink(chunk['filepath']) | |
async def main(): | |
client = VADClient() | |
try: | |
result = await client.transcribe("audio.wav") | |
print(f"Text: {result['text']}") | |
print(f"VAD time: {result['vad_time']:.2f}s") | |
print(f"Transcribe time: {result['transcribe_time']:.2f}s") | |
print(f"Total time: {result['total_time']:.2f}s") | |
print(f"Audio duration: {result['audio_duration']:.2f}s") | |
except Exception as e: | |
print(f"Error: {e}") | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment