Created
January 24, 2023 04:04
-
-
Save alexlyzhov/72f194f4512f9abfba5f95617f81aac4 to your computer and use it in GitHub Desktop.
Whisper to json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# based on https://github.com/ANonEntity/WhisperWithVAD | |
import torch | |
import whisper | |
import os | |
import ffmpeg | |
import srt | |
from tqdm import tqdm | |
import datetime | |
import urllib.request | |
import json | |
from glob import glob | |
model_size = 'medium' | |
language = "english" # @param {type:"string"} | |
translation_mode = "End-to-end Whisper (default)" # @param ["End-to-end Whisper (default)", "Whisper -> DeepL", "No translation"] | |
# @markdown Advanced settings: | |
deepl_authkey = "" # @param {type:"string"} | |
chunk_threshold = 3.0 # @param {type:"number"} | |
max_attempts = 1 # @param {type:"integer"} | |
# Configuration | |
assert max_attempts >= 1 | |
assert chunk_threshold >= 0.1 | |
# assert audio_path != "" | |
assert language != "" | |
if translation_mode == "End-to-end Whisper (default)": | |
task = "translate" | |
run_deepl = False | |
elif translation_mode == "Whisper -> DeepL": | |
task = "transcribe" | |
run_deepl = True | |
elif translation_mode == "No translation": | |
task = "transcribe" | |
run_deepl = False | |
else: | |
raise ValueError("Invalid translation mode") | |
inputs = glob('/Users/alexlyzhov/Documents/recordings/*.mp3') + glob('/Users/alexlyzhov/Documents/recordings/*.wav') | |
tmp_path = '/Users/alexlyzhov/Documents/recordings/tmp/vad_chunks' | |
todo_inputs = [] | |
for input in inputs: | |
wo_ext = os.path.splitext(input)[0] | |
srt_file = wo_ext + '.srt' | |
vtt_file = wo_ext + '.vtt' | |
txt_file = wo_ext + '.txt' | |
add_srt_file = input + '.srt' | |
add_vtt_file = input + '.vtt' | |
add_txt_file = input + '.txt' | |
todo = not (os.path.exists(srt_file) or os.path.exists(vtt_file) or os.path.exists(add_srt_file) or os.path.exists(add_vtt_file) | |
or os.path.exists(txt_file) or os.path.exists(add_txt_file)) | |
# print(srt_file, vtt_file, add_srt_file, add_vtt_file, todo) | |
if todo: | |
todo_inputs.append(input) | |
print(todo_inputs) | |
def encode(audio_path): | |
print("Encoding audio...") | |
if not os.path.exists(tmp_path): | |
os.mkdir(tmp_path) | |
ffmpeg.input(audio_path).output( | |
os.path.join(tmp_path, "silero_temp.wav"), | |
ar="16000", | |
ac="1", | |
acodec="pcm_s16le", | |
map_metadata="-1", | |
fflags="+bitexact", | |
).overwrite_output().run(quiet=True) | |
for audio_path in todo_inputs: | |
encode(audio_path) | |
print("Running VAD...") | |
model, utils = torch.hub.load( | |
repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=False | |
) | |
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils | |
# Generate VAD timestamps | |
VAD_SR = 16000 | |
wav = read_audio(os.path.join(tmp_path, "silero_temp.wav"), sampling_rate=VAD_SR) | |
t = get_speech_timestamps(wav, model, sampling_rate=VAD_SR) | |
# Add a bit of padding, and remove small gaps | |
for i in range(len(t)): | |
t[i]["start"] = max(0, t[i]["start"] - 3200) # 0.2s head | |
t[i]["end"] = min(wav.shape[0] - 16, t[i]["end"] + 20800) # 1.3s tail | |
if i > 0 and t[i]["start"] < t[i - 1]["end"]: | |
t[i]["start"] = t[i - 1]["end"] # Remove overlap | |
# If breaks are longer than chunk_threshold seconds, split into a new audio file | |
# This'll effectively turn long transcriptions into many shorter ones | |
u = [[]] | |
for i in range(len(t)): | |
if i > 0 and t[i]["start"] > t[i - 1]["end"] + (chunk_threshold * VAD_SR): | |
u.append([]) | |
u[-1].append(t[i]) | |
# Merge speech chunks | |
for i in range(len(u)): | |
save_audio( | |
os.path.join(tmp_path, str(i) + ".wav"), | |
collect_chunks(u[i], wav), | |
sampling_rate=VAD_SR, | |
) | |
os.remove(os.path.join(tmp_path, "silero_temp.wav")) | |
# Convert timestamps to seconds | |
for i in range(len(u)): | |
time = 0.0 | |
offset = 0.0 | |
for j in range(len(u[i])): | |
u[i][j]["start"] /= VAD_SR | |
u[i][j]["end"] /= VAD_SR | |
u[i][j]["chunk_start"] = time | |
time += u[i][j]["end"] - u[i][j]["start"] | |
u[i][j]["chunk_end"] = time | |
if j == 0: | |
offset += u[i][j]["start"] | |
else: | |
offset += u[i][j]["start"] - u[i][j - 1]["end"] | |
u[i][j]["offset"] = offset | |
# Run Whisper on each audio chunk | |
print("Running Whisper...") | |
model = whisper.load_model(model_size) | |
subs = [] | |
segment_info = [] | |
sub_index = 1 | |
suppress_low = [ | |
"Thank you", | |
"Thanks for", | |
"ike and ", | |
"Bye.", | |
"Bye!", | |
"Bye bye!", | |
"lease sub", | |
"The end.", | |
"視聴", | |
] | |
suppress_high = [ | |
"ubscribe", | |
"my channel", | |
"the channel", | |
"our channel", | |
"ollow me on", | |
"for watching", | |
"hank you for watching", | |
"for your viewing", | |
"r viewing", | |
"Amara", | |
"next video", | |
"full video", | |
"ranslation by", | |
"ranslated by", | |
"ee you next week", | |
"ご視聴", | |
"視聴ありがとうございました", | |
] | |
for i in tqdm(range(len(u))): | |
line_buffer = [] # Used for DeepL | |
for x in range(max_attempts): | |
result = model.transcribe( | |
os.path.join(tmp_path, str(i) + ".wav"), task=task, language=language | |
) | |
# Break if result doesn't end with severe hallucinations | |
if len(result["segments"]) == 0: | |
break | |
elif result["segments"][-1]["end"] < u[i][-1]["chunk_end"] + 10.0: | |
break | |
elif x+1 < max_attempts: | |
print("Retrying chunk", i) | |
for r in result["segments"]: | |
# Skip audio timestamped after the chunk has ended | |
if r["start"] > u[i][-1]["chunk_end"]: | |
continue | |
# Reduce log probability for certain words/phrases | |
for s in suppress_low: | |
if s in r["text"]: | |
r["avg_logprob"] -= 0.15 | |
for s in suppress_high: | |
if s in r["text"]: | |
r["avg_logprob"] -= 0.35 | |
# Keep segment info for debugging | |
del r["tokens"] | |
segment_info.append(r) | |
# Skip if log prob is low or no speech prob is high | |
if r["avg_logprob"] < -1.0 or r["no_speech_prob"] > 0.7: | |
continue | |
# Set start timestamp | |
start = r["start"] + u[i][0]["offset"] | |
for j in range(len(u[i])): | |
if ( | |
r["start"] >= u[i][j]["chunk_start"] | |
and r["start"] <= u[i][j]["chunk_end"] | |
): | |
start = r["start"] + u[i][j]["offset"] | |
break | |
# Prevent overlapping subs | |
if len(subs) > 0: | |
last_end = datetime.timedelta.total_seconds(subs[-1].end) | |
if last_end > start: | |
subs[-1].end = datetime.timedelta(seconds=start) | |
# Set end timestamp | |
end = u[i][-1]["end"] + 0.5 | |
for j in range(len(u[i])): | |
if r["end"] >= u[i][j]["chunk_start"] and r["end"] <= u[i][j]["chunk_end"]: | |
end = r["end"] + u[i][j]["offset"] | |
break | |
# Add to SRT list | |
subs.append( | |
srt.Subtitle( | |
index=sub_index, | |
start=datetime.timedelta(seconds=start), | |
end=datetime.timedelta(seconds=end), | |
content=r["text"].strip(), | |
) | |
) | |
sub_index += 1 | |
with open("segment_info.json", "w", encoding="utf8") as f: | |
json.dump(segment_info, f, indent=4) | |
out_path = os.path.splitext(audio_path)[0] + ".srt" | |
with open(out_path, "w", encoding="utf8") as f: | |
f.write(srt.compose(subs)) | |
print("\nDone! Subs written to", out_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment