Created
December 20, 2024 16:45
-
-
Save Getty/41ac7e47ac63662685df3a2259d4b384 to your computer and use it in GitHub Desktop.
Script for splitting meetups into different speaker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --extra-index-url https://download.pytorch.org/whl/cu121 | |
# torch==2.1.2+cu121 | |
# torchaudio==2.1.2+cu121 | |
# pyannote.audio==3.1.1 | |
# transformers==4.36.2 | |
# numpy==1.24.3 | |
# tqdm==4.66.1 | |
# PyYAML==6.0.1 | |
# soundfile==0.12.1 | |
# typing-extensions>=4.8.0 | |
# librosa==0.10.1 | |
import torch | |
from pyannote.audio import Pipeline | |
import os | |
import sys | |
import yaml | |
from pathlib import Path | |
import glob | |
from transformers import pipeline, logging | |
from tqdm import tqdm | |
import torchaudio | |
from typing import List, Dict | |
import gc | |
import warnings | |
import time | |
from datetime import datetime, timedelta | |
import torch.multiprocessing as mp | |
from itertools import cycle | |
# Suppress various warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
logging.set_verbosity_error() | |
class ProgressPrinter: | |
def __init__(self): | |
self.start_time = time.time() | |
def print_progress(self, message: str, newline: bool = True): | |
elapsed = time.time() - self.start_time | |
timestamp = datetime.now().strftime("%H:%M:%S") | |
if newline: | |
print(f"[{timestamp}] ({timedelta(seconds=int(elapsed))}) {message}") | |
else: | |
print(f"[{timestamp}] ({timedelta(seconds=int(elapsed))}) {message}", end='', flush=True) | |
class ParallelProcessor: | |
def __init__(self): | |
self.gpu_count = torch.cuda.device_count() | |
if self.gpu_count < 1: | |
raise RuntimeError("No GPUs available") | |
self.progress = ProgressPrinter() | |
self.progress.print_progress(f"Found {self.gpu_count} GPUs:") | |
for i in range(self.gpu_count): | |
gpu_name = torch.cuda.get_device_name(i) | |
gpu_mem = torch.cuda.get_device_properties(i).total_memory / (1024**3) | |
self.progress.print_progress(f" GPU {i}: {gpu_name} ({gpu_mem:.1f}GB)") | |
# Print initial memory status | |
self.print_gpu_memory("Initial GPU memory status") | |
# Create temp directory | |
self.temp_dir = Path("temp_audio_segments") | |
self.temp_dir.mkdir(exist_ok=True) | |
# Initialize pipelines | |
self.initialize_pipelines() | |
def print_gpu_memory(self, message: str): | |
"""Print memory usage for all GPUs.""" | |
self.progress.print_progress(f"\n{message}:") | |
for gpu_id in range(self.gpu_count): | |
allocated = torch.cuda.memory_allocated(f'cuda:{gpu_id}') / (1024**3) | |
reserved = torch.cuda.memory_reserved(f'cuda:{gpu_id}') / (1024**3) | |
self.progress.print_progress(f" GPU {gpu_id}:") | |
self.progress.print_progress(f" Allocated: {allocated:.1f}GB") | |
self.progress.print_progress(f" Reserved: {reserved:.1f}GB") | |
def clear_gpu_memory(self, gpu_id: int = None): | |
"""Clear memory on specified GPU or all GPUs.""" | |
if gpu_id is not None: | |
# Clear specific GPU | |
with torch.cuda.device(f'cuda:{gpu_id}'): | |
torch.cuda.empty_cache() | |
gc.collect() | |
else: | |
# Clear all GPUs | |
for i in range(self.gpu_count): | |
with torch.cuda.device(f'cuda:{i}'): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def initialize_pipelines(self): | |
"""Initialize pipelines for each GPU.""" | |
self.progress.print_progress("Initializing models...") | |
# Initialize diarization on primary GPU | |
self.progress.print_progress("Initializing diarization pipeline...", False) | |
device = torch.device('cuda:0') | |
self.diarization = Pipeline.from_pretrained( | |
"pyannote/[email protected]", | |
use_auth_token=True | |
).to(device) | |
self.diarization.instantiate({ | |
"segmentation": { | |
"min_duration_off": 0.5, # Minimum duration of non-speech | |
"threshold": 0.5 # Adjust this - higher means more conservative segmentation | |
}, | |
"clustering": { | |
"min_cluster_size": 10, # Minimum number of segments per speaker | |
"threshold": 0.7 # Adjust this - higher means more speakers | |
} | |
}) | |
print(" Done") | |
# Print memory after diarization init | |
self.print_gpu_memory("Memory after diarization init") | |
# Initialize ASR models on all GPUs with different batch sizes based on VRAM | |
self.asr_models = [] | |
self.batch_sizes = [] # Store batch size for each GPU | |
for gpu_id in range(self.gpu_count): | |
# Get GPU memory | |
gpu_mem = torch.cuda.get_device_properties(gpu_id).total_memory / (1024**3) | |
# Set batch size based on GPU memory | |
batch_size = 4 if gpu_mem > 12 else 2 | |
self.batch_sizes.append(batch_size) | |
self.progress.print_progress(f"Initializing Whisper model on GPU {gpu_id} (batch size: {batch_size})...", False) | |
asr = pipeline( | |
"automatic-speech-recognition", | |
model="primeline/whisper-large-v3-german", | |
device=f'cuda:{gpu_id}', | |
chunk_length_s=30, # Increased from 15 | |
stride_length_s=3, # Adjusted accordingly | |
batch_size=batch_size, | |
generate_kwargs={ | |
'task': 'transcribe', | |
'language': 'de', | |
'temperature': 0.0, # Remove randomness | |
'no_speech_threshold': 0.6, # More conservative about silence | |
'condition_on_previous_text': True, # Use context from previous segments | |
'compression_ratio_threshold': 2.4 # Help with word repetition | |
} | |
) | |
self.asr_models.append(asr) | |
print(" Done") | |
# Print memory after each ASR init | |
self.print_gpu_memory(f"Memory after ASR init on GPU {gpu_id}") | |
def run_diarization(self, audio_path: str): | |
"""Run diarization with explicit memory cleanup.""" | |
try: | |
# Run diarization on GPU 0 | |
self.progress.print_progress("Running speaker diarization...", False) | |
diarization = self.diarization(audio_path) | |
print(" Done") | |
# Convert to list immediately to free pipeline memory | |
segments = list(diarization.itertracks(yield_label=True)) | |
# Explicit cleanup | |
del diarization | |
self.clear_gpu_memory(0) | |
# Print memory status after diarization | |
self.print_gpu_memory("Memory after diarization") | |
return segments | |
except Exception as e: | |
self.progress.print_progress(f"Error in diarization: {str(e)}") | |
return None | |
def transcribe_batch(self, audio_path: str, segments: List[tuple], gpu_id: int) -> List[str]: | |
"""Transcribe a batch of segments using the specified GPU.""" | |
try: | |
waveform, sample_rate = torchaudio.load(audio_path) | |
batch_audio = [] | |
temp_files = [] | |
# Prepare all segments in batch | |
for segment, _, _ in segments: | |
start_sample = int(segment.start * sample_rate) | |
end_sample = int(segment.end * sample_rate) | |
segment_audio = waveform[:, start_sample:end_sample] | |
# Save temporary file | |
temp_path = self.temp_dir / f"segment_gpu{gpu_id}_{segment.start:.2f}.wav" | |
torchaudio.save(temp_path, segment_audio, sample_rate) | |
temp_files.append(temp_path) | |
batch_audio.append(str(temp_path)) | |
# Process batch | |
results = self.asr_models[gpu_id](batch_audio) | |
# Extract texts | |
texts = [] | |
if isinstance(results, dict): | |
texts = [results['text'].strip()] | |
else: | |
texts = [r['text'].strip() for r in results] | |
return texts | |
finally: | |
# Cleanup temp files | |
for temp_path in temp_files: | |
try: | |
temp_path.unlink() | |
except: | |
pass | |
# Clear GPU memory after batch | |
self.clear_gpu_memory(gpu_id) | |
def process_audio_file(self, audio_path: str) -> List[Dict]: | |
"""Process audio file using all available GPUs in parallel.""" | |
try: | |
# Get audio duration | |
info = torchaudio.info(audio_path) | |
duration = info.num_frames / info.sample_rate | |
self.progress.print_progress(f"\nProcessing {audio_path}") | |
self.progress.print_progress(f"Audio duration: {timedelta(seconds=int(duration))}") | |
# Run diarization with memory management | |
all_segments = self.run_diarization(audio_path) | |
if not all_segments: | |
return None | |
self.progress.print_progress(f"Found {len(all_segments)} segments") | |
# Create a cycle of GPU IDs for round-robin assignment | |
gpu_cycle = cycle(range(self.gpu_count)) | |
segments = [] | |
# Process segments in batches | |
pbar = tqdm( | |
range(0, len(all_segments), max(self.batch_sizes)), | |
desc="Processing segment batches", | |
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' | |
) | |
for batch_start in pbar: | |
# Get next GPU in rotation | |
gpu_id = next(gpu_cycle) | |
batch_size = self.batch_sizes[gpu_id] | |
# Get batch for this GPU | |
batch_end = min(batch_start + batch_size, len(all_segments)) | |
batch = all_segments[batch_start:batch_end] | |
# Process batch | |
transcriptions = self.transcribe_batch( | |
audio_path, | |
batch, | |
gpu_id | |
) | |
# Add results | |
for (segment, _, speaker), transcription in zip(batch, transcriptions): | |
segments.append({ | |
'speaker': speaker, | |
'start': round(segment.start, 2), | |
'end': round(segment.end, 2), | |
'duration': round(segment.end - segment.start, 2), | |
'text': transcription, | |
'gpu_used': gpu_id | |
}) | |
# Print memory status periodically | |
if len(segments) % 10 == 0: | |
self.print_gpu_memory("Current GPU memory status") | |
return segments | |
except Exception as e: | |
self.progress.print_progress(f"Error processing {audio_path}: {str(e)}") | |
return None | |
finally: | |
# Final cleanup | |
self.clear_gpu_memory() | |
self.print_gpu_memory("Final GPU memory status") | |
def save_yaml(data: Dict, yaml_path: Path) -> None: | |
"""Save diarization and transcription data to a YAML file.""" | |
try: | |
with open(yaml_path, 'w', encoding='utf-8') as f: | |
yaml.dump(data, f, allow_unicode=True, sort_keys=False) | |
print(f"Saved results to {yaml_path}") | |
except Exception as e: | |
print(f"Error saving YAML file {yaml_path}: {str(e)}") | |
def save_transcript(segments: List[Dict], txt_path: Path) -> None: | |
"""Save a human-readable transcript.""" | |
try: | |
with open(txt_path, 'w', encoding='utf-8') as f: | |
for segment in segments: | |
f.write(f"[{segment['speaker']}] {segment['start']}s - {segment['end']}s:\n") | |
f.write(f"{segment['text']}\n\n") | |
print(f"Saved transcript to {txt_path}") | |
except Exception as e: | |
print(f"Error saving transcript file {txt_path}: {str(e)}") | |
def main(): | |
progress = ProgressPrinter() | |
# Initialize processor | |
try: | |
processor = ParallelProcessor() | |
except RuntimeError as e: | |
progress.print_progress(f"Error initializing processor: {str(e)}") | |
return | |
# Get WAV files from command line arguments | |
wav_files = [] | |
for arg in sys.argv[1:]: | |
if '*' in arg: | |
wav_files.extend(glob.glob(arg)) | |
else: | |
wav_files.append(arg) | |
# Filter out temporary files and ensure files exist | |
wav_files = [ | |
f for f in wav_files | |
if f.lower().endswith('.wav') | |
and os.path.exists(f) | |
and not any(x in f for x in ["temp_segment", "segment_gpu"]) | |
] | |
if not wav_files: | |
progress.print_progress("No WAV files found in arguments.") | |
progress.print_progress("Usage: python script.py file1.wav file2.wav") | |
progress.print_progress(" or: python script.py *.wav") | |
return | |
progress.print_progress(f"Found {len(wav_files)} WAV files to process") | |
for wav_file in wav_files: | |
wav_path = Path(wav_file) | |
yaml_path = wav_path.with_suffix('.yaml') | |
txt_path = wav_path.with_suffix('.txt') | |
# Process the audio file | |
segments = processor.process_audio_file(str(wav_path)) | |
if segments: | |
# Save outputs | |
yaml_data = { | |
'audio_file': wav_path.name, | |
'segments': segments, | |
'processed_at': datetime.now().isoformat() | |
} | |
save_yaml(yaml_data, yaml_path) | |
save_transcript(segments, txt_path) | |
progress.print_progress("\nProcessing complete!") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment