Skip to content

Instantly share code, notes, and snippets.

@aarondewindt
Created September 27, 2025 11:46
Show Gist options
  • Save aarondewindt/298dde73631a9c77c298ce41e3862016 to your computer and use it in GitHub Desktop.
Save aarondewindt/298dde73631a9c77c298ce41e3862016 to your computer and use it in GitHub Desktop.
Aaron whisper setup
#!/usr/bin/env python3
import os, math, argparse
from typing import List, Tuple, Optional, Dict
import torch
from pyannote.audio import Pipeline
import srt
import datetime as dt
import webvtt
# ---------- time helpers ----------
def to_srt_ts(seconds: float) -> dt.timedelta:
if seconds is None: seconds = 0.0
ms = int(round((seconds - math.floor(seconds)) * 1000))
return dt.timedelta(seconds=int(seconds), milliseconds=ms)
def from_vtt_timestamp(ts: str) -> float:
# "HH:MM:SS.mmm" (WebVTT)
h, m, s = ts.split(":")
return int(h) * 3600 + int(m) * 60 + float(s)
def to_vtt_timestamp(seconds: float) -> str:
if seconds < 0: seconds = 0.0
h = int(seconds // 3600); m = int((seconds % 3600) // 60)
s = seconds - (h * 3600 + m * 60)
return f"{h:02d}:{m:02d}:{s:06.3f}"
# ---------- caption I/O ----------
def read_srt(path: str) -> List[Dict]:
with open(path, "r", encoding="utf-8") as f:
subs = list(srt.parse(f.read()))
items = []
for sub in subs:
start = sub.start.total_seconds()
end = sub.end.total_seconds()
items.append({"index": sub.index, "start": start, "end": end, "text": sub.content})
return items
def write_srt(path: str, items: List[Dict]) -> None:
subs = []
for i, it in enumerate(items, start=1):
subs.append(
srt.Subtitle(
index=i,
start=to_srt_ts(it["start"]),
end=to_srt_ts(it["end"]),
content=it["text"],
)
)
with open(path, "w", encoding="utf-8") as f:
f.write(srt.compose(subs))
def read_vtt(path: str) -> List[Dict]:
vtt = webvtt.read(path)
items = []
for i, cue in enumerate(vtt, start=1):
items.append({
"index": i,
"start": from_vtt_timestamp(cue.start),
"end": from_vtt_timestamp(cue.end),
"text": cue.text
})
return items
def write_vtt(path: str, items: List[Dict]) -> None:
vtt = webvtt.WebVTT()
for it in items:
cue = webvtt.Caption(
start=to_vtt_timestamp(it["start"]),
end=to_vtt_timestamp(it["end"]),
text=it["text"]
)
vtt.captions.append(cue)
vtt.save(path)
# ---------- diarization overlap ----------
def overlap(a: float, b: float, c: float, d: float) -> float:
# length of intersection of [a,b] and [c,d]
return max(0.0, min(b, d) - max(a, c))
def label_by_max_overlap(seg_start: float, seg_end: float, diar_segments: List[Tuple[float, float, str]]) -> Optional[str]:
best, best_len = None, 0.0
for (s, e, lab) in diar_segments:
ol = overlap(seg_start, seg_end, s, e)
if ol > best_len:
best_len = ol
best = lab
return best
# ---------- main ----------
def main():
ap = argparse.ArgumentParser(description="Add speaker labels to existing SRT/VTT using pyannote diarization.")
ap.add_argument("--audio_dir", required=True, help="Folder with .wav audio (names must match captions).")
ap.add_argument("--captions_dir", required=True, help="Folder with .srt or .vtt files.")
ap.add_argument("--out_dir", required=True)
ap.add_argument("--hf_token", required=True, help="Hugging Face token to download the pipeline the first time.")
ap.add_argument("--num_speakers", type=int, default=None)
ap.add_argument("--min_speakers", type=int, default=None)
ap.add_argument("--max_speakers", type=int, default=None)
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# Instantiate diarization pipeline (GPU optional)
pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=args.hf_token)
if torch.cuda.is_available():
pipe.to(torch.device("cuda"))
diar_kwargs = {}
if args.num_speakers is not None: diar_kwargs["num_speakers"] = args.num_speakers
if args.min_speakers is not None: diar_kwargs["min_speakers"] = args.min_speakers
if args.max_speakers is not None: diar_kwargs["max_speakers"] = args.max_speakers
# Build index of captions by basename
caption_map: Dict[str, Tuple[str, str]] = {} # base -> (path, kind)
for fn in os.listdir(args.captions_dir):
base, ext = os.path.splitext(fn)
ext = ext.lower()
if ext in (".srt", ".vtt"):
caption_map[base] = (os.path.join(args.captions_dir, fn), ext[1:])
# Process each audio whose base has a caption
for fn in sorted(os.listdir(args.audio_dir)):
if not fn.lower().endswith(".wav"):
continue
base = os.path.splitext(fn)[0]
if base not in caption_map:
print(f"[skip] No caption found for {fn}")
continue
audio_path = os.path.join(args.audio_dir, fn)
cap_path, kind = caption_map[base]
print(f"[diarize] {fn} | captions: {os.path.basename(cap_path)} ({kind})")
# Read captions
items = read_srt(cap_path) if kind == "srt" else read_vtt(cap_path)
# Run diarization (pyannote handles resampling + mono downmix automatically)
diar = pipe(audio_path, **diar_kwargs)
# Collect diarization segments
diar_segments: List[Tuple[float, float, str]] = []
for turn, _, label in diar.itertracks(yield_label=True):
diar_segments.append((turn.start, turn.end, label))
diar_segments.sort(key=lambda x: x[0])
# Tag each caption line with max-overlap speaker
tagged = []
for it in items:
speaker = label_by_max_overlap(it["start"], it["end"], diar_segments)
prefix = f"[{speaker}] " if speaker else ""
tagged.append({**it, "text": prefix + it["text"]})
# Write outputs
out_srt = os.path.join(args.out_dir, f"{base}.diarized.srt")
out_vtt = os.path.join(args.out_dir, f"{base}.diarized.vtt")
if kind == "srt":
write_srt(out_srt, tagged)
write_vtt(out_vtt, tagged)
else:
write_vtt(out_vtt, tagged)
write_srt(out_srt, tagged)
# Also dump RTTM (standard diarization format)
with open(os.path.join(args.out_dir, f"{base}.rttm"), "w", encoding="utf-8") as rttm:
diar.write_rttm(rttm)
print(f"[done] {base}: wrote *.diarized.srt, *.diarized.vtt, *.rttm")
if __name__ == "__main__":
main()
version: "3.9"
services:
whisper:
build: .
container_name: whisper
command: >
bash -lc '
mkdir -p /out &&
whisper /audio/*.wav
--model large-v3
--device cuda
--task transcribe
--output_dir /out
--output_format txt,srt,vtt,json'
volumes:
- ./audio:/audio:ro
- ./out:/out
- ./cache:/cache
environment:
- XDG_CACHE_HOME=/cache
devices:
- /dev/kfd
- /dev/dri
security_opt:
- seccomp=unconfined
ipc: host
user: "1000:1000"
diarize:
build: .
container_name: diarize
# Prefetch models on first run (optional), then run diarization
command: >
bash -lc '
mkdir -p /out /cache/hfcache &&
python -c "from huggingface_hub import snapshot_download;
import os; t=os.environ.get(\"HF_TOKEN\");
[snapshot_download(repo_id=r, token=t, cache_dir=\"/cache/hfcache\")
for r in [\"pyannote/segmentation-3.0\",\"pyannote/speaker-diarization-3.1\"]]" &&
export HF_HOME=/cache/hfcache &&
python /work/diarize_existing_captions.py
--audio_dir /audio
--captions_dir /captions
--out_dir /out
--hf_token ${HF_TOKEN}
--min_speakers 2 --max_speakers 6'
volumes:
- ./audio:/audio:ro
- ./captions:/captions:ro # your existing SRT/VTT go here
- ./out:/out
- ./cache:/cache
- ./diarize_existing_captions.py:/work/diarize_existing_captions.py:ro
environment:
- HF_TOKEN=${HF_TOKEN}
- HF_HOME=/cache/hfcache
# set to 1 after first successful download to force *offline* loads:
# - HF_HUB_OFFLINE=1
devices:
- /dev/kfd
- /dev/dri
security_opt:
- seccomp=unconfined
ipc: host
user: "1000:1000"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment