Created
August 11, 2024 06:40
-
-
Save hirocarma/3625b9ac112ffe433f369d6f42cfb7c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
#--- description | |
# Using WebRTC Voice Activity Detector (VAD) model, | |
# extract the periods of time a person is speaking | |
# from the audio file and calculate the percentage. | |
#--- reference | |
# https://webrtc.org/ | |
# https://github.com/wiseman/py-webrtcvad/tree/master | |
import collections | |
import contextlib | |
import sys | |
import wave | |
import pathlib | |
import polars as pl | |
import matplotlib.pyplot as plt | |
import webrtcvad | |
def read_wave(path): | |
with contextlib.closing(wave.open(path, "rb")) as wf: | |
num_channels = wf.getnchannels() | |
assert num_channels == 1 | |
sample_width = wf.getsampwidth() | |
assert sample_width == 2 | |
sample_rate = wf.getframerate() | |
assert sample_rate in (8000, 16000, 32000, 48000) | |
pcm_data = wf.readframes(wf.getnframes()) | |
return pcm_data, sample_rate | |
def write_wave(path, audio, sample_rate): | |
with contextlib.closing(wave.open(path, "wb")) as wf: | |
wf.setnchannels(1) | |
wf.setsampwidth(2) | |
wf.setframerate(sample_rate) | |
wf.writeframes(audio) | |
class Frame(object): | |
def __init__(self, bytes, timestamp, duration): | |
self.bytes = bytes | |
self.timestamp = timestamp | |
self.duration = duration | |
def frame_generator(frame_duration_ms, audio, sample_rate): | |
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) | |
offset = 0 | |
timestamp = 0.0 | |
duration = (float(n) / sample_rate) / 2.0 | |
while offset + n < len(audio): | |
yield Frame(audio[offset : offset + n], timestamp, duration) | |
timestamp += duration | |
offset += n | |
def vad_collector( | |
sample_rate: int, | |
frame_duration_ms: int, | |
padding_duration_ms: int, | |
vad: webrtcvad.Vad, | |
frames: list[Frame], | |
voice_trigger_on_thres: float = 0.9, | |
voice_trigger_off_thres: float = 0.1, | |
) -> list[dict]: | |
num_padding_frames = int(padding_duration_ms / frame_duration_ms) | |
frame_buffer = [] | |
triggered = False | |
voiced_frames = [] | |
vu_segments = [] | |
for frame in frames: | |
is_speech = vad.is_speech(frame.bytes, sample_rate) | |
frame_buffer.append((frame, is_speech)) | |
if not triggered: | |
num_voiced = len( | |
[f for f, speech in frame_buffer[-num_padding_frames:] if speech] | |
) | |
if num_voiced > voice_trigger_on_thres * num_padding_frames: | |
triggered = True | |
audio_data = b"".join( | |
[f.bytes for f, _ in frame_buffer[:-num_padding_frames]] | |
) | |
vu_segments.append( | |
{"vad": 0, "audio_size": len(audio_data), "audio_data": audio_data} | |
) | |
for f, _ in frame_buffer[-num_padding_frames:]: | |
voiced_frames.append(f) | |
frame_buffer = [] | |
else: | |
voiced_frames.append(frame) | |
num_unvoiced = len( | |
[f for f, speech in frame_buffer[-num_padding_frames:] if not speech] | |
) | |
if num_unvoiced > (1 - voice_trigger_off_thres) * num_padding_frames: | |
triggered = False | |
audio_data = b"".join([f.bytes for f in voiced_frames]) | |
vu_segments.append( | |
{"vad": 1, "audio_size": len(audio_data), "audio_data": audio_data} | |
) | |
voiced_frames = [] | |
frame_buffer = [] | |
if triggered: | |
audio_data = b"".join([f.bytes for f in voiced_frames]) | |
vu_segments.append( | |
{"vad": 1, "audio_size": len(audio_data), "audio_data": audio_data} | |
) | |
else: | |
audio_data = b"".join([f.bytes for f, _ in frame_buffer]) | |
vu_segments.append( | |
{"vad": 0, "audio_size": len(audio_data), "audio_data": audio_data} | |
) | |
return vu_segments | |
def segments_output(vu_segments, sample_rate, title): | |
segments_dir = pathlib.Path("./segments/" + title + "/") | |
segments_dir.mkdir(parents=True, exist_ok=True) | |
for_df = [] | |
speech_duration = 0 | |
total_duration = 0 | |
for i, segment in enumerate(vu_segments): | |
path = segments_dir.joinpath(f"segment-{i:03d}-vad{segment['vad']}.wav") | |
write_wave(str(path), segment["audio_data"], sample_rate) | |
duration_sec = segment["audio_size"] / 2.0 / sample_rate | |
if segment["vad"] == 1: | |
speech_duration += duration_sec | |
total_duration += duration_sec | |
for_df.append( | |
{ | |
"filename": path.name, | |
"vad": segment["vad"], | |
"duration_sec": duration_sec, | |
"speech_duration": speech_duration, | |
"total_duration": total_duration, | |
} | |
) | |
df = pl.DataFrame(for_df) | |
return df | |
def plot_vad(df, title): | |
total_duration = df["total_duration"][-1] | |
speech_duration = df["speech_duration"][-1] | |
speech_ratio = round(speech_duration / total_duration * 100, 1) | |
plt.rcParams["font.family"] = "sans-serif" | |
plt.rcParams["font.sans-serif"] = [ | |
"Noto Sans CJK JP", | |
"VL PGothic", | |
] | |
plt.style.use('Solarize_Light2') | |
fig = plt.figure(figsize=(20, 4), dpi=100, tight_layout=True) | |
ax = fig.add_subplot(111, fc="w", xlabel="sec") | |
stats = ( | |
" ---- 発話時間: " | |
+ str(round(speech_duration)) | |
+ "sec " | |
+ "(total: " | |
+ str(round(total_duration)) | |
+ "sec) " | |
+ " 発話率: " | |
+ str(speech_ratio) | |
+ "% ---- " | |
) | |
ax.set_title(title + " 発話区間検出 " + stats) | |
ax.set_xlim(0, total_duration) | |
plt.yticks([]) | |
ax.bar(df["total_duration"], df["vad"], width=1.0, linewidth=0) | |
fig.savefig(title + "-webrtc_vad.png", facecolor=fig.get_facecolor()) | |
plt.show() | |
def main(argv): | |
if len(argv) != 4: | |
sys.stderr.write( | |
"Usage: script.py <aggressiveness> <path to wav file> <title>\n" | |
) | |
sys.exit(1) | |
vad = webrtcvad.Vad(int(argv[1])) | |
audio, sample_rate = read_wave(argv[2]) | |
title = argv[3] | |
frames = frame_generator(30, audio, sample_rate) | |
frames = list(frames) | |
vu_segments = vad_collector( | |
sample_rate, 30, 300, vad, frames, voice_trigger_off_thres=0.8 | |
) | |
df = segments_output(vu_segments, sample_rate, title) | |
print(df) | |
plot_vad(df, title) | |
if __name__ == '__main__': | |
main(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment