Created
November 23, 2024 09:41
-
-
Save fredriccliver/1ffe3e32820e57059721cfbce2db2d5a to your computer and use it in GitHub Desktop.
Fast and Cost efficient transcription with Whisper and Replicate
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import logging | |
import replicate | |
from pathlib import Path | |
from time import time as get_time | |
from typing import Optional, Dict, Any | |
from dotenv import load_dotenv | |
from pydub import AudioSegment | |
import tempfile | |
import concurrent.futures | |
from concurrent.futures import ThreadPoolExecutor | |
import time | |
from collections import defaultdict | |
import curses | |
import io | |
import asyncio | |
from queue import Queue | |
from threading import Lock | |
# Debug: Print environment variable directly and after loading | |
token = os.environ.get('REPLICATE_API_TOKEN') | |
print(f"Direct env token: {token[:5]}...{token[-5:]}" if token else "Direct env token: Not found") | |
load_dotenv() | |
token = os.getenv('REPLICATE_API_TOKEN') | |
print(f"After load_dotenv token: {token[:5]}...{token[-5:]}" if token else "After load_dotenv token: Not found") | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
# Configure Replicate API | |
replicate_api_token = os.getenv('REPLICATE_API_TOKEN') | |
if not replicate_api_token: | |
raise ValueError("Replicate API token not found. Please set REPLICATE_API_TOKEN in .env file") | |
# Debug: Print API token prefix and suffix | |
client = replicate.Client(api_token=replicate_api_token) | |
print(f"Using token: {replicate_api_token[:8]}...{replicate_api_token[-8:]}") | |
def get_file_count() -> int: | |
"""Get the number of files to process from user input""" | |
while True: | |
try: | |
user_input = input("Enter number of files to transcribe (default=1): ").strip() | |
if not user_input: | |
return 1 | |
count = int(user_input) | |
if count < 1: | |
print("Please enter a positive number") | |
continue | |
return count | |
except ValueError: | |
print("Please enter a valid number") | |
def verify_audio_file(file_path: str) -> bool: | |
"""Verify if the audio file is valid and readable""" | |
try: | |
# Use ffprobe to check file validity | |
import subprocess | |
result = subprocess.run([ | |
'ffprobe', | |
'-v', 'error', | |
'-select_streams', 'a:0', | |
'-show_entries', 'format=duration', | |
'-of', 'default=noprint_wrappers=1:nokey=1', | |
file_path | |
], capture_output=True, text=True) | |
if result.returncode != 0 or not result.stdout.strip(): | |
logging.error(f"FFprobe validation failed for {os.path.basename(file_path)}") | |
return False | |
return True | |
except Exception as e: | |
logging.error(f"Invalid audio file {os.path.basename(file_path)}: {str(e)}") | |
return False | |
async def transcribe_audio_worker(queue: Queue, results: dict, lock: Lock) -> None: | |
"""Worker function to process transcription tasks""" | |
while True: | |
try: | |
# Get task from queue | |
task = queue.get_nowait() | |
if task is None: # Sentinel value to stop worker | |
break | |
file_path, file_info = task | |
start_time = get_time() | |
try: | |
deployment = replicate.deployments.get("[organisation]/[deployment id]") | |
logging.info(f"Starting transcription for: {os.path.basename(file_path)}") | |
with open(file_path, "rb") as audio_file: | |
prediction = deployment.predictions.create( | |
input={ | |
"audio": audio_file, | |
"task": "transcribe", | |
"language": "english", | |
"batch_size": 64, | |
"timestamp": "chunk", | |
"diarise_audio": False | |
} | |
) | |
# Custom wait with longer intervals | |
while prediction.status not in ["succeeded", "failed", "canceled"]: | |
await asyncio.sleep(5) # Non-blocking sleep | |
prediction.reload() | |
if prediction.status == "failed": | |
error_msg = prediction.error or "Unknown error occurred" | |
logging.error(f"❌ Prediction failed for {os.path.basename(file_path)}: {error_msg}") | |
with lock: | |
results[file_path] = None | |
continue | |
if prediction.output is None: | |
logging.error(f"❌ No output received for {os.path.basename(file_path)}") | |
with lock: | |
results[file_path] = None | |
continue | |
transcription = { | |
"text": prediction.output.get("text", ""), | |
"chunks": prediction.output.get("chunks", []) | |
} | |
logging.info(f"✅ Transcription completed for {os.path.basename(file_path)} in {get_time() - start_time:.2f} seconds") | |
with lock: | |
results[file_path] = transcription | |
except Exception as e: | |
logging.error(f"❌ Error processing {os.path.basename(file_path)}: {str(e)}") | |
with lock: | |
results[file_path] = None | |
except asyncio.QueueEmpty: | |
break | |
finally: | |
queue.task_done() | |
async def process_files_concurrently(files_to_process: list, max_workers: int = 10) -> dict: | |
"""Process multiple files concurrently using a queue system""" | |
queue = Queue() | |
results = {} | |
lock = Lock() | |
total_files = len(files_to_process) | |
completed = 0 | |
# Track active files | |
active_files = {} # worker_id -> (file_name, start_time, title) | |
for file_info in files_to_process: | |
queue.put(file_info) | |
for _ in range(max_workers): | |
queue.put(None) | |
def clear_lines(n): | |
"""Clear n lines in console""" | |
print(f"\033[{n}A\033[J", end="") | |
def update_status(): | |
"""Update status display with active files""" | |
# Clear previous status | |
if active_files: | |
clear_lines(len(active_files) + 4) | |
# Print overall progress bar | |
progress = (completed / total_files) * 100 | |
bar_width = 50 | |
filled = int(bar_width * completed // total_files) | |
bar = '█' * filled + '░' * (bar_width - filled) | |
print(f"\nProgress: [{bar}] {progress:.1f}%") | |
print(f"Completed: {completed}/{total_files} | Queue: {max(0, queue.qsize() - max_workers)} remaining") | |
print("\nActive Workers:") | |
print("─" * 100) # Increased width for longer titles | |
# Sort workers by elapsed time | |
sorted_workers = sorted(active_files.items(), | |
key=lambda x: x[1][1], | |
reverse=True) | |
# Print active file status with visual indicators | |
for worker_id, (_, start_time, title) in sorted_workers: | |
elapsed = time.time() - start_time | |
bar_width = 20 # Reduced bar width to accommodate titles | |
progress_bar = '▰' * min(bar_width, int(elapsed/2)) + '▱' * (bar_width - min(bar_width, int(elapsed/2))) | |
# Truncate title if too long | |
display_title = title[:50] + '...' if len(title) > 50 else title | |
print(f"Worker {worker_id:2d} │ {progress_bar} │ {display_title:<50} │ {elapsed:.1f}s") | |
async def status_worker(): | |
while completed < total_files: | |
update_status() | |
await asyncio.sleep(0.5) | |
async def worker(worker_id: int, queue: Queue, results: dict, lock: Lock) -> None: | |
nonlocal completed | |
while True: | |
try: | |
task = queue.get_nowait() | |
if task is None: | |
break | |
file_path, file_info = task | |
audio_file, transcription_file, channel_id, episode_id, episode_title = file_info | |
with lock: | |
active_files[worker_id] = (file_path, time.time(), episode_title) | |
start_time = get_time() | |
try: | |
deployment = replicate.deployments.get("[organisation]/[deployment id]") | |
logging.info(f"Starting transcription for: {os.path.basename(file_path)}") | |
with open(file_path, "rb") as audio_file: | |
prediction = deployment.predictions.create( | |
input={ | |
"audio": audio_file, | |
"task": "transcribe", | |
"language": "english", | |
"batch_size": 24, | |
"timestamp": "chunk", | |
"diarise_audio": False | |
} | |
) | |
# Custom wait with longer intervals | |
while prediction.status not in ["succeeded", "failed", "canceled"]: | |
await asyncio.sleep(5) # Non-blocking sleep | |
prediction.reload() | |
if prediction.status == "failed": | |
error_msg = prediction.error or "Unknown error occurred" | |
logging.error(f"❌ Prediction failed for {os.path.basename(file_path)}: {error_msg}") | |
with lock: | |
results[file_path] = None | |
continue | |
if prediction.output is None: | |
logging.error(f"❌ No output received for {os.path.basename(file_path)}") | |
with lock: | |
results[file_path] = None | |
continue | |
if prediction.output is not None: | |
transcription = { | |
"text": prediction.output.get("text", ""), | |
"chunks": prediction.output.get("chunks", []) | |
} | |
# Verify we have actual content | |
if not transcription["text"] and not transcription["chunks"]: | |
logging.error(f"❌ Empty transcription received for {episode_id} - {episode_title}") | |
with lock: | |
results[file_path] = None | |
continue | |
# Save transcription immediately after receiving it | |
try: | |
transcription_file.parent.mkdir(parents=True, exist_ok=True) | |
with open(transcription_file, "w", encoding="utf-8") as f: | |
json.dump(transcription, f, ensure_ascii=False, indent=2) | |
logging.info(f"✅ Transcription saved for {episode_id} - {episode_title}") | |
logging.info(f" Path: {transcription_file}") | |
logging.info(f" Text length: {len(transcription['text'])} chars") | |
logging.info(f" Chunks: {len(transcription['chunks'])}") | |
except Exception as e: | |
logging.error(f"❌ Error saving transcription: {str(e)}") | |
except Exception as e: | |
logging.error(f"❌ Error processing {os.path.basename(file_path)}: {str(e)}") | |
with lock: | |
results[file_path] = None | |
with lock: | |
completed += 1 | |
active_files.pop(worker_id, None) | |
except asyncio.QueueEmpty: | |
break | |
finally: | |
queue.task_done() | |
workers = [asyncio.create_task(worker(i, queue, results, lock)) | |
for i in range(max_workers)] | |
workers.append(asyncio.create_task(status_worker())) | |
await asyncio.gather(*workers) | |
print("\n") | |
return results | |
def get_channel_selection(data_dir: Path) -> Optional[str]: | |
"""Get channel selection from user or return None for all channels""" | |
channels = [d.name for d in data_dir.iterdir() if d.is_dir()] | |
if not channels: | |
return None | |
print("\nAvailable channels:") | |
print("0. All channels") | |
for i, channel in enumerate(channels, 1): | |
print(f"{i}. {channel}") | |
while True: | |
choice = input("\nSelect channel number (0 for all) or paste channel ID: ").strip() | |
if not choice: | |
return None | |
try: | |
# Try as number first | |
choice_num = int(choice) | |
if choice_num == 0: | |
return None | |
if 1 <= choice_num <= len(channels): | |
return channels[choice_num - 1] | |
except ValueError: | |
# If not a number, check if it's a valid channel ID | |
if choice in channels: | |
return choice | |
print("Invalid selection. Please try again.") | |
def process_untranscribed_files() -> None: | |
"""Find and process untranscribed audio files""" | |
script_dir = Path(__file__).parent | |
data_dir = script_dir / "data" | |
data_dir.mkdir(exist_ok=True) | |
# Get channel selection from user | |
selected_channel = get_channel_selection(data_dir) | |
processed_count = 0 | |
failed_count = 0 | |
failed_files = [] | |
# Find all untranscribed files | |
pending_files = [] | |
channel_dirs = [data_dir / selected_channel] if selected_channel else data_dir.iterdir() | |
for channel_dir in channel_dirs: | |
if not channel_dir.is_dir(): | |
continue | |
channel_id = channel_dir.name | |
for episode_dir in channel_dir.iterdir(): | |
if not episode_dir.is_dir(): | |
continue | |
episode_id = episode_dir.name | |
metadata_file = episode_dir / "metadata.json" | |
audio_file = episode_dir / "audio.mp3" | |
transcription_file = episode_dir / "transcription.json" | |
# Get episode title from metadata | |
episode_title = "Unknown Title" | |
if metadata_file.exists(): | |
try: | |
with open(metadata_file, 'r', encoding='utf-8') as f: | |
metadata = json.load(f) | |
episode_title = metadata.get('title', 'Unknown Title') | |
except Exception as e: | |
logging.warning(f"Could not read metadata for episode {episode_id}: {e}") | |
if audio_file.exists() and not transcription_file.exists(): | |
pending_files.append((audio_file, transcription_file, channel_id, episode_id, episode_title)) | |
logging.info(f"Found untranscribed episode: {episode_id} - {episode_title}") | |
total_files = len(pending_files) | |
if total_files == 0: | |
logging.info("No untranscribed files found") | |
return | |
logging.info(f"Found {total_files} untranscribed files") | |
files_to_process = min(get_file_count(), total_files) | |
logging.info(f"Will process {files_to_process} file(s)") | |
# Get user confirmation | |
if not get_user_confirmation(pending_files, files_to_process): | |
logging.info("Transcription cancelled by user") | |
return | |
# Process files concurrently | |
file_pairs = [(str(file_info[0]), file_info) for file_info in pending_files[:files_to_process]] | |
results = asyncio.run(process_files_concurrently(file_pairs)) | |
# Process results | |
processed_count = 0 | |
failed_count = 0 | |
failed_files = [] | |
for file_path, result in results.items(): | |
file_info = next(info for path, info in file_pairs if path == file_path) | |
audio_file, transcription_file, channel_id, episode_id, episode_title = file_info # Unpack all values | |
if result is not None: | |
try: | |
# Ensure the directory exists | |
transcription_file.parent.mkdir(parents=True, exist_ok=True) | |
# Save transcription with proper path handling | |
with open(transcription_file, "w", encoding="utf-8") as f: | |
json.dump(result, f, ensure_ascii=False, indent=2) | |
processed_count += 1 | |
logging.info(f"✓ Saved transcription for: {episode_id} - {episode_title}") | |
logging.info(f" Path: {transcription_file}") | |
except Exception as e: | |
failed_count += 1 | |
failed_files.append({ | |
'channel_id': channel_id, | |
'episode_id': episode_id, | |
'title': episode_title, | |
'error': f"Failed to save transcription: {str(e)}" | |
}) | |
logging.error(f"❌ Error saving transcription: {str(e)}") | |
else: | |
failed_count += 1 | |
failed_files.append({ | |
'channel_id': channel_id, | |
'episode_id': episode_id, | |
'title': episode_title, | |
'error': "Transcription failed" | |
}) | |
# Summary reporting | |
logging.info(f"\nProcessing Summary:") | |
logging.info(f"------------------------") | |
logging.info(f"Total files found: {total_files}") | |
logging.info(f"Files attempted: {files_to_process}") | |
logging.info(f"Successfully transcribed: {processed_count}") | |
logging.info(f"Failed: {failed_count}") | |
if failed_count > 0: | |
logging.error("\nFailed Files:") | |
logging.error("------------------------") | |
for failed in failed_files: | |
logging.error(f"- Channel: {failed['channel_id']}") | |
logging.error(f" Episode: {failed['episode_id']}") | |
logging.error(f" Title: {failed['title']}") | |
logging.error(f" Error: {failed['error']}") | |
logging.error(" ------------------------") | |
logging.info("\nProcess Status:") | |
logging.info("------------------------") | |
if processed_count == files_to_process: | |
logging.info("✓ All files processed successfully") | |
elif processed_count == 0: | |
logging.error("× No files were successfully processed") | |
if failed_files: | |
logging.error(f" Main error: {failed_files[0]['error']}") | |
else: | |
logging.warning(f"⚠ Partial success: {processed_count}/{files_to_process} files processed") | |
def get_user_confirmation(pending_files: list, files_to_process: int) -> bool: | |
"""Get user confirmation before proceeding with transcription""" | |
print("\nFiles to be processed:") | |
print("------------------------") | |
for i, (_, _, _, episode_id, title) in enumerate(pending_files[:files_to_process], 1): | |
print(f"{i}. {episode_id} - {title}") | |
if i >= 5 and files_to_process > 5: | |
remaining = files_to_process - 5 | |
print(f"... and {remaining} more files") | |
break | |
while True: | |
choice = input("\nProceed with transcription? (y/n): ").strip().lower() | |
if choice in ['y', 'yes']: | |
return True | |
if choice in ['n', 'no']: | |
return False | |
print("Please enter 'y' or 'n'") | |
def main(): | |
"""Main entry point for the transcription process""" | |
logging.info("Starting transcription process") | |
# Remove the status check and proceed directly to processing | |
process_untranscribed_files() | |
logging.info("Transcription process completed") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment