|
# /// script |
|
# requires-python = ">=3.12" |
|
# dependencies = [ |
|
# "loguru", |
|
# "pyannote-audio>=3.3.2", |
|
# "torch>=2.7.0", |
|
# "torchaudio>=2.7.0", |
|
# "tqdm", |
|
# ] |
|
# [[tool.uv.index]] |
|
# url = "https://download.pytorch.org/whl/cpu" |
|
# /// |
|
|
|
import argparse |
|
import os |
|
import warnings |
|
from pathlib import Path |
|
from typing import TYPE_CHECKING, Callable |
|
|
|
import torch |
|
import torchaudio |
|
from loguru import logger |
|
from tqdm import tqdm |
|
|
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", category=SyntaxWarning) |
|
from pyannote.audio import Pipeline |
|
|
|
if TYPE_CHECKING: |
|
from pyannote.core import Annotation |
|
|
|
|
|
def apply_vad_pipeline( |
|
audio_file: str | os.PathLike, duration_seconds: int | None = None, hook: Callable | None = None |
|
) -> "Annotation": |
|
pipeline = Pipeline.from_pretrained("pyannote/voice-activity-detection") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pipeline.to(device) |
|
|
|
logger.debug(f"Processing {audio_file} on {device}") |
|
waveform, sample_rate = torchaudio.load(audio_file) |
|
if duration_seconds is not None: |
|
logger.debug(f"Trimming waveform to {duration_seconds} seconds") |
|
max_samples = sample_rate * duration_seconds |
|
waveform = waveform[:, :max_samples] |
|
input_ = {"waveform": waveform, "sample_rate": sample_rate} |
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", category=UserWarning, message=".*MPEG_LAYER_III.*") |
|
output = pipeline(input_, hook=hook) |
|
return output |
|
|
|
|
|
def process_input_file(audio_input_file: str | os.PathLike) -> "Annotation": |
|
label = Path(audio_input_file).stem |
|
logger.info(f"Processing audio file: {audio_input_file} (label: {label})") |
|
logger.info("Applying VAD pipeline...") |
|
annotation = apply_vad_pipeline(audio_input_file) |
|
annotation.rename_labels({"SPEECH": label}, copy=False) |
|
annotation.uri = label |
|
return annotation |
|
|
|
|
|
def main(input_files: list[str], output_dir: str | os.PathLike): |
|
logger.info(f"Starting VAD processing for {len(input_files)} input files.") |
|
logger.info(f"Output directory: {output_dir}") |
|
output_dir = Path(output_dir) |
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
mixed_annotation = Annotation(uri="mixed") |
|
for audio_file in tqdm(input_files): |
|
annotation = process_input_file(audio_file, output_dir=output_dir) |
|
Path(output_dir / f"{annotation.uri}.rttm").write_text(annotation.to_rttm()) |
|
mixed_annotation.update(annotation) |
|
Path(output_dir / "mixed.rttm").write_text(mixed_annotation.to_rttm()) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description=""" |
|
Generate ground-truth diarization data: |
|
1. Apply VAD to each of the input audio files. |
|
2. Mix the results into a single RTTM file. |
|
""", |
|
formatter_class=argparse.RawTextHelpFormatter, |
|
) |
|
parser.add_argument( |
|
"--input-files", |
|
"-i", |
|
help="Path to an input audio files. For more than one file, use space-separated paths.", |
|
type=str, |
|
nargs="+", |
|
required=True, |
|
) |
|
parser.add_argument("--output-dir", "-o", type=str, help="Directory to save output RTTM files.", required=True) |
|
|
|
args = parser.parse_args() |
|
logger.remove() |
|
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) |
|
main(args.input_files, args.output_dir) |