Created
September 27, 2025 11:46
-
-
Save aarondewindt/298dde73631a9c77c298ce41e3862016 to your computer and use it in GitHub Desktop.
Aaron whisper setup
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import os, math, argparse | |
| from typing import List, Tuple, Optional, Dict | |
| import torch | |
| from pyannote.audio import Pipeline | |
| import srt | |
| import datetime as dt | |
| import webvtt | |
| # ---------- time helpers ---------- | |
| def to_srt_ts(seconds: float) -> dt.timedelta: | |
| if seconds is None: seconds = 0.0 | |
| ms = int(round((seconds - math.floor(seconds)) * 1000)) | |
| return dt.timedelta(seconds=int(seconds), milliseconds=ms) | |
| def from_vtt_timestamp(ts: str) -> float: | |
| # "HH:MM:SS.mmm" (WebVTT) | |
| h, m, s = ts.split(":") | |
| return int(h) * 3600 + int(m) * 60 + float(s) | |
| def to_vtt_timestamp(seconds: float) -> str: | |
| if seconds < 0: seconds = 0.0 | |
| h = int(seconds // 3600); m = int((seconds % 3600) // 60) | |
| s = seconds - (h * 3600 + m * 60) | |
| return f"{h:02d}:{m:02d}:{s:06.3f}" | |
| # ---------- caption I/O ---------- | |
| def read_srt(path: str) -> List[Dict]: | |
| with open(path, "r", encoding="utf-8") as f: | |
| subs = list(srt.parse(f.read())) | |
| items = [] | |
| for sub in subs: | |
| start = sub.start.total_seconds() | |
| end = sub.end.total_seconds() | |
| items.append({"index": sub.index, "start": start, "end": end, "text": sub.content}) | |
| return items | |
| def write_srt(path: str, items: List[Dict]) -> None: | |
| subs = [] | |
| for i, it in enumerate(items, start=1): | |
| subs.append( | |
| srt.Subtitle( | |
| index=i, | |
| start=to_srt_ts(it["start"]), | |
| end=to_srt_ts(it["end"]), | |
| content=it["text"], | |
| ) | |
| ) | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(srt.compose(subs)) | |
| def read_vtt(path: str) -> List[Dict]: | |
| vtt = webvtt.read(path) | |
| items = [] | |
| for i, cue in enumerate(vtt, start=1): | |
| items.append({ | |
| "index": i, | |
| "start": from_vtt_timestamp(cue.start), | |
| "end": from_vtt_timestamp(cue.end), | |
| "text": cue.text | |
| }) | |
| return items | |
| def write_vtt(path: str, items: List[Dict]) -> None: | |
| vtt = webvtt.WebVTT() | |
| for it in items: | |
| cue = webvtt.Caption( | |
| start=to_vtt_timestamp(it["start"]), | |
| end=to_vtt_timestamp(it["end"]), | |
| text=it["text"] | |
| ) | |
| vtt.captions.append(cue) | |
| vtt.save(path) | |
| # ---------- diarization overlap ---------- | |
| def overlap(a: float, b: float, c: float, d: float) -> float: | |
| # length of intersection of [a,b] and [c,d] | |
| return max(0.0, min(b, d) - max(a, c)) | |
| def label_by_max_overlap(seg_start: float, seg_end: float, diar_segments: List[Tuple[float, float, str]]) -> Optional[str]: | |
| best, best_len = None, 0.0 | |
| for (s, e, lab) in diar_segments: | |
| ol = overlap(seg_start, seg_end, s, e) | |
| if ol > best_len: | |
| best_len = ol | |
| best = lab | |
| return best | |
| # ---------- main ---------- | |
| def main(): | |
| ap = argparse.ArgumentParser(description="Add speaker labels to existing SRT/VTT using pyannote diarization.") | |
| ap.add_argument("--audio_dir", required=True, help="Folder with .wav audio (names must match captions).") | |
| ap.add_argument("--captions_dir", required=True, help="Folder with .srt or .vtt files.") | |
| ap.add_argument("--out_dir", required=True) | |
| ap.add_argument("--hf_token", required=True, help="Hugging Face token to download the pipeline the first time.") | |
| ap.add_argument("--num_speakers", type=int, default=None) | |
| ap.add_argument("--min_speakers", type=int, default=None) | |
| ap.add_argument("--max_speakers", type=int, default=None) | |
| args = ap.parse_args() | |
| os.makedirs(args.out_dir, exist_ok=True) | |
| # Instantiate diarization pipeline (GPU optional) | |
| pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=args.hf_token) | |
| if torch.cuda.is_available(): | |
| pipe.to(torch.device("cuda")) | |
| diar_kwargs = {} | |
| if args.num_speakers is not None: diar_kwargs["num_speakers"] = args.num_speakers | |
| if args.min_speakers is not None: diar_kwargs["min_speakers"] = args.min_speakers | |
| if args.max_speakers is not None: diar_kwargs["max_speakers"] = args.max_speakers | |
| # Build index of captions by basename | |
| caption_map: Dict[str, Tuple[str, str]] = {} # base -> (path, kind) | |
| for fn in os.listdir(args.captions_dir): | |
| base, ext = os.path.splitext(fn) | |
| ext = ext.lower() | |
| if ext in (".srt", ".vtt"): | |
| caption_map[base] = (os.path.join(args.captions_dir, fn), ext[1:]) | |
| # Process each audio whose base has a caption | |
| for fn in sorted(os.listdir(args.audio_dir)): | |
| if not fn.lower().endswith(".wav"): | |
| continue | |
| base = os.path.splitext(fn)[0] | |
| if base not in caption_map: | |
| print(f"[skip] No caption found for {fn}") | |
| continue | |
| audio_path = os.path.join(args.audio_dir, fn) | |
| cap_path, kind = caption_map[base] | |
| print(f"[diarize] {fn} | captions: {os.path.basename(cap_path)} ({kind})") | |
| # Read captions | |
| items = read_srt(cap_path) if kind == "srt" else read_vtt(cap_path) | |
| # Run diarization (pyannote handles resampling + mono downmix automatically) | |
| diar = pipe(audio_path, **diar_kwargs) | |
| # Collect diarization segments | |
| diar_segments: List[Tuple[float, float, str]] = [] | |
| for turn, _, label in diar.itertracks(yield_label=True): | |
| diar_segments.append((turn.start, turn.end, label)) | |
| diar_segments.sort(key=lambda x: x[0]) | |
| # Tag each caption line with max-overlap speaker | |
| tagged = [] | |
| for it in items: | |
| speaker = label_by_max_overlap(it["start"], it["end"], diar_segments) | |
| prefix = f"[{speaker}] " if speaker else "" | |
| tagged.append({**it, "text": prefix + it["text"]}) | |
| # Write outputs | |
| out_srt = os.path.join(args.out_dir, f"{base}.diarized.srt") | |
| out_vtt = os.path.join(args.out_dir, f"{base}.diarized.vtt") | |
| if kind == "srt": | |
| write_srt(out_srt, tagged) | |
| write_vtt(out_vtt, tagged) | |
| else: | |
| write_vtt(out_vtt, tagged) | |
| write_srt(out_srt, tagged) | |
| # Also dump RTTM (standard diarization format) | |
| with open(os.path.join(args.out_dir, f"{base}.rttm"), "w", encoding="utf-8") as rttm: | |
| diar.write_rttm(rttm) | |
| print(f"[done] {base}: wrote *.diarized.srt, *.diarized.vtt, *.rttm") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| version: "3.9" | |
| services: | |
| whisper: | |
| build: . | |
| container_name: whisper | |
| command: > | |
| bash -lc ' | |
| mkdir -p /out && | |
| whisper /audio/*.wav | |
| --model large-v3 | |
| --device cuda | |
| --task transcribe | |
| --output_dir /out | |
| --output_format txt,srt,vtt,json' | |
| volumes: | |
| - ./audio:/audio:ro | |
| - ./out:/out | |
| - ./cache:/cache | |
| environment: | |
| - XDG_CACHE_HOME=/cache | |
| devices: | |
| - /dev/kfd | |
| - /dev/dri | |
| security_opt: | |
| - seccomp=unconfined | |
| ipc: host | |
| user: "1000:1000" | |
| diarize: | |
| build: . | |
| container_name: diarize | |
| # Prefetch models on first run (optional), then run diarization | |
| command: > | |
| bash -lc ' | |
| mkdir -p /out /cache/hfcache && | |
| python -c "from huggingface_hub import snapshot_download; | |
| import os; t=os.environ.get(\"HF_TOKEN\"); | |
| [snapshot_download(repo_id=r, token=t, cache_dir=\"/cache/hfcache\") | |
| for r in [\"pyannote/segmentation-3.0\",\"pyannote/speaker-diarization-3.1\"]]" && | |
| export HF_HOME=/cache/hfcache && | |
| python /work/diarize_existing_captions.py | |
| --audio_dir /audio | |
| --captions_dir /captions | |
| --out_dir /out | |
| --hf_token ${HF_TOKEN} | |
| --min_speakers 2 --max_speakers 6' | |
| volumes: | |
| - ./audio:/audio:ro | |
| - ./captions:/captions:ro # your existing SRT/VTT go here | |
| - ./out:/out | |
| - ./cache:/cache | |
| - ./diarize_existing_captions.py:/work/diarize_existing_captions.py:ro | |
| environment: | |
| - HF_TOKEN=${HF_TOKEN} | |
| - HF_HOME=/cache/hfcache | |
| # set to 1 after first successful download to force *offline* loads: | |
| # - HF_HUB_OFFLINE=1 | |
| devices: | |
| - /dev/kfd | |
| - /dev/dri | |
| security_opt: | |
| - seccomp=unconfined | |
| ipc: host | |
| user: "1000:1000" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment