Skip to content

Instantly share code, notes, and snippets.

@hirocarma
Created August 11, 2024 06:40
Show Gist options
  • Save hirocarma/3625b9ac112ffe433f369d6f42cfb7c7 to your computer and use it in GitHub Desktop.
Save hirocarma/3625b9ac112ffe433f369d6f42cfb7c7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#--- description
# Using WebRTC Voice Activity Detector (VAD) model,
# extract the periods of time a person is speaking
# from the audio file and calculate the percentage.
#--- reference
# https://webrtc.org/
# https://github.com/wiseman/py-webrtcvad/tree/master
import collections
import contextlib
import sys
import wave
import pathlib
import polars as pl
import matplotlib.pyplot as plt
import webrtcvad
def read_wave(path):
with contextlib.closing(wave.open(path, "rb")) as wf:
num_channels = wf.getnchannels()
assert num_channels == 1
sample_width = wf.getsampwidth()
assert sample_width == 2
sample_rate = wf.getframerate()
assert sample_rate in (8000, 16000, 32000, 48000)
pcm_data = wf.readframes(wf.getnframes())
return pcm_data, sample_rate
def write_wave(path, audio, sample_rate):
with contextlib.closing(wave.open(path, "wb")) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio)
class Frame(object):
def __init__(self, bytes, timestamp, duration):
self.bytes = bytes
self.timestamp = timestamp
self.duration = duration
def frame_generator(frame_duration_ms, audio, sample_rate):
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
offset = 0
timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio):
yield Frame(audio[offset : offset + n], timestamp, duration)
timestamp += duration
offset += n
def vad_collector(
sample_rate: int,
frame_duration_ms: int,
padding_duration_ms: int,
vad: webrtcvad.Vad,
frames: list[Frame],
voice_trigger_on_thres: float = 0.9,
voice_trigger_off_thres: float = 0.1,
) -> list[dict]:
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
frame_buffer = []
triggered = False
voiced_frames = []
vu_segments = []
for frame in frames:
is_speech = vad.is_speech(frame.bytes, sample_rate)
frame_buffer.append((frame, is_speech))
if not triggered:
num_voiced = len(
[f for f, speech in frame_buffer[-num_padding_frames:] if speech]
)
if num_voiced > voice_trigger_on_thres * num_padding_frames:
triggered = True
audio_data = b"".join(
[f.bytes for f, _ in frame_buffer[:-num_padding_frames]]
)
vu_segments.append(
{"vad": 0, "audio_size": len(audio_data), "audio_data": audio_data}
)
for f, _ in frame_buffer[-num_padding_frames:]:
voiced_frames.append(f)
frame_buffer = []
else:
voiced_frames.append(frame)
num_unvoiced = len(
[f for f, speech in frame_buffer[-num_padding_frames:] if not speech]
)
if num_unvoiced > (1 - voice_trigger_off_thres) * num_padding_frames:
triggered = False
audio_data = b"".join([f.bytes for f in voiced_frames])
vu_segments.append(
{"vad": 1, "audio_size": len(audio_data), "audio_data": audio_data}
)
voiced_frames = []
frame_buffer = []
if triggered:
audio_data = b"".join([f.bytes for f in voiced_frames])
vu_segments.append(
{"vad": 1, "audio_size": len(audio_data), "audio_data": audio_data}
)
else:
audio_data = b"".join([f.bytes for f, _ in frame_buffer])
vu_segments.append(
{"vad": 0, "audio_size": len(audio_data), "audio_data": audio_data}
)
return vu_segments
def segments_output(vu_segments, sample_rate, title):
segments_dir = pathlib.Path("./segments/" + title + "/")
segments_dir.mkdir(parents=True, exist_ok=True)
for_df = []
speech_duration = 0
total_duration = 0
for i, segment in enumerate(vu_segments):
path = segments_dir.joinpath(f"segment-{i:03d}-vad{segment['vad']}.wav")
write_wave(str(path), segment["audio_data"], sample_rate)
duration_sec = segment["audio_size"] / 2.0 / sample_rate
if segment["vad"] == 1:
speech_duration += duration_sec
total_duration += duration_sec
for_df.append(
{
"filename": path.name,
"vad": segment["vad"],
"duration_sec": duration_sec,
"speech_duration": speech_duration,
"total_duration": total_duration,
}
)
df = pl.DataFrame(for_df)
return df
def plot_vad(df, title):
total_duration = df["total_duration"][-1]
speech_duration = df["speech_duration"][-1]
speech_ratio = round(speech_duration / total_duration * 100, 1)
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = [
"Noto Sans CJK JP",
"VL PGothic",
]
plt.style.use('Solarize_Light2')
fig = plt.figure(figsize=(20, 4), dpi=100, tight_layout=True)
ax = fig.add_subplot(111, fc="w", xlabel="sec")
stats = (
" ---- 発話時間: "
+ str(round(speech_duration))
+ "sec "
+ "(total: "
+ str(round(total_duration))
+ "sec) "
+ " 発話率: "
+ str(speech_ratio)
+ "% ---- "
)
ax.set_title(title + " 発話区間検出 " + stats)
ax.set_xlim(0, total_duration)
plt.yticks([])
ax.bar(df["total_duration"], df["vad"], width=1.0, linewidth=0)
fig.savefig(title + "-webrtc_vad.png", facecolor=fig.get_facecolor())
plt.show()
def main(argv):
if len(argv) != 4:
sys.stderr.write(
"Usage: script.py <aggressiveness> <path to wav file> <title>\n"
)
sys.exit(1)
vad = webrtcvad.Vad(int(argv[1]))
audio, sample_rate = read_wave(argv[2])
title = argv[3]
frames = frame_generator(30, audio, sample_rate)
frames = list(frames)
vu_segments = vad_collector(
sample_rate, 30, 300, vad, frames, voice_trigger_off_thres=0.8
)
df = segments_output(vu_segments, sample_rate, title)
print(df)
plot_vad(df, title)
if __name__ == '__main__':
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment