Skip to content

Instantly share code, notes, and snippets.

@orangepeelbeef
Last active October 11, 2023 20:43
Show Gist options
  • Save orangepeelbeef/6c776e782d2a2308baae63d0b54fa391 to your computer and use it in GitHub Desktop.
Save orangepeelbeef/6c776e782d2a2308baae63d0b54fa391 to your computer and use it in GitHub Desktop.
whisperx transcribe python word based timing SRT output
import whisperx
import gc
import argparse
import os
from whisperx.utils import get_writer
YOUR_HF_TOKEN = ''
parser = argparse.ArgumentParser()
parser.add_argument("--file", help="the filename to transcribe", type=str, required=True)
parser.add_argument("--model", help="the name of the model\n['tiny.en', 'tiny', 'base.en', 'base', 'small.en', "
"'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large']", type=str,
required=True)
args = parser.parse_args()
device = "cuda"
audio_file = args.file
batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
# 1. Transcribe with original whisper (batched)
print(f"Loading model {args.model}")
model = whisperx.load_model(args.model, device, compute_type=compute_type)
print(f"Loading file {audio_file}")
audio = whisperx.load_audio(audio_file)
print("Transcribing...")
result = model.transcribe(audio, batch_size=batch_size)
print(result)
# print(result["segments"]) # before alignment
# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model
# 2. Align whisper output
print("Loading Alignment model")
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
# we seem to lose language after alignment so let's capture it
language = result["language"]
print("Aligning data...")
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
# re-add the language
result["language"] = language
# print(result["segments"]) # after alignment
# delete model if low on GPU resources
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
print("Load diarization pipeline")
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
# add min/max number of speakers if known
print("Diarize segments")
diarize_segments = diarize_model(audio, min_speakers=1, max_speakers=5)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
print("Assign speakers")
result = whisperx.assign_word_speakers(diarize_segments, result)
# print(diarize_segments)
# print(result["segments"]) # segments are now assigned speaker IDs
name_no_ext = os.path.splitext(args.file)
print("Write SRT file")
writer = get_writer(output_format='srt', output_dir='.')
with open(f"{name_no_ext[0]}-{args.model}.srt", 'w', encoding='utf-8') as srt_file:
writer.write_result(result, file=srt_file,
options={"max_line_width": 55, "max_line_count": 2, "highlight_words": True})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment