Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save pkpio/cea6554a3ec4c6b07eb8f3639235ae22 to your computer and use it in GitHub Desktop.

Select an option

Save pkpio/cea6554a3ec4c6b07eb8f3639235ae22 to your computer and use it in GitHub Desktop.
Speaker Diarization on Large Audio Files with pyannote/speaker-diarization-3.1
#!/usr/bin/env python3
import os
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
import torch
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from pyannote.audio import Pipeline
import librosa
import soundfile as sf
def plot_embeddings_in_2D(df: pd.DataFrame, embedding_dims: int = 256):
embedding_cols = [f'dim_{i}' for i in range(embedding_dims)]
vectors = df[embedding_cols].values
reducer = PCA(n_components=2, random_state=42)
vectors_2d = reducer.fit_transform(vectors)
# Plot
plt.figure(figsize=(10, 8))
plt.scatter(vectors_2d[:, 0], vectors_2d[:, 1], alpha=0.6)
plt.ylabel('Dimension 2')
plt.xlabel('Dimension 1')
plt.title('256D Vectors Visualized in 2D')
plt.grid(True, alpha=0.3)
plt.show()
def main():
#======================================================
# Configuration
#======================================================
huggingface_token='YOUR-SECRET-TOKEN'
audio_file_paths = [
'./path/to/large_file_one.wav',
'./path/to/large_file_two.wav',
#...
]
chunk_duration=15*60 # in seconds
embedding_dims = 256 # dimensions of the pyannote speaker embeddings
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=huggingface_token)
# Comment out below when not on Apple Silicon; use 'cpu' or 'cuda'
mps_device = torch.device("mps")
pipeline.to(mps_device)
embed_cols = [f'dim_{i}' for i in range(embedding_dims)]
#======================================================
# Make chunks and run speaker diarization per chunk
#======================================================
segments = []
for audio_file_path in audio_file_paths:
print(f'Processing {audio_file_path}')
# chunk audio for processing 30 minutes per file
print(f'Loading audio file...')
y, sr = librosa.load(audio_file_path)
print(f'Loaded. Sampling Rate is {sr}.')
chunk_samples = int(chunk_duration * sr)
chunks = [y[i:i + chunk_samples] for i in range(0, len(y), chunk_samples)]
chunks_dir = audio_file_path[:-4] + '/'
os.makedirs(chunks_dir, exist_ok=True)
print(f'Saving {len(chunks)} chunks of the audio into {chunks_dir}...')
chunk_file_paths = []
for i, ch in enumerate(chunks):
chunk_file_path = chunks_dir + f'chunk_{i:02d}.wav'
chunk_file_paths.append(chunk_file_path)
print(f' Saving chunk {i:02d} to {chunk_file_path}...')
with sf.SoundFile(chunk_file_path, mode="w", samplerate=sr, channels=1, format="WAV", subtype="PCM_16") as f:
f.write(ch)
print(f'Done with chunking. About to run speaker diarization on chunks.')
for i, chunk_file_path in enumerate(chunk_file_paths):
print(f'Processing {chunk_file_path}...')
diarization, embeddings = pipeline(chunk_file_path, return_embeddings=True)
print(f'Done. Storing diarization and embeddings...')
start_offset = i * chunk_duration
for segment, t, label in diarization.itertracks(yield_label=True):
embed = embeddings[int(label[-2:])]
data = {
'audio_file_path': audio_file_path,
'chunk_path': chunk_file_path,
'chunk': i,
'start': segment.start + start_offset,
'duration': segment.duration,
'end': segment.start + start_offset + segment.duration,
'chunk_level_label': label,
**{embed_cols[i]: embed[i] for i in range(embedding_dims)}
}
segments.append(data)
checkpoint_file_path = f'{chunks_dir}/segments.csv'
print(f'Done with processing {audio_file_path}, saving checkpoint results to {checkpoint_file_path}')
pd.DataFrame(segments).to_csv(checkpoint_file_path, index=False)
# Dump all results to disk
output_file_path = f'./diarization_results_run_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.csv'
df = pd.DataFrame(segments)
print(f'Saving all {len(df)} results to {output_file_path}...')
df.to_csv(output_file_path, index=False)
print(f'Done.')
#======================================================
# (Optional) Plot embeddings in 2D
#======================================================
plot_embeddings_in_2D(df, embedding_dims)
#======================================================
# Cluster embeddings to identify speakers
#======================================================
print('Clustering embeddings to identify speakers...')
embedding_cols = [f'dim_{i}' for i in range(embedding_dims)]
embeddings = df[embedding_cols].values
# Standardize embeddings for better clustering
scaler = StandardScaler()
embeddings_scaled = scaler.fit_transform(embeddings)
dbscan = DBSCAN(eps=0.5, min_samples=5, metric='cosine')
df['speaker'] = dbscan.fit_predict(embeddings_scaled)
output_file_path = f'./diarization_with_speakers_{datetime.now().strftime("%Y-%m-%d_%H-%M")}.csv'
print(f'Done. Saving results with speaker column to {output_file_path}...')
df.to_csv(output_file_path, index=False)
print(f'Done.')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment