Skip to content

Instantly share code, notes, and snippets.

@fredriccliver
Created November 23, 2024 09:41
Show Gist options
  • Save fredriccliver/1ffe3e32820e57059721cfbce2db2d5a to your computer and use it in GitHub Desktop.
Save fredriccliver/1ffe3e32820e57059721cfbce2db2d5a to your computer and use it in GitHub Desktop.
Fast and Cost efficient transcription with Whisper and Replicate
import os
import json
import logging
import replicate
from pathlib import Path
from time import time as get_time
from typing import Optional, Dict, Any
from dotenv import load_dotenv
from pydub import AudioSegment
import tempfile
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
import time
from collections import defaultdict
import curses
import io
import asyncio
from queue import Queue
from threading import Lock
# Debug: Print environment variable directly and after loading
token = os.environ.get('REPLICATE_API_TOKEN')
print(f"Direct env token: {token[:5]}...{token[-5:]}" if token else "Direct env token: Not found")
load_dotenv()
token = os.getenv('REPLICATE_API_TOKEN')
print(f"After load_dotenv token: {token[:5]}...{token[-5:]}" if token else "After load_dotenv token: Not found")
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Configure Replicate API
replicate_api_token = os.getenv('REPLICATE_API_TOKEN')
if not replicate_api_token:
raise ValueError("Replicate API token not found. Please set REPLICATE_API_TOKEN in .env file")
# Debug: Print API token prefix and suffix
client = replicate.Client(api_token=replicate_api_token)
print(f"Using token: {replicate_api_token[:8]}...{replicate_api_token[-8:]}")
def get_file_count() -> int:
"""Get the number of files to process from user input"""
while True:
try:
user_input = input("Enter number of files to transcribe (default=1): ").strip()
if not user_input:
return 1
count = int(user_input)
if count < 1:
print("Please enter a positive number")
continue
return count
except ValueError:
print("Please enter a valid number")
def verify_audio_file(file_path: str) -> bool:
"""Verify if the audio file is valid and readable"""
try:
# Use ffprobe to check file validity
import subprocess
result = subprocess.run([
'ffprobe',
'-v', 'error',
'-select_streams', 'a:0',
'-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1',
file_path
], capture_output=True, text=True)
if result.returncode != 0 or not result.stdout.strip():
logging.error(f"FFprobe validation failed for {os.path.basename(file_path)}")
return False
return True
except Exception as e:
logging.error(f"Invalid audio file {os.path.basename(file_path)}: {str(e)}")
return False
async def transcribe_audio_worker(queue: Queue, results: dict, lock: Lock) -> None:
"""Worker function to process transcription tasks"""
while True:
try:
# Get task from queue
task = queue.get_nowait()
if task is None: # Sentinel value to stop worker
break
file_path, file_info = task
start_time = get_time()
try:
deployment = replicate.deployments.get("[organisation]/[deployment id]")
logging.info(f"Starting transcription for: {os.path.basename(file_path)}")
with open(file_path, "rb") as audio_file:
prediction = deployment.predictions.create(
input={
"audio": audio_file,
"task": "transcribe",
"language": "english",
"batch_size": 64,
"timestamp": "chunk",
"diarise_audio": False
}
)
# Custom wait with longer intervals
while prediction.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(5) # Non-blocking sleep
prediction.reload()
if prediction.status == "failed":
error_msg = prediction.error or "Unknown error occurred"
logging.error(f"❌ Prediction failed for {os.path.basename(file_path)}: {error_msg}")
with lock:
results[file_path] = None
continue
if prediction.output is None:
logging.error(f"❌ No output received for {os.path.basename(file_path)}")
with lock:
results[file_path] = None
continue
transcription = {
"text": prediction.output.get("text", ""),
"chunks": prediction.output.get("chunks", [])
}
logging.info(f"✅ Transcription completed for {os.path.basename(file_path)} in {get_time() - start_time:.2f} seconds")
with lock:
results[file_path] = transcription
except Exception as e:
logging.error(f"❌ Error processing {os.path.basename(file_path)}: {str(e)}")
with lock:
results[file_path] = None
except asyncio.QueueEmpty:
break
finally:
queue.task_done()
async def process_files_concurrently(files_to_process: list, max_workers: int = 10) -> dict:
"""Process multiple files concurrently using a queue system"""
queue = Queue()
results = {}
lock = Lock()
total_files = len(files_to_process)
completed = 0
# Track active files
active_files = {} # worker_id -> (file_name, start_time, title)
for file_info in files_to_process:
queue.put(file_info)
for _ in range(max_workers):
queue.put(None)
def clear_lines(n):
"""Clear n lines in console"""
print(f"\033[{n}A\033[J", end="")
def update_status():
"""Update status display with active files"""
# Clear previous status
if active_files:
clear_lines(len(active_files) + 4)
# Print overall progress bar
progress = (completed / total_files) * 100
bar_width = 50
filled = int(bar_width * completed // total_files)
bar = '█' * filled + '░' * (bar_width - filled)
print(f"\nProgress: [{bar}] {progress:.1f}%")
print(f"Completed: {completed}/{total_files} | Queue: {max(0, queue.qsize() - max_workers)} remaining")
print("\nActive Workers:")
print("─" * 100) # Increased width for longer titles
# Sort workers by elapsed time
sorted_workers = sorted(active_files.items(),
key=lambda x: x[1][1],
reverse=True)
# Print active file status with visual indicators
for worker_id, (_, start_time, title) in sorted_workers:
elapsed = time.time() - start_time
bar_width = 20 # Reduced bar width to accommodate titles
progress_bar = '▰' * min(bar_width, int(elapsed/2)) + '▱' * (bar_width - min(bar_width, int(elapsed/2)))
# Truncate title if too long
display_title = title[:50] + '...' if len(title) > 50 else title
print(f"Worker {worker_id:2d} │ {progress_bar} │ {display_title:<50} │ {elapsed:.1f}s")
async def status_worker():
while completed < total_files:
update_status()
await asyncio.sleep(0.5)
async def worker(worker_id: int, queue: Queue, results: dict, lock: Lock) -> None:
nonlocal completed
while True:
try:
task = queue.get_nowait()
if task is None:
break
file_path, file_info = task
audio_file, transcription_file, channel_id, episode_id, episode_title = file_info
with lock:
active_files[worker_id] = (file_path, time.time(), episode_title)
start_time = get_time()
try:
deployment = replicate.deployments.get("[organisation]/[deployment id]")
logging.info(f"Starting transcription for: {os.path.basename(file_path)}")
with open(file_path, "rb") as audio_file:
prediction = deployment.predictions.create(
input={
"audio": audio_file,
"task": "transcribe",
"language": "english",
"batch_size": 24,
"timestamp": "chunk",
"diarise_audio": False
}
)
# Custom wait with longer intervals
while prediction.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(5) # Non-blocking sleep
prediction.reload()
if prediction.status == "failed":
error_msg = prediction.error or "Unknown error occurred"
logging.error(f"❌ Prediction failed for {os.path.basename(file_path)}: {error_msg}")
with lock:
results[file_path] = None
continue
if prediction.output is None:
logging.error(f"❌ No output received for {os.path.basename(file_path)}")
with lock:
results[file_path] = None
continue
if prediction.output is not None:
transcription = {
"text": prediction.output.get("text", ""),
"chunks": prediction.output.get("chunks", [])
}
# Verify we have actual content
if not transcription["text"] and not transcription["chunks"]:
logging.error(f"❌ Empty transcription received for {episode_id} - {episode_title}")
with lock:
results[file_path] = None
continue
# Save transcription immediately after receiving it
try:
transcription_file.parent.mkdir(parents=True, exist_ok=True)
with open(transcription_file, "w", encoding="utf-8") as f:
json.dump(transcription, f, ensure_ascii=False, indent=2)
logging.info(f"✅ Transcription saved for {episode_id} - {episode_title}")
logging.info(f" Path: {transcription_file}")
logging.info(f" Text length: {len(transcription['text'])} chars")
logging.info(f" Chunks: {len(transcription['chunks'])}")
except Exception as e:
logging.error(f"❌ Error saving transcription: {str(e)}")
except Exception as e:
logging.error(f"❌ Error processing {os.path.basename(file_path)}: {str(e)}")
with lock:
results[file_path] = None
with lock:
completed += 1
active_files.pop(worker_id, None)
except asyncio.QueueEmpty:
break
finally:
queue.task_done()
workers = [asyncio.create_task(worker(i, queue, results, lock))
for i in range(max_workers)]
workers.append(asyncio.create_task(status_worker()))
await asyncio.gather(*workers)
print("\n")
return results
def get_channel_selection(data_dir: Path) -> Optional[str]:
"""Get channel selection from user or return None for all channels"""
channels = [d.name for d in data_dir.iterdir() if d.is_dir()]
if not channels:
return None
print("\nAvailable channels:")
print("0. All channels")
for i, channel in enumerate(channels, 1):
print(f"{i}. {channel}")
while True:
choice = input("\nSelect channel number (0 for all) or paste channel ID: ").strip()
if not choice:
return None
try:
# Try as number first
choice_num = int(choice)
if choice_num == 0:
return None
if 1 <= choice_num <= len(channels):
return channels[choice_num - 1]
except ValueError:
# If not a number, check if it's a valid channel ID
if choice in channels:
return choice
print("Invalid selection. Please try again.")
def process_untranscribed_files() -> None:
"""Find and process untranscribed audio files"""
script_dir = Path(__file__).parent
data_dir = script_dir / "data"
data_dir.mkdir(exist_ok=True)
# Get channel selection from user
selected_channel = get_channel_selection(data_dir)
processed_count = 0
failed_count = 0
failed_files = []
# Find all untranscribed files
pending_files = []
channel_dirs = [data_dir / selected_channel] if selected_channel else data_dir.iterdir()
for channel_dir in channel_dirs:
if not channel_dir.is_dir():
continue
channel_id = channel_dir.name
for episode_dir in channel_dir.iterdir():
if not episode_dir.is_dir():
continue
episode_id = episode_dir.name
metadata_file = episode_dir / "metadata.json"
audio_file = episode_dir / "audio.mp3"
transcription_file = episode_dir / "transcription.json"
# Get episode title from metadata
episode_title = "Unknown Title"
if metadata_file.exists():
try:
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
episode_title = metadata.get('title', 'Unknown Title')
except Exception as e:
logging.warning(f"Could not read metadata for episode {episode_id}: {e}")
if audio_file.exists() and not transcription_file.exists():
pending_files.append((audio_file, transcription_file, channel_id, episode_id, episode_title))
logging.info(f"Found untranscribed episode: {episode_id} - {episode_title}")
total_files = len(pending_files)
if total_files == 0:
logging.info("No untranscribed files found")
return
logging.info(f"Found {total_files} untranscribed files")
files_to_process = min(get_file_count(), total_files)
logging.info(f"Will process {files_to_process} file(s)")
# Get user confirmation
if not get_user_confirmation(pending_files, files_to_process):
logging.info("Transcription cancelled by user")
return
# Process files concurrently
file_pairs = [(str(file_info[0]), file_info) for file_info in pending_files[:files_to_process]]
results = asyncio.run(process_files_concurrently(file_pairs))
# Process results
processed_count = 0
failed_count = 0
failed_files = []
for file_path, result in results.items():
file_info = next(info for path, info in file_pairs if path == file_path)
audio_file, transcription_file, channel_id, episode_id, episode_title = file_info # Unpack all values
if result is not None:
try:
# Ensure the directory exists
transcription_file.parent.mkdir(parents=True, exist_ok=True)
# Save transcription with proper path handling
with open(transcription_file, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
processed_count += 1
logging.info(f"✓ Saved transcription for: {episode_id} - {episode_title}")
logging.info(f" Path: {transcription_file}")
except Exception as e:
failed_count += 1
failed_files.append({
'channel_id': channel_id,
'episode_id': episode_id,
'title': episode_title,
'error': f"Failed to save transcription: {str(e)}"
})
logging.error(f"❌ Error saving transcription: {str(e)}")
else:
failed_count += 1
failed_files.append({
'channel_id': channel_id,
'episode_id': episode_id,
'title': episode_title,
'error': "Transcription failed"
})
# Summary reporting
logging.info(f"\nProcessing Summary:")
logging.info(f"------------------------")
logging.info(f"Total files found: {total_files}")
logging.info(f"Files attempted: {files_to_process}")
logging.info(f"Successfully transcribed: {processed_count}")
logging.info(f"Failed: {failed_count}")
if failed_count > 0:
logging.error("\nFailed Files:")
logging.error("------------------------")
for failed in failed_files:
logging.error(f"- Channel: {failed['channel_id']}")
logging.error(f" Episode: {failed['episode_id']}")
logging.error(f" Title: {failed['title']}")
logging.error(f" Error: {failed['error']}")
logging.error(" ------------------------")
logging.info("\nProcess Status:")
logging.info("------------------------")
if processed_count == files_to_process:
logging.info("✓ All files processed successfully")
elif processed_count == 0:
logging.error("× No files were successfully processed")
if failed_files:
logging.error(f" Main error: {failed_files[0]['error']}")
else:
logging.warning(f"⚠ Partial success: {processed_count}/{files_to_process} files processed")
def get_user_confirmation(pending_files: list, files_to_process: int) -> bool:
"""Get user confirmation before proceeding with transcription"""
print("\nFiles to be processed:")
print("------------------------")
for i, (_, _, _, episode_id, title) in enumerate(pending_files[:files_to_process], 1):
print(f"{i}. {episode_id} - {title}")
if i >= 5 and files_to_process > 5:
remaining = files_to_process - 5
print(f"... and {remaining} more files")
break
while True:
choice = input("\nProceed with transcription? (y/n): ").strip().lower()
if choice in ['y', 'yes']:
return True
if choice in ['n', 'no']:
return False
print("Please enter 'y' or 'n'")
def main():
"""Main entry point for the transcription process"""
logging.info("Starting transcription process")
# Remove the status check and proceed directly to processing
process_untranscribed_files()
logging.info("Transcription process completed")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment