Skip to content

Instantly share code, notes, and snippets.

@Getty
Created December 20, 2024 16:45
Show Gist options
  • Save Getty/41ac7e47ac63662685df3a2259d4b384 to your computer and use it in GitHub Desktop.
Save Getty/41ac7e47ac63662685df3a2259d4b384 to your computer and use it in GitHub Desktop.
Script for splitting meetups into different speaker
# --extra-index-url https://download.pytorch.org/whl/cu121
# torch==2.1.2+cu121
# torchaudio==2.1.2+cu121
# pyannote.audio==3.1.1
# transformers==4.36.2
# numpy==1.24.3
# tqdm==4.66.1
# PyYAML==6.0.1
# soundfile==0.12.1
# typing-extensions>=4.8.0
# librosa==0.10.1
import torch
from pyannote.audio import Pipeline
import os
import sys
import yaml
from pathlib import Path
import glob
from transformers import pipeline, logging
from tqdm import tqdm
import torchaudio
from typing import List, Dict
import gc
import warnings
import time
from datetime import datetime, timedelta
import torch.multiprocessing as mp
from itertools import cycle
# Suppress various warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
logging.set_verbosity_error()
class ProgressPrinter:
def __init__(self):
self.start_time = time.time()
def print_progress(self, message: str, newline: bool = True):
elapsed = time.time() - self.start_time
timestamp = datetime.now().strftime("%H:%M:%S")
if newline:
print(f"[{timestamp}] ({timedelta(seconds=int(elapsed))}) {message}")
else:
print(f"[{timestamp}] ({timedelta(seconds=int(elapsed))}) {message}", end='', flush=True)
class ParallelProcessor:
def __init__(self):
self.gpu_count = torch.cuda.device_count()
if self.gpu_count < 1:
raise RuntimeError("No GPUs available")
self.progress = ProgressPrinter()
self.progress.print_progress(f"Found {self.gpu_count} GPUs:")
for i in range(self.gpu_count):
gpu_name = torch.cuda.get_device_name(i)
gpu_mem = torch.cuda.get_device_properties(i).total_memory / (1024**3)
self.progress.print_progress(f" GPU {i}: {gpu_name} ({gpu_mem:.1f}GB)")
# Print initial memory status
self.print_gpu_memory("Initial GPU memory status")
# Create temp directory
self.temp_dir = Path("temp_audio_segments")
self.temp_dir.mkdir(exist_ok=True)
# Initialize pipelines
self.initialize_pipelines()
def print_gpu_memory(self, message: str):
"""Print memory usage for all GPUs."""
self.progress.print_progress(f"\n{message}:")
for gpu_id in range(self.gpu_count):
allocated = torch.cuda.memory_allocated(f'cuda:{gpu_id}') / (1024**3)
reserved = torch.cuda.memory_reserved(f'cuda:{gpu_id}') / (1024**3)
self.progress.print_progress(f" GPU {gpu_id}:")
self.progress.print_progress(f" Allocated: {allocated:.1f}GB")
self.progress.print_progress(f" Reserved: {reserved:.1f}GB")
def clear_gpu_memory(self, gpu_id: int = None):
"""Clear memory on specified GPU or all GPUs."""
if gpu_id is not None:
# Clear specific GPU
with torch.cuda.device(f'cuda:{gpu_id}'):
torch.cuda.empty_cache()
gc.collect()
else:
# Clear all GPUs
for i in range(self.gpu_count):
with torch.cuda.device(f'cuda:{i}'):
torch.cuda.empty_cache()
gc.collect()
def initialize_pipelines(self):
"""Initialize pipelines for each GPU."""
self.progress.print_progress("Initializing models...")
# Initialize diarization on primary GPU
self.progress.print_progress("Initializing diarization pipeline...", False)
device = torch.device('cuda:0')
self.diarization = Pipeline.from_pretrained(
"pyannote/[email protected]",
use_auth_token=True
).to(device)
self.diarization.instantiate({
"segmentation": {
"min_duration_off": 0.5, # Minimum duration of non-speech
"threshold": 0.5 # Adjust this - higher means more conservative segmentation
},
"clustering": {
"min_cluster_size": 10, # Minimum number of segments per speaker
"threshold": 0.7 # Adjust this - higher means more speakers
}
})
print(" Done")
# Print memory after diarization init
self.print_gpu_memory("Memory after diarization init")
# Initialize ASR models on all GPUs with different batch sizes based on VRAM
self.asr_models = []
self.batch_sizes = [] # Store batch size for each GPU
for gpu_id in range(self.gpu_count):
# Get GPU memory
gpu_mem = torch.cuda.get_device_properties(gpu_id).total_memory / (1024**3)
# Set batch size based on GPU memory
batch_size = 4 if gpu_mem > 12 else 2
self.batch_sizes.append(batch_size)
self.progress.print_progress(f"Initializing Whisper model on GPU {gpu_id} (batch size: {batch_size})...", False)
asr = pipeline(
"automatic-speech-recognition",
model="primeline/whisper-large-v3-german",
device=f'cuda:{gpu_id}',
chunk_length_s=30, # Increased from 15
stride_length_s=3, # Adjusted accordingly
batch_size=batch_size,
generate_kwargs={
'task': 'transcribe',
'language': 'de',
'temperature': 0.0, # Remove randomness
'no_speech_threshold': 0.6, # More conservative about silence
'condition_on_previous_text': True, # Use context from previous segments
'compression_ratio_threshold': 2.4 # Help with word repetition
}
)
self.asr_models.append(asr)
print(" Done")
# Print memory after each ASR init
self.print_gpu_memory(f"Memory after ASR init on GPU {gpu_id}")
def run_diarization(self, audio_path: str):
"""Run diarization with explicit memory cleanup."""
try:
# Run diarization on GPU 0
self.progress.print_progress("Running speaker diarization...", False)
diarization = self.diarization(audio_path)
print(" Done")
# Convert to list immediately to free pipeline memory
segments = list(diarization.itertracks(yield_label=True))
# Explicit cleanup
del diarization
self.clear_gpu_memory(0)
# Print memory status after diarization
self.print_gpu_memory("Memory after diarization")
return segments
except Exception as e:
self.progress.print_progress(f"Error in diarization: {str(e)}")
return None
def transcribe_batch(self, audio_path: str, segments: List[tuple], gpu_id: int) -> List[str]:
"""Transcribe a batch of segments using the specified GPU."""
try:
waveform, sample_rate = torchaudio.load(audio_path)
batch_audio = []
temp_files = []
# Prepare all segments in batch
for segment, _, _ in segments:
start_sample = int(segment.start * sample_rate)
end_sample = int(segment.end * sample_rate)
segment_audio = waveform[:, start_sample:end_sample]
# Save temporary file
temp_path = self.temp_dir / f"segment_gpu{gpu_id}_{segment.start:.2f}.wav"
torchaudio.save(temp_path, segment_audio, sample_rate)
temp_files.append(temp_path)
batch_audio.append(str(temp_path))
# Process batch
results = self.asr_models[gpu_id](batch_audio)
# Extract texts
texts = []
if isinstance(results, dict):
texts = [results['text'].strip()]
else:
texts = [r['text'].strip() for r in results]
return texts
finally:
# Cleanup temp files
for temp_path in temp_files:
try:
temp_path.unlink()
except:
pass
# Clear GPU memory after batch
self.clear_gpu_memory(gpu_id)
def process_audio_file(self, audio_path: str) -> List[Dict]:
"""Process audio file using all available GPUs in parallel."""
try:
# Get audio duration
info = torchaudio.info(audio_path)
duration = info.num_frames / info.sample_rate
self.progress.print_progress(f"\nProcessing {audio_path}")
self.progress.print_progress(f"Audio duration: {timedelta(seconds=int(duration))}")
# Run diarization with memory management
all_segments = self.run_diarization(audio_path)
if not all_segments:
return None
self.progress.print_progress(f"Found {len(all_segments)} segments")
# Create a cycle of GPU IDs for round-robin assignment
gpu_cycle = cycle(range(self.gpu_count))
segments = []
# Process segments in batches
pbar = tqdm(
range(0, len(all_segments), max(self.batch_sizes)),
desc="Processing segment batches",
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
)
for batch_start in pbar:
# Get next GPU in rotation
gpu_id = next(gpu_cycle)
batch_size = self.batch_sizes[gpu_id]
# Get batch for this GPU
batch_end = min(batch_start + batch_size, len(all_segments))
batch = all_segments[batch_start:batch_end]
# Process batch
transcriptions = self.transcribe_batch(
audio_path,
batch,
gpu_id
)
# Add results
for (segment, _, speaker), transcription in zip(batch, transcriptions):
segments.append({
'speaker': speaker,
'start': round(segment.start, 2),
'end': round(segment.end, 2),
'duration': round(segment.end - segment.start, 2),
'text': transcription,
'gpu_used': gpu_id
})
# Print memory status periodically
if len(segments) % 10 == 0:
self.print_gpu_memory("Current GPU memory status")
return segments
except Exception as e:
self.progress.print_progress(f"Error processing {audio_path}: {str(e)}")
return None
finally:
# Final cleanup
self.clear_gpu_memory()
self.print_gpu_memory("Final GPU memory status")
def save_yaml(data: Dict, yaml_path: Path) -> None:
"""Save diarization and transcription data to a YAML file."""
try:
with open(yaml_path, 'w', encoding='utf-8') as f:
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
print(f"Saved results to {yaml_path}")
except Exception as e:
print(f"Error saving YAML file {yaml_path}: {str(e)}")
def save_transcript(segments: List[Dict], txt_path: Path) -> None:
"""Save a human-readable transcript."""
try:
with open(txt_path, 'w', encoding='utf-8') as f:
for segment in segments:
f.write(f"[{segment['speaker']}] {segment['start']}s - {segment['end']}s:\n")
f.write(f"{segment['text']}\n\n")
print(f"Saved transcript to {txt_path}")
except Exception as e:
print(f"Error saving transcript file {txt_path}: {str(e)}")
def main():
progress = ProgressPrinter()
# Initialize processor
try:
processor = ParallelProcessor()
except RuntimeError as e:
progress.print_progress(f"Error initializing processor: {str(e)}")
return
# Get WAV files from command line arguments
wav_files = []
for arg in sys.argv[1:]:
if '*' in arg:
wav_files.extend(glob.glob(arg))
else:
wav_files.append(arg)
# Filter out temporary files and ensure files exist
wav_files = [
f for f in wav_files
if f.lower().endswith('.wav')
and os.path.exists(f)
and not any(x in f for x in ["temp_segment", "segment_gpu"])
]
if not wav_files:
progress.print_progress("No WAV files found in arguments.")
progress.print_progress("Usage: python script.py file1.wav file2.wav")
progress.print_progress(" or: python script.py *.wav")
return
progress.print_progress(f"Found {len(wav_files)} WAV files to process")
for wav_file in wav_files:
wav_path = Path(wav_file)
yaml_path = wav_path.with_suffix('.yaml')
txt_path = wav_path.with_suffix('.txt')
# Process the audio file
segments = processor.process_audio_file(str(wav_path))
if segments:
# Save outputs
yaml_data = {
'audio_file': wav_path.name,
'segments': segments,
'processed_at': datetime.now().isoformat()
}
save_yaml(yaml_data, yaml_path)
save_transcript(segments, txt_path)
progress.print_progress("\nProcessing complete!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment