Last active
March 6, 2025 13:51
-
-
Save haasr/7638a4e72056ba3108a6be171f8cb534 to your computer and use it in GitHub Desktop.
WhisperX Transcription GUI (frontend for m-bain/whisperX; does not include diarization)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tkinter as tk | |
from tkinter import ttk, scrolledtext, filedialog, messagebox | |
import threading | |
import whisperx | |
import os | |
import sys | |
import datetime | |
import shutil | |
from pathlib import Path | |
import queue | |
import torch | |
import platform | |
import subprocess | |
import importlib | |
import threading | |
import time | |
import sys | |
# Note: numpy==1.26.4 is used because numpy.NaN --> numpy.nan in 2.x which will cause exception in whisperx | |
class SplashScreen(tk.Toplevel): | |
def __init__(self): | |
super().__init__() | |
# Remove window decorations | |
self.overrideredirect(True) | |
# Get screen dimensions | |
screen_width = self.winfo_screenwidth() | |
screen_height = self.winfo_screenheight() | |
# Set window dimensions and position | |
width = 300 | |
height = 200 | |
x = (screen_width - width) // 2 | |
y = (screen_height - height) // 2 | |
self.geometry(f"{width}x{height}+{x}+{y}") | |
# Configure the window | |
self.configure(bg='#2d2d2d') | |
# Create and pack widgets | |
title = tk.Label( | |
self, | |
text="WhisperX GUI", | |
font=("Helvetica", 16, "bold"), | |
bg='#2d2d2d', | |
fg='white' | |
) | |
title.pack(pady=(20, 10)) | |
loading_text = tk.Label( | |
self, | |
text="Loading dependencies...", | |
font=("Helvetica", 10), | |
bg='#2d2d2d', | |
fg='white' | |
) | |
loading_text.pack(pady=5) | |
# Create a progress bar | |
self.progress = ttk.Progressbar( | |
self, | |
mode='indeterminate', | |
length=200 | |
) | |
self.progress.pack(pady=10) | |
# Start progress bar animation | |
self.progress.start() | |
# Keep track of loaded modules | |
self.status_label = tk.Label( | |
self, | |
text="", | |
font=("Helvetica", 8), | |
bg='#2d2d2d', | |
fg='#a0a0a0', | |
wraplength=280 | |
) | |
self.status_label.pack(pady=5) | |
# Center the window | |
self.update_idletasks() | |
# Make this window stay on top | |
self.attributes('-topmost', True) | |
def update_status(self, text): | |
self.status_label.config(text=text) | |
self.update() | |
def import_with_status(splash_screen): | |
"""Import all required modules while updating the splash screen""" | |
modules_to_import = [ | |
('tkinter.scrolledtext', 'GUI Components'), | |
('tkinter.filedialog', 'File Dialog'), | |
('tkinter.messagebox', 'Message Boxes'), | |
('whisperx', 'WhisperX Core'), | |
('torch', 'PyTorch'), | |
('datetime', 'DateTime Utils'), | |
('pathlib', 'Path Utils'), | |
('queue', 'Threading Queue'), | |
('platform', 'Platform Utils'), | |
('subprocess', 'Subprocess Utils') | |
] | |
imported_modules = {} | |
for module_name, display_name in modules_to_import: | |
splash_screen.update_status(f"Loading {display_name}...") | |
try: | |
imported_modules[module_name] = importlib.import_module(module_name) | |
time.sleep(0.1) # Give a small delay to show the loading message | |
except ImportError as e: | |
splash_screen.update_status(f"Error loading {display_name}: {str(e)}") | |
time.sleep(2) | |
sys.exit(1) | |
return imported_modules | |
# Here's where we define our WhisperXGUI class with all the imported modules | |
class WhisperXGUI: | |
def __init__(self, root, modules): | |
self.root = root | |
self.root.title("WhisperX Transcription GUI") | |
self.root.geometry("800x600") | |
# Unpack needed modules | |
self.scrolledtext = modules['tkinter.scrolledtext'] | |
self.filedialog = modules['tkinter.filedialog'] | |
self.messagebox = modules['tkinter.messagebox'] | |
self.whisperx = modules['whisperx'] | |
self.torch = modules['torch'] | |
self.Path = modules['pathlib'].Path | |
self.queue = modules['queue'].Queue | |
self.platform = modules['platform'] | |
self.subprocess = modules['subprocess'] | |
# Queue for communication between threads | |
self.output_queue = self.queue() | |
# List to store multiple file paths | |
self.file_list = [] | |
self.setup_ui() | |
self.setup_bindings() | |
# Start the output checking loop | |
self.check_output() | |
def setup_ui(self): | |
# Main frame | |
main_frame = ttk.Frame(self.root, padding="10") | |
main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) | |
# File selection | |
ttk.Label(main_frame, text="Input Files:").grid(row=0, column=0, sticky=tk.W) | |
self.files_listbox = tk.Listbox(main_frame, width=50, height=5) | |
self.files_listbox.grid(row=0, column=1, padx=5, sticky=(tk.W, tk.E)) | |
# File buttons frame | |
file_buttons_frame = ttk.Frame(main_frame) | |
file_buttons_frame.grid(row=0, column=2, padx=5) | |
ttk.Button(file_buttons_frame, text="➕ Add Files", command=self.browse_files).grid(row=0, column=0, pady=2) | |
ttk.Button(file_buttons_frame, text="➖ Remove Selected", command=self.remove_selected_file).grid(row=1, column=0, pady=2) | |
ttk.Button(file_buttons_frame, text="🆑 Clear All", command=self.clear_files).grid(row=2, column=0, pady=2) | |
# Model selection | |
ttk.Label(main_frame, text="Model:").grid(row=1, column=0, sticky=tk.W, pady=5) | |
self.model_var = tk.StringVar(value="large-v2") | |
models = ["tiny", "base", "small", "medium", "large-v2"] | |
model_combo = ttk.Combobox(main_frame, textvariable=self.model_var, values=models, state="readonly") | |
model_combo.grid(row=1, column=1, sticky=(tk.W, tk.E), pady=5) | |
# Language selection with full names | |
ttk.Label(main_frame, text="Language:").grid(row=2, column=0, sticky=tk.W, pady=5) | |
self.language_mapping = { | |
"English": "en", | |
"French": "fr", | |
"German": "de", | |
"Spanish": "es", | |
"Italian": "it", | |
"Japanese": "ja", | |
"Chinese": "zh", | |
"Dutch": "nl", | |
"Ukrainian": "uk", | |
"Portuguese": "pt" | |
} | |
self.language_var = tk.StringVar(value="English") | |
language_names = list(self.language_mapping.keys()) | |
language_combo = ttk.Combobox(main_frame, textvariable=self.language_var, values=language_names, state="readonly") | |
language_combo.grid(row=2, column=1, sticky=(tk.W, tk.E), pady=5) | |
# Compute type selection | |
ttk.Label(main_frame, text="Compute Type:").grid(row=3, column=0, sticky=tk.W, pady=5) | |
self.compute_type_var = tk.StringVar(value="float16") | |
compute_types = ["float32", "float16", "int8"] | |
compute_type_combo = ttk.Combobox(main_frame, textvariable=self.compute_type_var, values=compute_types, state="readonly") | |
compute_type_combo.grid(row=3, column=1, sticky=(tk.W, tk.E), pady=5) | |
# Progress and output | |
self.output_text = scrolledtext.ScrolledText(main_frame, height=20, width=70, wrap=tk.WORD) | |
self.output_text.grid(row=4, column=0, columnspan=3, pady=10) | |
self.output_text.config(state='disabled') | |
# Transcribe button | |
self.transcribe_btn = ttk.Button(main_frame, text="✨ Transcribe All", command=self.start_transcription) | |
self.transcribe_btn.grid(row=5, column=0, columnspan=3, pady=10) | |
# Configure grid weights | |
self.root.columnconfigure(0, weight=1) | |
self.root.rowconfigure(0, weight=1) | |
main_frame.columnconfigure(1, weight=1) | |
def browse_files(self): | |
filetypes = ( | |
("Video/audio", "*.mp4 *.mkv *.mov *.wmv *.avi *.flv *.mp3 *.wav *.aac *.flac *.ogg"), | |
("Video", "*.mp4 *.mkv *.mov *.wmv *.avi *.flv"), | |
("Audio", "*.mp3 *.wav *.aac *.flac *.ogg"), | |
("All files", "*.*") | |
) | |
filenames = filedialog.askopenfilenames(filetypes=filetypes) | |
if filenames: | |
for filename in filenames: | |
if filename not in self.file_list: | |
self.file_list.append(filename) | |
self.files_listbox.insert(tk.END, os.path.basename(filename)) | |
def remove_selected_file(self): | |
selection = self.files_listbox.curselection() | |
if selection: | |
index = selection[0] | |
self.files_listbox.delete(index) | |
self.file_list.pop(index) | |
def clear_files(self): | |
self.files_listbox.delete(0, tk.END) | |
self.file_list.clear() | |
def setup_bindings(self): | |
self.root.protocol("WM_DELETE_WINDOW", self.on_closing) | |
def update_output(self, message): | |
self.output_text.config(state='normal') | |
self.output_text.insert(tk.END, message + '\n') | |
self.output_text.see(tk.END) | |
self.output_text.config(state='disabled') | |
def check_output(self): | |
try: | |
while True: | |
message = self.output_queue.get_nowait() | |
self.update_output(message) | |
except queue.Empty: | |
pass | |
finally: | |
self.root.after(100, self.check_output) | |
def create_output_directory(self, input_file): | |
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") | |
base_name = Path(input_file).stem | |
output_dir = Path("transcripts") / f"{base_name}_{timestamp}" | |
output_dir.mkdir(parents=True, exist_ok=True) | |
return output_dir | |
def transcribe(self): | |
try: | |
if not self.file_list: | |
messagebox.showerror("Error", "Please select at least one input file") | |
return | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
compute_type = self.compute_type_var.get() | |
language = self.language_mapping[self.language_var.get()] | |
# Load model once for all files | |
self.output_queue.put("Loading WhisperX model...") | |
model = whisperx.load_model( | |
self.model_var.get(), | |
device, | |
compute_type=compute_type | |
) | |
for input_file in self.file_list: | |
self.output_queue.put(f"\nStarting transcription of: {input_file}") | |
# Create output directory | |
output_dir = self.create_output_directory(input_file) | |
self.output_queue.put(f"Output directory: {output_dir}") | |
# Load audio | |
audio = whisperx.load_audio(input_file) | |
# Transcribe | |
self.output_queue.put("Transcribing...") | |
result = model.transcribe(audio, batch_size=16, language=language) | |
# Align | |
self.output_queue.put("Aligning transcript...") | |
model_a, metadata = whisperx.load_align_model( | |
language_code=language, | |
device=device | |
) | |
result = whisperx.align( | |
result["segments"], | |
model_a, | |
metadata, | |
audio, | |
device, | |
return_char_alignments=False | |
) | |
# Save output files | |
srt_path = output_dir / f"{Path(input_file).stem}.srt" | |
txt_path = output_dir / f"{Path(input_file).stem}.txt" | |
# Write SRT file | |
with open(srt_path, 'w', encoding='utf-8') as f: | |
for i, seg in enumerate(result["segments"], 1): | |
start = self.format_timestamp(seg["start"]) | |
end = self.format_timestamp(seg["end"]) | |
f.write(f"{i}\n{start} --> {end}\n{seg['text'].strip()}\n\n") | |
# Write TXT file | |
with open(txt_path, 'w', encoding='utf-8') as f: | |
for seg in result["segments"]: | |
f.write(f"{seg['text'].strip()}\n") | |
self.output_queue.put(f"Transcription complete for: {input_file}") | |
self.output_queue.put(f"Files saved in: {output_dir}") | |
# Open the output directory | |
self.open_folder(output_dir) | |
# Clean up alignment model | |
del model_a | |
torch.cuda.empty_cache() | |
# Clean up whisper model | |
del model | |
torch.cuda.empty_cache() | |
self.output_queue.put("\nAll files processed successfully!") | |
except Exception as e: | |
self.output_queue.put(f"Error during transcription: {str(e)}") | |
messagebox.showerror("Error", f"Transcription failed: {str(e)}") | |
finally: | |
self.transcribe_btn.config(state='normal') | |
def open_folder(self, path): | |
path = os.path.realpath(path) | |
if platform.system() == "Windows": | |
os.startfile(path) | |
elif platform.system() == "Darwin": # macOS | |
subprocess.run(["open", path]) | |
else: # Linux | |
subprocess.run(["xdg-open", path]) | |
def format_timestamp(self, seconds): | |
hours = int(seconds // 3600) | |
minutes = int((seconds % 3600) // 60) | |
seconds = seconds % 60 | |
milliseconds = int((seconds % 1) * 1000) | |
seconds = int(seconds) | |
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" | |
def start_transcription(self): | |
self.transcribe_btn.config(state='disabled') | |
threading.Thread(target=self.transcribe, daemon=True).start() | |
def on_closing(self): | |
if messagebox.askokcancel("Quit", "Do you want to quit?"): | |
self.root.quit() | |
# Function to finish initialization and show main window | |
def finish_init(): | |
app = WhisperXGUI(root, imported_modules) | |
root.deiconify() # Show the main window | |
splash.destroy() # Close the splash screen | |
if __name__ == "__main__": | |
# Create and show the splash screen | |
root = tk.Tk() | |
root.withdraw() # Hide the main window initially | |
splash = SplashScreen() | |
# Import dependencies | |
imported_modules = import_with_status(splash) | |
# Now import our WhisperXGUI class and global variables | |
splash.update_status("Initializing application...") | |
# Schedule the final initialization | |
root.after(1000, finish_init) | |
# Start the main loop | |
root.mainloop() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
aiohappyeyeballs==2.4.4 | |
aiohttp==3.11.11 | |
aiosignal==1.3.2 | |
alembic==1.14.0 | |
altgraph==0.17.4 | |
antlr4-python3-runtime==4.9.3 | |
asteroid-filterbanks==0.4.0 | |
attrs==24.3.0 | |
audioread==3.0.1 | |
av==11.0.0 | |
certifi==2024.12.14 | |
cffi==1.17.1 | |
charset-normalizer==3.4.1 | |
click==8.1.8 | |
colorama==0.4.6 | |
coloredlogs==15.0.1 | |
colorlog==6.9.0 | |
contourpy==1.3.1 | |
ctranslate2==4.4.0 | |
cycler==0.12.1 | |
decorator==5.1.1 | |
docopt==0.6.2 | |
einops==0.8.0 | |
faster-whisper==1.0.0 | |
filelock==3.16.1 | |
flatbuffers==24.12.23 | |
fonttools==4.55.3 | |
frozenlist==1.5.0 | |
fsspec==2024.12.0 | |
greenlet==3.1.1 | |
huggingface-hub==0.27.0 | |
humanfriendly==10.0 | |
HyperPyYAML==1.2.2 | |
idna==3.10 | |
Jinja2==3.1.5 | |
joblib==1.4.2 | |
julius==0.2.7 | |
kiwisolver==1.4.8 | |
lazy_loader==0.4 | |
librosa==0.10.2.post1 | |
lightning==2.5.0.post0 | |
lightning-utilities==0.11.9 | |
llvmlite==0.43.0 | |
Mako==1.3.8 | |
markdown-it-py==3.0.0 | |
MarkupSafe==3.0.2 | |
matplotlib==3.10.0 | |
mdurl==0.1.2 | |
mpmath==1.3.0 | |
msgpack==1.1.0 | |
multidict==6.1.0 | |
networkx==3.4.2 | |
nltk==3.9.1 | |
numba==0.60.0 | |
numpy==1.26.4 | |
omegaconf==2.3.0 | |
onnxruntime==1.20.1 | |
optuna==4.1.0 | |
packaging==24.2 | |
pandas==2.2.3 | |
pefile==2023.2.7 | |
pillow==11.0.0 | |
platformdirs==4.3.6 | |
pooch==1.8.2 | |
primePy==1.3 | |
propcache==0.2.1 | |
protobuf==5.29.2 | |
pyannote.audio==3.1.1 | |
pyannote.core==5.0.0 | |
pyannote.database==5.1.0 | |
pyannote.metrics==3.2.1 | |
pyannote.pipeline==3.0.1 | |
pycparser==2.22 | |
Pygments==2.18.0 | |
pyinstaller-hooks-contrib==2024.11 | |
pyparsing==3.2.0 | |
pyreadline3==3.5.4 | |
python-dateutil==2.9.0.post0 | |
pytorch-lightning==2.5.0.post0 | |
pytorch-metric-learning==2.8.1 | |
pytz==2024.2 | |
pywin32-ctypes==0.2.3 | |
PyYAML==6.0.2 | |
regex==2024.11.6 | |
requests==2.32.3 | |
rich==13.9.4 | |
ruamel.yaml==0.18.6 | |
ruamel.yaml.clib==0.2.12 | |
safetensors==0.4.5 | |
scikit-learn==1.6.0 | |
scipy==1.14.1 | |
semver==3.0.2 | |
sentencepiece==0.2.0 | |
setuptools==75.6.0 | |
shellingham==1.5.4 | |
six==1.17.0 | |
sortedcontainers==2.4.0 | |
soundfile==0.12.1 | |
soxr==0.5.0.post1 | |
speechbrain==1.0.2 | |
SQLAlchemy==2.0.36 | |
sympy==1.13.1 | |
tabulate==0.9.0 | |
tensorboardX==2.6.2.2 | |
threadpoolctl==3.5.0 | |
tokenizers==0.15.2 | |
torch @ https://download.pytorch.org/whl/nightly/cu126/torch-2.6.0.dev20241227%2Bcu126-cp312-cp312-win_amd64.whl#sha256=f52e45693daa9f729c8063e2a31976a49b7bafac93c7b8f8815cfb6026b14183 | |
torch-audiomentations==0.11.1 | |
torch_pitch_shift==1.2.5 | |
torchaudio @ https://download.pytorch.org/whl/nightly/cu126/torchaudio-2.6.0.dev20241228%2Bcu126-cp312-cp312-win_amd64.whl#sha256=8a690788d534fadbafb31668c77b5c23c224bacf8caaa46ee436e38d29fb273a | |
torchmetrics==1.4.1 | |
tqdm==4.66.5 | |
transformers==4.39.3 | |
typer==0.12.5 | |
typing_extensions==4.12.2 | |
tzdata==2024.1 | |
urllib3==2.2.2 | |
whisperx @ git+https://github.com/m-bain/whisperx.git@7307306a9d8dd0d261e588cc933322454f853853 | |
yarl==1.18.3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment