Skip to content

Instantly share code, notes, and snippets.

@haasr
Last active March 6, 2025 13:51
Show Gist options
  • Save haasr/7638a4e72056ba3108a6be171f8cb534 to your computer and use it in GitHub Desktop.
Save haasr/7638a4e72056ba3108a6be171f8cb534 to your computer and use it in GitHub Desktop.
WhisperX Transcription GUI (frontend for m-bain/whisperX; does not include diarization)
import tkinter as tk
from tkinter import ttk, scrolledtext, filedialog, messagebox
import threading
import whisperx
import os
import sys
import datetime
import shutil
from pathlib import Path
import queue
import torch
import platform
import subprocess
import importlib
import threading
import time
import sys
# Note: numpy==1.26.4 is used because numpy.NaN --> numpy.nan in 2.x which will cause exception in whisperx
class SplashScreen(tk.Toplevel):
def __init__(self):
super().__init__()
# Remove window decorations
self.overrideredirect(True)
# Get screen dimensions
screen_width = self.winfo_screenwidth()
screen_height = self.winfo_screenheight()
# Set window dimensions and position
width = 300
height = 200
x = (screen_width - width) // 2
y = (screen_height - height) // 2
self.geometry(f"{width}x{height}+{x}+{y}")
# Configure the window
self.configure(bg='#2d2d2d')
# Create and pack widgets
title = tk.Label(
self,
text="WhisperX GUI",
font=("Helvetica", 16, "bold"),
bg='#2d2d2d',
fg='white'
)
title.pack(pady=(20, 10))
loading_text = tk.Label(
self,
text="Loading dependencies...",
font=("Helvetica", 10),
bg='#2d2d2d',
fg='white'
)
loading_text.pack(pady=5)
# Create a progress bar
self.progress = ttk.Progressbar(
self,
mode='indeterminate',
length=200
)
self.progress.pack(pady=10)
# Start progress bar animation
self.progress.start()
# Keep track of loaded modules
self.status_label = tk.Label(
self,
text="",
font=("Helvetica", 8),
bg='#2d2d2d',
fg='#a0a0a0',
wraplength=280
)
self.status_label.pack(pady=5)
# Center the window
self.update_idletasks()
# Make this window stay on top
self.attributes('-topmost', True)
def update_status(self, text):
self.status_label.config(text=text)
self.update()
def import_with_status(splash_screen):
"""Import all required modules while updating the splash screen"""
modules_to_import = [
('tkinter.scrolledtext', 'GUI Components'),
('tkinter.filedialog', 'File Dialog'),
('tkinter.messagebox', 'Message Boxes'),
('whisperx', 'WhisperX Core'),
('torch', 'PyTorch'),
('datetime', 'DateTime Utils'),
('pathlib', 'Path Utils'),
('queue', 'Threading Queue'),
('platform', 'Platform Utils'),
('subprocess', 'Subprocess Utils')
]
imported_modules = {}
for module_name, display_name in modules_to_import:
splash_screen.update_status(f"Loading {display_name}...")
try:
imported_modules[module_name] = importlib.import_module(module_name)
time.sleep(0.1) # Give a small delay to show the loading message
except ImportError as e:
splash_screen.update_status(f"Error loading {display_name}: {str(e)}")
time.sleep(2)
sys.exit(1)
return imported_modules
# Here's where we define our WhisperXGUI class with all the imported modules
class WhisperXGUI:
def __init__(self, root, modules):
self.root = root
self.root.title("WhisperX Transcription GUI")
self.root.geometry("800x600")
# Unpack needed modules
self.scrolledtext = modules['tkinter.scrolledtext']
self.filedialog = modules['tkinter.filedialog']
self.messagebox = modules['tkinter.messagebox']
self.whisperx = modules['whisperx']
self.torch = modules['torch']
self.Path = modules['pathlib'].Path
self.queue = modules['queue'].Queue
self.platform = modules['platform']
self.subprocess = modules['subprocess']
# Queue for communication between threads
self.output_queue = self.queue()
# List to store multiple file paths
self.file_list = []
self.setup_ui()
self.setup_bindings()
# Start the output checking loop
self.check_output()
def setup_ui(self):
# Main frame
main_frame = ttk.Frame(self.root, padding="10")
main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
# File selection
ttk.Label(main_frame, text="Input Files:").grid(row=0, column=0, sticky=tk.W)
self.files_listbox = tk.Listbox(main_frame, width=50, height=5)
self.files_listbox.grid(row=0, column=1, padx=5, sticky=(tk.W, tk.E))
# File buttons frame
file_buttons_frame = ttk.Frame(main_frame)
file_buttons_frame.grid(row=0, column=2, padx=5)
ttk.Button(file_buttons_frame, text="➕ Add Files", command=self.browse_files).grid(row=0, column=0, pady=2)
ttk.Button(file_buttons_frame, text="➖ Remove Selected", command=self.remove_selected_file).grid(row=1, column=0, pady=2)
ttk.Button(file_buttons_frame, text="🆑 Clear All", command=self.clear_files).grid(row=2, column=0, pady=2)
# Model selection
ttk.Label(main_frame, text="Model:").grid(row=1, column=0, sticky=tk.W, pady=5)
self.model_var = tk.StringVar(value="large-v2")
models = ["tiny", "base", "small", "medium", "large-v2"]
model_combo = ttk.Combobox(main_frame, textvariable=self.model_var, values=models, state="readonly")
model_combo.grid(row=1, column=1, sticky=(tk.W, tk.E), pady=5)
# Language selection with full names
ttk.Label(main_frame, text="Language:").grid(row=2, column=0, sticky=tk.W, pady=5)
self.language_mapping = {
"English": "en",
"French": "fr",
"German": "de",
"Spanish": "es",
"Italian": "it",
"Japanese": "ja",
"Chinese": "zh",
"Dutch": "nl",
"Ukrainian": "uk",
"Portuguese": "pt"
}
self.language_var = tk.StringVar(value="English")
language_names = list(self.language_mapping.keys())
language_combo = ttk.Combobox(main_frame, textvariable=self.language_var, values=language_names, state="readonly")
language_combo.grid(row=2, column=1, sticky=(tk.W, tk.E), pady=5)
# Compute type selection
ttk.Label(main_frame, text="Compute Type:").grid(row=3, column=0, sticky=tk.W, pady=5)
self.compute_type_var = tk.StringVar(value="float16")
compute_types = ["float32", "float16", "int8"]
compute_type_combo = ttk.Combobox(main_frame, textvariable=self.compute_type_var, values=compute_types, state="readonly")
compute_type_combo.grid(row=3, column=1, sticky=(tk.W, tk.E), pady=5)
# Progress and output
self.output_text = scrolledtext.ScrolledText(main_frame, height=20, width=70, wrap=tk.WORD)
self.output_text.grid(row=4, column=0, columnspan=3, pady=10)
self.output_text.config(state='disabled')
# Transcribe button
self.transcribe_btn = ttk.Button(main_frame, text="✨ Transcribe All", command=self.start_transcription)
self.transcribe_btn.grid(row=5, column=0, columnspan=3, pady=10)
# Configure grid weights
self.root.columnconfigure(0, weight=1)
self.root.rowconfigure(0, weight=1)
main_frame.columnconfigure(1, weight=1)
def browse_files(self):
filetypes = (
("Video/audio", "*.mp4 *.mkv *.mov *.wmv *.avi *.flv *.mp3 *.wav *.aac *.flac *.ogg"),
("Video", "*.mp4 *.mkv *.mov *.wmv *.avi *.flv"),
("Audio", "*.mp3 *.wav *.aac *.flac *.ogg"),
("All files", "*.*")
)
filenames = filedialog.askopenfilenames(filetypes=filetypes)
if filenames:
for filename in filenames:
if filename not in self.file_list:
self.file_list.append(filename)
self.files_listbox.insert(tk.END, os.path.basename(filename))
def remove_selected_file(self):
selection = self.files_listbox.curselection()
if selection:
index = selection[0]
self.files_listbox.delete(index)
self.file_list.pop(index)
def clear_files(self):
self.files_listbox.delete(0, tk.END)
self.file_list.clear()
def setup_bindings(self):
self.root.protocol("WM_DELETE_WINDOW", self.on_closing)
def update_output(self, message):
self.output_text.config(state='normal')
self.output_text.insert(tk.END, message + '\n')
self.output_text.see(tk.END)
self.output_text.config(state='disabled')
def check_output(self):
try:
while True:
message = self.output_queue.get_nowait()
self.update_output(message)
except queue.Empty:
pass
finally:
self.root.after(100, self.check_output)
def create_output_directory(self, input_file):
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
base_name = Path(input_file).stem
output_dir = Path("transcripts") / f"{base_name}_{timestamp}"
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir
def transcribe(self):
try:
if not self.file_list:
messagebox.showerror("Error", "Please select at least one input file")
return
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = self.compute_type_var.get()
language = self.language_mapping[self.language_var.get()]
# Load model once for all files
self.output_queue.put("Loading WhisperX model...")
model = whisperx.load_model(
self.model_var.get(),
device,
compute_type=compute_type
)
for input_file in self.file_list:
self.output_queue.put(f"\nStarting transcription of: {input_file}")
# Create output directory
output_dir = self.create_output_directory(input_file)
self.output_queue.put(f"Output directory: {output_dir}")
# Load audio
audio = whisperx.load_audio(input_file)
# Transcribe
self.output_queue.put("Transcribing...")
result = model.transcribe(audio, batch_size=16, language=language)
# Align
self.output_queue.put("Aligning transcript...")
model_a, metadata = whisperx.load_align_model(
language_code=language,
device=device
)
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
device,
return_char_alignments=False
)
# Save output files
srt_path = output_dir / f"{Path(input_file).stem}.srt"
txt_path = output_dir / f"{Path(input_file).stem}.txt"
# Write SRT file
with open(srt_path, 'w', encoding='utf-8') as f:
for i, seg in enumerate(result["segments"], 1):
start = self.format_timestamp(seg["start"])
end = self.format_timestamp(seg["end"])
f.write(f"{i}\n{start} --> {end}\n{seg['text'].strip()}\n\n")
# Write TXT file
with open(txt_path, 'w', encoding='utf-8') as f:
for seg in result["segments"]:
f.write(f"{seg['text'].strip()}\n")
self.output_queue.put(f"Transcription complete for: {input_file}")
self.output_queue.put(f"Files saved in: {output_dir}")
# Open the output directory
self.open_folder(output_dir)
# Clean up alignment model
del model_a
torch.cuda.empty_cache()
# Clean up whisper model
del model
torch.cuda.empty_cache()
self.output_queue.put("\nAll files processed successfully!")
except Exception as e:
self.output_queue.put(f"Error during transcription: {str(e)}")
messagebox.showerror("Error", f"Transcription failed: {str(e)}")
finally:
self.transcribe_btn.config(state='normal')
def open_folder(self, path):
path = os.path.realpath(path)
if platform.system() == "Windows":
os.startfile(path)
elif platform.system() == "Darwin": # macOS
subprocess.run(["open", path])
else: # Linux
subprocess.run(["xdg-open", path])
def format_timestamp(self, seconds):
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
milliseconds = int((seconds % 1) * 1000)
seconds = int(seconds)
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def start_transcription(self):
self.transcribe_btn.config(state='disabled')
threading.Thread(target=self.transcribe, daemon=True).start()
def on_closing(self):
if messagebox.askokcancel("Quit", "Do you want to quit?"):
self.root.quit()
# Function to finish initialization and show main window
def finish_init():
app = WhisperXGUI(root, imported_modules)
root.deiconify() # Show the main window
splash.destroy() # Close the splash screen
if __name__ == "__main__":
# Create and show the splash screen
root = tk.Tk()
root.withdraw() # Hide the main window initially
splash = SplashScreen()
# Import dependencies
imported_modules = import_with_status(splash)
# Now import our WhisperXGUI class and global variables
splash.update_status("Initializing application...")
# Schedule the final initialization
root.after(1000, finish_init)
# Start the main loop
root.mainloop()
aiohappyeyeballs==2.4.4
aiohttp==3.11.11
aiosignal==1.3.2
alembic==1.14.0
altgraph==0.17.4
antlr4-python3-runtime==4.9.3
asteroid-filterbanks==0.4.0
attrs==24.3.0
audioread==3.0.1
av==11.0.0
certifi==2024.12.14
cffi==1.17.1
charset-normalizer==3.4.1
click==8.1.8
colorama==0.4.6
coloredlogs==15.0.1
colorlog==6.9.0
contourpy==1.3.1
ctranslate2==4.4.0
cycler==0.12.1
decorator==5.1.1
docopt==0.6.2
einops==0.8.0
faster-whisper==1.0.0
filelock==3.16.1
flatbuffers==24.12.23
fonttools==4.55.3
frozenlist==1.5.0
fsspec==2024.12.0
greenlet==3.1.1
huggingface-hub==0.27.0
humanfriendly==10.0
HyperPyYAML==1.2.2
idna==3.10
Jinja2==3.1.5
joblib==1.4.2
julius==0.2.7
kiwisolver==1.4.8
lazy_loader==0.4
librosa==0.10.2.post1
lightning==2.5.0.post0
lightning-utilities==0.11.9
llvmlite==0.43.0
Mako==1.3.8
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.10.0
mdurl==0.1.2
mpmath==1.3.0
msgpack==1.1.0
multidict==6.1.0
networkx==3.4.2
nltk==3.9.1
numba==0.60.0
numpy==1.26.4
omegaconf==2.3.0
onnxruntime==1.20.1
optuna==4.1.0
packaging==24.2
pandas==2.2.3
pefile==2023.2.7
pillow==11.0.0
platformdirs==4.3.6
pooch==1.8.2
primePy==1.3
propcache==0.2.1
protobuf==5.29.2
pyannote.audio==3.1.1
pyannote.core==5.0.0
pyannote.database==5.1.0
pyannote.metrics==3.2.1
pyannote.pipeline==3.0.1
pycparser==2.22
Pygments==2.18.0
pyinstaller-hooks-contrib==2024.11
pyparsing==3.2.0
pyreadline3==3.5.4
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pytorch-metric-learning==2.8.1
pytz==2024.2
pywin32-ctypes==0.2.3
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
rich==13.9.4
ruamel.yaml==0.18.6
ruamel.yaml.clib==0.2.12
safetensors==0.4.5
scikit-learn==1.6.0
scipy==1.14.1
semver==3.0.2
sentencepiece==0.2.0
setuptools==75.6.0
shellingham==1.5.4
six==1.17.0
sortedcontainers==2.4.0
soundfile==0.12.1
soxr==0.5.0.post1
speechbrain==1.0.2
SQLAlchemy==2.0.36
sympy==1.13.1
tabulate==0.9.0
tensorboardX==2.6.2.2
threadpoolctl==3.5.0
tokenizers==0.15.2
torch @ https://download.pytorch.org/whl/nightly/cu126/torch-2.6.0.dev20241227%2Bcu126-cp312-cp312-win_amd64.whl#sha256=f52e45693daa9f729c8063e2a31976a49b7bafac93c7b8f8815cfb6026b14183
torch-audiomentations==0.11.1
torch_pitch_shift==1.2.5
torchaudio @ https://download.pytorch.org/whl/nightly/cu126/torchaudio-2.6.0.dev20241228%2Bcu126-cp312-cp312-win_amd64.whl#sha256=8a690788d534fadbafb31668c77b5c23c224bacf8caaa46ee436e38d29fb273a
torchmetrics==1.4.1
tqdm==4.66.5
transformers==4.39.3
typer==0.12.5
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
whisperx @ git+https://github.com/m-bain/whisperx.git@7307306a9d8dd0d261e588cc933322454f853853
yarl==1.18.3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment