Last active
October 11, 2023 20:43
-
-
Save orangepeelbeef/6c776e782d2a2308baae63d0b54fa391 to your computer and use it in GitHub Desktop.
whisperx transcribe python word based timing SRT output
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import whisperx | |
import gc | |
import argparse | |
import os | |
from whisperx.utils import get_writer | |
YOUR_HF_TOKEN = '' | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--file", help="the filename to transcribe", type=str, required=True) | |
parser.add_argument("--model", help="the name of the model\n['tiny.en', 'tiny', 'base.en', 'base', 'small.en', " | |
"'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large']", type=str, | |
required=True) | |
args = parser.parse_args() | |
device = "cuda" | |
audio_file = args.file | |
batch_size = 16 # reduce if low on GPU mem | |
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy) | |
# 1. Transcribe with original whisper (batched) | |
print(f"Loading model {args.model}") | |
model = whisperx.load_model(args.model, device, compute_type=compute_type) | |
print(f"Loading file {audio_file}") | |
audio = whisperx.load_audio(audio_file) | |
print("Transcribing...") | |
result = model.transcribe(audio, batch_size=batch_size) | |
print(result) | |
# print(result["segments"]) # before alignment | |
# delete model if low on GPU resources | |
# import gc; gc.collect(); torch.cuda.empty_cache(); del model | |
# 2. Align whisper output | |
print("Loading Alignment model") | |
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) | |
# we seem to lose language after alignment so let's capture it | |
language = result["language"] | |
print("Aligning data...") | |
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) | |
# re-add the language | |
result["language"] = language | |
# print(result["segments"]) # after alignment | |
# delete model if low on GPU resources | |
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a | |
# 3. Assign speaker labels | |
print("Load diarization pipeline") | |
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device) | |
# add min/max number of speakers if known | |
print("Diarize segments") | |
diarize_segments = diarize_model(audio, min_speakers=1, max_speakers=5) | |
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) | |
print("Assign speakers") | |
result = whisperx.assign_word_speakers(diarize_segments, result) | |
# print(diarize_segments) | |
# print(result["segments"]) # segments are now assigned speaker IDs | |
name_no_ext = os.path.splitext(args.file) | |
print("Write SRT file") | |
writer = get_writer(output_format='srt', output_dir='.') | |
with open(f"{name_no_ext[0]}-{args.model}.srt", 'w', encoding='utf-8') as srt_file: | |
writer.write_result(result, file=srt_file, | |
options={"max_line_width": 55, "max_line_count": 2, "highlight_words": True}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment