Created
April 13, 2026 12:44
-
-
Save Rastrian/efb147264b76a54c3d290f506c1fccdb to your computer and use it in GitHub Desktop.
Transcription (mkv video to txt using whisperx with diarization)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/bin/bash | |
| # Setup script for transcript-app | |
| set -e | |
| SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" | |
| cd "$SCRIPT_DIR" | |
| echo "Creating virtual environment..." | |
| python3 -m venv .venv | |
| echo "Activating virtual environment..." | |
| source .venv/bin/activate | |
| echo "Installing PyTorch (CPU for macOS)..." | |
| pip install torch torchaudio | |
| echo "Installing whisperx..." | |
| pip install git+https://github.com/m-bain/whisperx.git | |
| echo "" | |
| echo "Setup complete! To use:" | |
| echo " source .venv/bin/activate" | |
| echo " python3 transcribe.py /path/to/video.mkv --hf-token YOUR_TOKEN" | |
| echo "" | |
| echo "Optional: set HF_TOKEN env var to avoid passing --hf-token every time:" | |
| echo " export HF_TOKEN=your_token_here" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Transcribe a video/audio file with speaker diarization using whisperx. | |
| Usage: | |
| # First-time setup: | |
| python3 -m venv .venv | |
| source .venv/bin/activate | |
| pip install -r requirements.txt | |
| # Run (will prompt for HF token if not set): | |
| python3 transcribe.py /path/to/video.mkv | |
| python3 transcribe.py /path/to/video.mkv --language en --model large-v3 --output transcript.txt | |
| """ | |
| import argparse | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| from concurrent.futures import ThreadPoolExecutor | |
| def count_audio_streams(video_path: str) -> int: | |
| """Count the number of audio streams in a video file.""" | |
| cmd = [ | |
| "ffprobe", "-v", "error", | |
| "-select_streams", "a", | |
| "-show_entries", "stream=index", | |
| "-of", "csv=p=0", | |
| video_path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| return 0 | |
| return len([line for line in result.stdout.strip().split("\n") if line.strip()]) | |
| def extract_single_audio(video_path: str, audio_path: str) -> None: | |
| """Extract the first (or only) audio track from video into a mono WAV.""" | |
| print(f" Extracting audio from: {video_path}") | |
| cmd = [ | |
| "ffmpeg", "-i", video_path, | |
| "-vn", | |
| "-acodec", "pcm_s16le", | |
| "-ar", "16000", | |
| "-ac", "1", | |
| "-y", | |
| audio_path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| print(f"ffmpeg error:\n{result.stderr}", file=sys.stderr) | |
| sys.exit(1) | |
| print(" Audio extracted successfully.") | |
| def _is_audio_silent(wav_path: str, threshold_db: float = -40.0) -> bool: | |
| """Check if a WAV file is effectively silent using RMS energy. | |
| Uses ffmpeg's volumedetect filter to measure the mean volume in dB. | |
| Returns True if the volume is below the threshold (i.e., silent). | |
| Default threshold of -40 dB filters out near-silent/crosstalk-only tracks. | |
| """ | |
| cmd = [ | |
| "ffmpeg", "-i", wav_path, | |
| "-af", f"volumedetect", | |
| "-f", "null", "-", | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| # Parse mean_volume from stderr, e.g. "mean_volume: -60.0 dB" | |
| for line in result.stderr.splitlines(): | |
| if "mean_volume" in line: | |
| try: | |
| volume_db = float(line.split(":")[-1].strip().split()[0]) | |
| return volume_db < threshold_db | |
| except (ValueError, IndexError): | |
| pass | |
| # If we can't parse volume, treat as not silent to be safe | |
| return False | |
| def _extract_track(video_path: str, track_index: int) -> str | None: | |
| """Extract a single audio track. Returns path, or None if extraction failed.""" | |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| tmp.close() | |
| track_path = tmp.name | |
| cmd_extract = [ | |
| "ffmpeg", "-i", video_path, | |
| "-map", f"0:a:{track_index}", | |
| "-acodec", "pcm_s16le", | |
| "-ar", "16000", | |
| "-ac", "1", | |
| "-y", | |
| track_path, | |
| ] | |
| result = subprocess.run(cmd_extract, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| os.unlink(track_path) | |
| return None | |
| return track_path | |
| def extract_multi_track_audio(video_path: str) -> list[str]: | |
| """Extract each audio track from video into separate mono WAV files. | |
| Extracts tracks in parallel. Returns ALL tracks (even quiet ones), | |
| since each track represents a speaker. | |
| The caller is responsible for cleaning up these files. | |
| """ | |
| num_streams = count_audio_streams(video_path) | |
| print(f" Found {num_streams} audio stream(s) — extracting in parallel.") | |
| track_paths = [None] * num_streams | |
| with ThreadPoolExecutor(max_workers=num_streams) as executor: | |
| futures = { | |
| executor.submit(_extract_track, video_path, i): i | |
| for i in range(num_streams) | |
| } | |
| for future in futures: | |
| i = futures[future] | |
| path = future.result() | |
| if path is not None: | |
| print(f" Track {i} extracted.") | |
| track_paths[i] = path | |
| else: | |
| print(f" Track {i} failed to extract.") | |
| track_paths = [p for p in track_paths if p is not None] | |
| if not track_paths: | |
| print("Error: no audio tracks could be extracted.", file=sys.stderr) | |
| sys.exit(1) | |
| print(f" {len(track_paths)} track(s) extracted successfully.") | |
| return track_paths | |
| def _get_device(device: str) -> str: | |
| """Resolve 'auto' device to a concrete device string.""" | |
| import torch | |
| if device == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return device | |
| def _rms_energy(audio: "numpy.ndarray", start_sample: int, end_sample: int) -> float: | |
| """Compute RMS energy of an audio segment.""" | |
| import numpy as np | |
| segment = audio[start_sample:end_sample] | |
| if len(segment) == 0: | |
| return 0.0 | |
| return float(np.sqrt(np.mean(segment ** 2))) | |
| def assign_speakers_by_energy( | |
| segments: list[dict], | |
| track_audios: list["numpy.ndarray"], | |
| sample_rate: int = 16000, | |
| ) -> list[dict]: | |
| """Assign speaker labels to segments based on per-track audio energy. | |
| For each segment, computes the RMS energy on each individual track | |
| during the segment's time window. The track with the highest energy | |
| (i.e., the loudest microphone) is the speaker. | |
| """ | |
| import numpy as np | |
| num_tracks = len(track_audios) | |
| for seg in segments: | |
| start_sample = int(seg.get("start", 0) * sample_rate) | |
| end_sample = int(seg.get("end", 0) * sample_rate) | |
| # Compute RMS energy per track in one vectorized pass | |
| energies = np.array([ | |
| _rms_energy(audio, start_sample, end_sample) | |
| for audio in track_audios | |
| ]) | |
| best_track = int(np.argmax(energies)) | |
| seg["speaker"] = f"SPEAKER_{best_track:02d}" | |
| return segments | |
| def transcribe_with_diarization( | |
| audio_path: str, | |
| language: str | None = None, | |
| model_name: str = "large-v3", | |
| device: str = "auto", | |
| hf_token: str | None = None, | |
| ) -> list[dict]: | |
| """ | |
| Run whisperx transcription + speaker diarization (single-track path). | |
| Returns a list of segments with start, end, text, and speaker. | |
| """ | |
| import torch | |
| import whisperx | |
| asr_device = _get_device(device) | |
| print(f"[2/4] Loading whisperx model '{model_name}' on {asr_device}...") | |
| compute_type = "float16" if asr_device != "cpu" else "int8" | |
| model = whisperx.load_model( | |
| model_name, | |
| device=asr_device, | |
| compute_type=compute_type, | |
| language=language, | |
| ) | |
| audio = whisperx.load_audio(audio_path) | |
| result = model.transcribe(audio, batch_size=16, language=language) | |
| language_detected = result.get("language", "unknown") | |
| print(f" Detected language: {language_detected}") | |
| # Align transcription for precise word-level timestamps | |
| align_device = device if device != "auto" else asr_device | |
| print(f"[3/4] Aligning transcription (language: {language_detected}) on {align_device}...") | |
| model_a, metadata = whisperx.load_align_model( | |
| language_code=language_detected, device=align_device | |
| ) | |
| result = whisperx.align( | |
| result["segments"], | |
| model_a, | |
| metadata, | |
| audio, | |
| device=align_device, | |
| ) | |
| # Speaker diarization | |
| if hf_token: | |
| print("[4/4] Running speaker diarization...") | |
| from whisperx.diarize import DiarizationPipeline | |
| diarize_model = DiarizationPipeline( | |
| token=hf_token, device=align_device | |
| ) | |
| diarize_segments = diarize_model(audio) | |
| result = whisperx.assign_word_speakers(diarize_segments, result) | |
| else: | |
| print("[4/4] Skipping diarization (no HF token provided).") | |
| for seg in result["segments"]: | |
| seg["speaker"] = "SPEAKER_00" | |
| # Clean up | |
| del model_a | |
| if device != "cpu": | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| return result["segments"] | |
| def format_transcript(segments: list[dict]) -> str: | |
| """Format segments into a readable transcript grouped by speaker turns.""" | |
| lines = [] | |
| current_speaker = None | |
| for seg in segments: | |
| speaker = seg.get("speaker", "Unknown") | |
| text = seg.get("text", "").strip() | |
| if not text: | |
| continue | |
| start = seg.get("start", 0) | |
| end = seg.get("end", 0) | |
| timestamp = f"[{_fmt_time(start)} -> {_fmt_time(end)}]" | |
| # Group consecutive segments by the same speaker | |
| if speaker != current_speaker: | |
| if current_speaker is not None: | |
| lines.append("") # blank line between speakers | |
| lines.append(f"{speaker}:") | |
| current_speaker = speaker | |
| lines.append(f" {timestamp} {text}") | |
| return "\n".join(lines) + "\n" | |
| def _fmt_time(seconds: float) -> str: | |
| """Format seconds as HH:MM:SS.""" | |
| h = int(seconds // 3600) | |
| m = int((seconds % 3600) // 60) | |
| s = int(seconds % 60) | |
| return f"{h:02d}:{m:02d}:{s:02d}" | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Transcribe a video/audio file with speaker diarization." | |
| ) | |
| parser.add_argument("input", help="Path to the video or audio file") | |
| parser.add_argument( | |
| "--language", "-l", default=None, | |
| help="Language code (e.g. 'en', 'es'). Auto-detected if omitted.", | |
| ) | |
| parser.add_argument( | |
| "--model", "-m", default="large-v3", | |
| help="Whisper model size (default: large-v3). Options: tiny, base, small, medium, large-v3", | |
| ) | |
| parser.add_argument( | |
| "--output", "-o", default=None, | |
| help="Output .txt file path (default: <input_name>.transcript.txt)", | |
| ) | |
| parser.add_argument( | |
| "--hf-token", default=None, | |
| help="HuggingFace token for speaker diarization (or set HF_TOKEN env var). " | |
| "Get one at: https://huggingface.co/settings/tokens " | |
| "and accept the pyannote conditions at: " | |
| "https://huggingface.co/pyannote/speaker-diarization-3.1 " | |
| "https://huggingface.co/pyannote/segmentation-3.0", | |
| ) | |
| parser.add_argument( | |
| "--device", default="auto", | |
| choices=["auto", "cpu", "cuda", "mps"], | |
| help="Compute device (default: auto-detect)", | |
| ) | |
| args = parser.parse_args() | |
| # Validate input file | |
| if not os.path.isfile(args.input): | |
| print(f"Error: file not found: {args.input}", file=sys.stderr) | |
| sys.exit(1) | |
| # Determine output path | |
| if args.output is None: | |
| base, _ = os.path.splitext(args.input) | |
| args.output = f"{base}.transcript.txt" | |
| # Get HF token (only needed for single-track diarization fallback) | |
| hf_token = args.hf_token or os.environ.get("HF_TOKEN") | |
| # Detect number of audio streams | |
| num_streams = count_audio_streams(args.input) | |
| print(f"[1/4] Found {num_streams} audio stream(s) in: {args.input}") | |
| resolved_device = _get_device(args.device) | |
| if num_streams > 1: | |
| # Multi-track path: use per-track energy to identify speakers. | |
| # 1. Extract tracks separately (for energy analysis) | |
| # 2. Mix into single audio (for transcription) | |
| # 3. Transcribe the mix once | |
| # 4. For each segment, the loudest track = the speaker | |
| print(" Multi-track detected — using track-energy speaker identification.") | |
| track_paths = extract_multi_track_audio(args.input) | |
| try: | |
| import numpy as np | |
| import whisperx | |
| # Load each individual track as numpy array (for energy comparison) | |
| print("[2/4] Loading individual tracks...") | |
| track_audios = [whisperx.load_audio(tp) for tp in track_paths] | |
| # Mix tracks in numpy (avoids another ffmpeg subprocess + temp file) | |
| print("[3/4] Transcribing mixed audio...") | |
| max_len = max(len(a) for a in track_audios) | |
| mixed = np.zeros(max_len, dtype=np.float32) | |
| for a in track_audios: | |
| mixed[:len(a)] += a | |
| mixed /= len(track_audios) # normalize by number of tracks | |
| mixed = np.clip(mixed, -1.0, 1.0) | |
| compute_type = "float16" if resolved_device != "cpu" else "int8" | |
| model = whisperx.load_model( | |
| args.model, | |
| device=resolved_device, | |
| compute_type=compute_type, | |
| language=args.language, | |
| ) | |
| asr_result = model.transcribe(mixed, batch_size=16, language=args.language) | |
| language_detected = asr_result.get("language", "unknown") | |
| print(f" Detected language: {language_detected}") | |
| model_a, metadata = whisperx.load_align_model( | |
| language_code=language_detected, device=resolved_device | |
| ) | |
| asr_result = whisperx.align( | |
| asr_result["segments"], | |
| model_a, | |
| metadata, | |
| mixed, | |
| device=resolved_device, | |
| ) | |
| del model_a | |
| # Assign speakers based on per-track energy | |
| print("[4/4] Assigning speakers by track energy...") | |
| segments = assign_speakers_by_energy(asr_result["segments"], track_audios) | |
| finally: | |
| # Clean up temp track files | |
| for tp in track_paths: | |
| if os.path.exists(tp): | |
| os.unlink(tp) | |
| else: | |
| # Single-track path: use pyannote diarization | |
| if not hf_token: | |
| print( | |
| "WARNING: No HuggingFace token provided.\n" | |
| "Speaker diarization will be skipped (all text assigned to SPEAKER_00).\n" | |
| "To enable speaker identification:\n" | |
| " 1. Create a token: https://huggingface.co/settings/tokens\n" | |
| " 2. Accept conditions: https://huggingface.co/pyannote/speaker-diarization-3.1\n" | |
| " and https://huggingface.co/pyannote/segmentation-3.0\n" | |
| " 3. Pass it with --hf-token TOKEN or set HF_TOKEN env var.\n", | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| tmp_audio_path = tmp.name | |
| try: | |
| extract_single_audio(args.input, tmp_audio_path) | |
| segments = transcribe_with_diarization( | |
| tmp_audio_path, | |
| language=args.language, | |
| model_name=args.model, | |
| device=args.device, | |
| hf_token=hf_token, | |
| ) | |
| finally: | |
| if os.path.exists(tmp_audio_path): | |
| os.unlink(tmp_audio_path) | |
| # Format and write output | |
| transcript = format_transcript(segments) | |
| with open(args.output, "w", encoding="utf-8") as f: | |
| f.write(transcript) | |
| print(f"\nDone! Transcript saved to: {args.output}") | |
| print(f" {len(segments)} segments written.") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment