Skip to content

Instantly share code, notes, and snippets.

@victormurcia
Created October 22, 2024 17:25
Show Gist options
  • Save victormurcia/e2c9fa0d092b80af3f3db938d74069af to your computer and use it in GitHub Desktop.
Save victormurcia/e2c9fa0d092b80af3f3db938d74069af to your computer and use it in GitHub Desktop.
def estimate_n_speakers_nmf(features, max_speakers, min_speakers):
"""
Estimate number of speakers using NMF reconstruction error and stability analysis.
"""
reconstruction_errors = []
stability_scores = []
# Run NMF multiple times for each number of components
for n in range(min_speakers, max_speakers + 1):
errors = []
components_list = []
# Run multiple times to assess stability
for _ in range(3):
model = NMF(
n_components=n,
init='random',
random_state=None,
max_iter=300
)
W = model.fit_transform(np.abs(features))
H = model.components_
# Calculate reconstruction error
reconstruction = np.dot(W, H)
error = np.mean((np.abs(features) - reconstruction) ** 2)
errors.append(error)
# Store components for stability analysis
components_list.append(W)
# Calculate mean reconstruction error
reconstruction_errors.append(np.mean(errors))
# Calculate stability score using correlation between runs
stability = 0
for i in range(len(components_list)):
for j in range(i + 1, len(components_list)):
correlation_matrix = np.corrcoef(components_list[i].T, components_list[j].T)
stability += np.mean(np.abs(correlation_matrix[:n, n:]))
stability_scores.append(stability / (len(components_list) * (len(components_list) - 1) / 2))
# Normalize scores
reconstruction_errors = np.array(reconstruction_errors)
stability_scores = np.array(stability_scores)
reconstruction_scores = 1 - (reconstruction_errors - np.min(reconstruction_errors)) / (np.max(reconstruction_errors) - np.min(reconstruction_errors))
# Combine scores with emphasis on stability
combined_scores = 0.7 * stability_scores + 0.3 * reconstruction_scores
# Find optimal number of speakers
optimal_n = np.argmax(combined_scores) + min_speakers
return optimal_n
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment