Skip to content

Instantly share code, notes, and snippets.

@info-wordcab
Created January 3, 2025 09:18
Show Gist options
  • Save info-wordcab/ae01828f19c3d350b5d30be33fac5a93 to your computer and use it in GitHub Desktop.
Save info-wordcab/ae01828f19c3d350b5d30be33fac5a93 to your computer and use it in GitHub Desktop.
Example VAD chunking for Triton
import asyncio
import aiohttp
import json
import wave
import webrtcvad
import contextlib
import time
from typing import List, Dict, Union, Optional
from pathlib import Path
import tempfile
import os
import logging
from pydub import AudioSegment
BATCH_SIZE = 64
class VADClient:
def __init__(self, base_url: str = "http://0.0.0.0:8003"):
self.base_url = base_url
self.logger = logging.getLogger(__name__)
self.vad = webrtcvad.Vad(3)
self.max_duration_ms = 29000
self.frame_ms = 30
self.min_silence_duration = 0.5
def _find_speech_segments(self, audio_path: str) -> List[Dict]:
with contextlib.closing(wave.open(audio_path, "rb")) as wf:
pcm_data = wf.readframes(wf.getnframes())
sample_rate = wf.getframerate()
duration = wf.getnframes() / float(sample_rate)
frame_bytes = int(sample_rate * (self.frame_ms / 1000.0) * 2)
silence_regions = []
silent_start = None
consecutive_silent = 0
for i in range(0, len(pcm_data), frame_bytes):
frame = pcm_data[i:i + frame_bytes]
if len(frame) == frame_bytes:
is_speech = self.vad.is_speech(frame, sample_rate)
frame_time = i / (2 * sample_rate)
if not is_speech:
consecutive_silent += self.frame_ms / 1000
if silent_start is None:
silent_start = frame_time
elif silent_start is not None:
if consecutive_silent >= self.min_silence_duration:
silence_regions.append((silent_start, frame_time, consecutive_silent))
silent_start = None
consecutive_silent = 0
chunks = []
current_start = 0
overlap_duration = 0.2
while current_start < duration:
max_end = min(current_start + 24.5, duration) # Leave margin
candidate_silences = [
(start, end, length) for start, end, length in silence_regions
if start > current_start and start < max_end
]
candidate_silences.sort(key=lambda x: (
abs(current_start + self.max_duration_ms - x[0]) * 0.7 +
-x[2] * 0.3
))
if candidate_silences:
split_point = candidate_silences[0][0]
else:
split_point = max_end
chunk_start_bytes = int(current_start * sample_rate * 2)
chunk_end_bytes = int((split_point + overlap_duration) * sample_rate * 2)
chunk_end_bytes = min(chunk_end_bytes, len(pcm_data))
chunks.append({
"start": current_start,
"end": split_point + overlap_duration,
"audio": pcm_data[chunk_start_bytes:chunk_end_bytes]
})
current_start = split_point
return chunks
def _prepare_audio_batch(self, segments: List[Dict]) -> List[Dict]:
chunks = []
for i, segment in enumerate(segments):
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
with wave.open(temp_file.name, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.writeframes(segment["audio"])
chunks.append({
"filepath": temp_file.name,
"metadata": {
"chunk_index": i,
"chunk_start_time": segment["start"],
"chunk_end_time": segment["end"]
}
})
return chunks
async def transcribe(self,
audio_file: Union[str, Path],
language: str = "en",
metadata: Optional[Dict] = None) -> Dict:
chunks = []
try:
vad_start = time.time()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
audio = AudioSegment.from_file(str(audio_file))
audio = audio.set_frame_rate(16000).set_channels(1)
audio.export(tmp.name, format="wav")
segments = self._find_speech_segments(tmp.name)
chunks = self._prepare_audio_batch(segments)
vad_time = time.time() - vad_start
self.logger.info(f"Processing {len(chunks)} chunks")
all_responses = []
transcribe_start = time.time()
for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i:i + BATCH_SIZE]
data = aiohttp.FormData()
metadata_list = []
for chunk in batch:
data.add_field("files",
open(chunk["filepath"], "rb"),
filename=Path(chunk["filepath"]).name)
chunk_metadata = chunk["metadata"].copy()
if metadata:
chunk_metadata.update(metadata)
metadata_list.append(chunk_metadata)
data.add_field("metadata", json.dumps(metadata_list))
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/transcribe/batch",
data=data,
params={'lang': language}
) as response:
if response.status != 200:
raise Exception(f"Batch transcription failed: {await response.text()}")
result = await response.json()
all_responses.extend(result['responses'])
transcribe_time = time.time() - transcribe_start
all_responses.sort(key=lambda x: x["metadata"]["chunk_index"])
return {
"text": " ".join(r['text'].strip() for r in all_responses),
"vad_time": vad_time,
"transcribe_time": transcribe_time,
"total_time": vad_time + transcribe_time,
"audio_duration": sum(r["audio_duration"] for r in all_responses)
}
finally:
for chunk in chunks:
if os.path.exists(chunk['filepath']):
os.unlink(chunk['filepath'])
async def main():
client = VADClient()
try:
result = await client.transcribe("audio.wav")
print(f"Text: {result['text']}")
print(f"VAD time: {result['vad_time']:.2f}s")
print(f"Transcribe time: {result['transcribe_time']:.2f}s")
print(f"Total time: {result['total_time']:.2f}s")
print(f"Audio duration: {result['audio_duration']:.2f}s")
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment