Skip to content

Instantly share code, notes, and snippets.

@Rastrian
Created April 13, 2026 12:44
Show Gist options
  • Select an option

  • Save Rastrian/efb147264b76a54c3d290f506c1fccdb to your computer and use it in GitHub Desktop.

Select an option

Save Rastrian/efb147264b76a54c3d290f506c1fccdb to your computer and use it in GitHub Desktop.
Transcription (mkv video to txt using whisperx with diarization)
#!/bin/bash
# Setup script for transcript-app
set -e
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
cd "$SCRIPT_DIR"
echo "Creating virtual environment..."
python3 -m venv .venv
echo "Activating virtual environment..."
source .venv/bin/activate
echo "Installing PyTorch (CPU for macOS)..."
pip install torch torchaudio
echo "Installing whisperx..."
pip install git+https://github.com/m-bain/whisperx.git
echo ""
echo "Setup complete! To use:"
echo " source .venv/bin/activate"
echo " python3 transcribe.py /path/to/video.mkv --hf-token YOUR_TOKEN"
echo ""
echo "Optional: set HF_TOKEN env var to avoid passing --hf-token every time:"
echo " export HF_TOKEN=your_token_here"
#!/usr/bin/env python3
"""
Transcribe a video/audio file with speaker diarization using whisperx.
Usage:
# First-time setup:
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
# Run (will prompt for HF token if not set):
python3 transcribe.py /path/to/video.mkv
python3 transcribe.py /path/to/video.mkv --language en --model large-v3 --output transcript.txt
"""
import argparse
import os
import subprocess
import sys
import tempfile
from concurrent.futures import ThreadPoolExecutor
def count_audio_streams(video_path: str) -> int:
"""Count the number of audio streams in a video file."""
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "a",
"-show_entries", "stream=index",
"-of", "csv=p=0",
video_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
return 0
return len([line for line in result.stdout.strip().split("\n") if line.strip()])
def extract_single_audio(video_path: str, audio_path: str) -> None:
"""Extract the first (or only) audio track from video into a mono WAV."""
print(f" Extracting audio from: {video_path}")
cmd = [
"ffmpeg", "-i", video_path,
"-vn",
"-acodec", "pcm_s16le",
"-ar", "16000",
"-ac", "1",
"-y",
audio_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"ffmpeg error:\n{result.stderr}", file=sys.stderr)
sys.exit(1)
print(" Audio extracted successfully.")
def _is_audio_silent(wav_path: str, threshold_db: float = -40.0) -> bool:
"""Check if a WAV file is effectively silent using RMS energy.
Uses ffmpeg's volumedetect filter to measure the mean volume in dB.
Returns True if the volume is below the threshold (i.e., silent).
Default threshold of -40 dB filters out near-silent/crosstalk-only tracks.
"""
cmd = [
"ffmpeg", "-i", wav_path,
"-af", f"volumedetect",
"-f", "null", "-",
]
result = subprocess.run(cmd, capture_output=True, text=True)
# Parse mean_volume from stderr, e.g. "mean_volume: -60.0 dB"
for line in result.stderr.splitlines():
if "mean_volume" in line:
try:
volume_db = float(line.split(":")[-1].strip().split()[0])
return volume_db < threshold_db
except (ValueError, IndexError):
pass
# If we can't parse volume, treat as not silent to be safe
return False
def _extract_track(video_path: str, track_index: int) -> str | None:
"""Extract a single audio track. Returns path, or None if extraction failed."""
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.close()
track_path = tmp.name
cmd_extract = [
"ffmpeg", "-i", video_path,
"-map", f"0:a:{track_index}",
"-acodec", "pcm_s16le",
"-ar", "16000",
"-ac", "1",
"-y",
track_path,
]
result = subprocess.run(cmd_extract, capture_output=True, text=True)
if result.returncode != 0:
os.unlink(track_path)
return None
return track_path
def extract_multi_track_audio(video_path: str) -> list[str]:
"""Extract each audio track from video into separate mono WAV files.
Extracts tracks in parallel. Returns ALL tracks (even quiet ones),
since each track represents a speaker.
The caller is responsible for cleaning up these files.
"""
num_streams = count_audio_streams(video_path)
print(f" Found {num_streams} audio stream(s) — extracting in parallel.")
track_paths = [None] * num_streams
with ThreadPoolExecutor(max_workers=num_streams) as executor:
futures = {
executor.submit(_extract_track, video_path, i): i
for i in range(num_streams)
}
for future in futures:
i = futures[future]
path = future.result()
if path is not None:
print(f" Track {i} extracted.")
track_paths[i] = path
else:
print(f" Track {i} failed to extract.")
track_paths = [p for p in track_paths if p is not None]
if not track_paths:
print("Error: no audio tracks could be extracted.", file=sys.stderr)
sys.exit(1)
print(f" {len(track_paths)} track(s) extracted successfully.")
return track_paths
def _get_device(device: str) -> str:
"""Resolve 'auto' device to a concrete device string."""
import torch
if device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return device
def _rms_energy(audio: "numpy.ndarray", start_sample: int, end_sample: int) -> float:
"""Compute RMS energy of an audio segment."""
import numpy as np
segment = audio[start_sample:end_sample]
if len(segment) == 0:
return 0.0
return float(np.sqrt(np.mean(segment ** 2)))
def assign_speakers_by_energy(
segments: list[dict],
track_audios: list["numpy.ndarray"],
sample_rate: int = 16000,
) -> list[dict]:
"""Assign speaker labels to segments based on per-track audio energy.
For each segment, computes the RMS energy on each individual track
during the segment's time window. The track with the highest energy
(i.e., the loudest microphone) is the speaker.
"""
import numpy as np
num_tracks = len(track_audios)
for seg in segments:
start_sample = int(seg.get("start", 0) * sample_rate)
end_sample = int(seg.get("end", 0) * sample_rate)
# Compute RMS energy per track in one vectorized pass
energies = np.array([
_rms_energy(audio, start_sample, end_sample)
for audio in track_audios
])
best_track = int(np.argmax(energies))
seg["speaker"] = f"SPEAKER_{best_track:02d}"
return segments
def transcribe_with_diarization(
audio_path: str,
language: str | None = None,
model_name: str = "large-v3",
device: str = "auto",
hf_token: str | None = None,
) -> list[dict]:
"""
Run whisperx transcription + speaker diarization (single-track path).
Returns a list of segments with start, end, text, and speaker.
"""
import torch
import whisperx
asr_device = _get_device(device)
print(f"[2/4] Loading whisperx model '{model_name}' on {asr_device}...")
compute_type = "float16" if asr_device != "cpu" else "int8"
model = whisperx.load_model(
model_name,
device=asr_device,
compute_type=compute_type,
language=language,
)
audio = whisperx.load_audio(audio_path)
result = model.transcribe(audio, batch_size=16, language=language)
language_detected = result.get("language", "unknown")
print(f" Detected language: {language_detected}")
# Align transcription for precise word-level timestamps
align_device = device if device != "auto" else asr_device
print(f"[3/4] Aligning transcription (language: {language_detected}) on {align_device}...")
model_a, metadata = whisperx.load_align_model(
language_code=language_detected, device=align_device
)
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
device=align_device,
)
# Speaker diarization
if hf_token:
print("[4/4] Running speaker diarization...")
from whisperx.diarize import DiarizationPipeline
diarize_model = DiarizationPipeline(
token=hf_token, device=align_device
)
diarize_segments = diarize_model(audio)
result = whisperx.assign_word_speakers(diarize_segments, result)
else:
print("[4/4] Skipping diarization (no HF token provided).")
for seg in result["segments"]:
seg["speaker"] = "SPEAKER_00"
# Clean up
del model_a
if device != "cpu":
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return result["segments"]
def format_transcript(segments: list[dict]) -> str:
"""Format segments into a readable transcript grouped by speaker turns."""
lines = []
current_speaker = None
for seg in segments:
speaker = seg.get("speaker", "Unknown")
text = seg.get("text", "").strip()
if not text:
continue
start = seg.get("start", 0)
end = seg.get("end", 0)
timestamp = f"[{_fmt_time(start)} -> {_fmt_time(end)}]"
# Group consecutive segments by the same speaker
if speaker != current_speaker:
if current_speaker is not None:
lines.append("") # blank line between speakers
lines.append(f"{speaker}:")
current_speaker = speaker
lines.append(f" {timestamp} {text}")
return "\n".join(lines) + "\n"
def _fmt_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
return f"{h:02d}:{m:02d}:{s:02d}"
def main():
parser = argparse.ArgumentParser(
description="Transcribe a video/audio file with speaker diarization."
)
parser.add_argument("input", help="Path to the video or audio file")
parser.add_argument(
"--language", "-l", default=None,
help="Language code (e.g. 'en', 'es'). Auto-detected if omitted.",
)
parser.add_argument(
"--model", "-m", default="large-v3",
help="Whisper model size (default: large-v3). Options: tiny, base, small, medium, large-v3",
)
parser.add_argument(
"--output", "-o", default=None,
help="Output .txt file path (default: <input_name>.transcript.txt)",
)
parser.add_argument(
"--hf-token", default=None,
help="HuggingFace token for speaker diarization (or set HF_TOKEN env var). "
"Get one at: https://huggingface.co/settings/tokens "
"and accept the pyannote conditions at: "
"https://huggingface.co/pyannote/speaker-diarization-3.1 "
"https://huggingface.co/pyannote/segmentation-3.0",
)
parser.add_argument(
"--device", default="auto",
choices=["auto", "cpu", "cuda", "mps"],
help="Compute device (default: auto-detect)",
)
args = parser.parse_args()
# Validate input file
if not os.path.isfile(args.input):
print(f"Error: file not found: {args.input}", file=sys.stderr)
sys.exit(1)
# Determine output path
if args.output is None:
base, _ = os.path.splitext(args.input)
args.output = f"{base}.transcript.txt"
# Get HF token (only needed for single-track diarization fallback)
hf_token = args.hf_token or os.environ.get("HF_TOKEN")
# Detect number of audio streams
num_streams = count_audio_streams(args.input)
print(f"[1/4] Found {num_streams} audio stream(s) in: {args.input}")
resolved_device = _get_device(args.device)
if num_streams > 1:
# Multi-track path: use per-track energy to identify speakers.
# 1. Extract tracks separately (for energy analysis)
# 2. Mix into single audio (for transcription)
# 3. Transcribe the mix once
# 4. For each segment, the loudest track = the speaker
print(" Multi-track detected — using track-energy speaker identification.")
track_paths = extract_multi_track_audio(args.input)
try:
import numpy as np
import whisperx
# Load each individual track as numpy array (for energy comparison)
print("[2/4] Loading individual tracks...")
track_audios = [whisperx.load_audio(tp) for tp in track_paths]
# Mix tracks in numpy (avoids another ffmpeg subprocess + temp file)
print("[3/4] Transcribing mixed audio...")
max_len = max(len(a) for a in track_audios)
mixed = np.zeros(max_len, dtype=np.float32)
for a in track_audios:
mixed[:len(a)] += a
mixed /= len(track_audios) # normalize by number of tracks
mixed = np.clip(mixed, -1.0, 1.0)
compute_type = "float16" if resolved_device != "cpu" else "int8"
model = whisperx.load_model(
args.model,
device=resolved_device,
compute_type=compute_type,
language=args.language,
)
asr_result = model.transcribe(mixed, batch_size=16, language=args.language)
language_detected = asr_result.get("language", "unknown")
print(f" Detected language: {language_detected}")
model_a, metadata = whisperx.load_align_model(
language_code=language_detected, device=resolved_device
)
asr_result = whisperx.align(
asr_result["segments"],
model_a,
metadata,
mixed,
device=resolved_device,
)
del model_a
# Assign speakers based on per-track energy
print("[4/4] Assigning speakers by track energy...")
segments = assign_speakers_by_energy(asr_result["segments"], track_audios)
finally:
# Clean up temp track files
for tp in track_paths:
if os.path.exists(tp):
os.unlink(tp)
else:
# Single-track path: use pyannote diarization
if not hf_token:
print(
"WARNING: No HuggingFace token provided.\n"
"Speaker diarization will be skipped (all text assigned to SPEAKER_00).\n"
"To enable speaker identification:\n"
" 1. Create a token: https://huggingface.co/settings/tokens\n"
" 2. Accept conditions: https://huggingface.co/pyannote/speaker-diarization-3.1\n"
" and https://huggingface.co/pyannote/segmentation-3.0\n"
" 3. Pass it with --hf-token TOKEN or set HF_TOKEN env var.\n",
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_audio_path = tmp.name
try:
extract_single_audio(args.input, tmp_audio_path)
segments = transcribe_with_diarization(
tmp_audio_path,
language=args.language,
model_name=args.model,
device=args.device,
hf_token=hf_token,
)
finally:
if os.path.exists(tmp_audio_path):
os.unlink(tmp_audio_path)
# Format and write output
transcript = format_transcript(segments)
with open(args.output, "w", encoding="utf-8") as f:
f.write(transcript)
print(f"\nDone! Transcript saved to: {args.output}")
print(f" {len(segments)} segments written.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment