Skip to content

Instantly share code, notes, and snippets.

@twobob
Last active October 13, 2024 10:36
Show Gist options
  • Save twobob/8586d2005d2303a766bbc7540bf80054 to your computer and use it in GitHub Desktop.
Save twobob/8586d2005d2303a766bbc7540bf80054 to your computer and use it in GitHub Desktop.
wav classifier (sort by reduced features revision)
import os
import argparse
import librosa
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import dash
from dash import dcc, html
from dash.dependencies import Input, Output, State
from flask import send_file
# Path to your audio files
AUDIO_DIR = 'd:\\wav' # Use double backslashes or raw strings for Windows paths
FEATURES_LIST = [
'mfcc', 'chroma', 'spec_contrast', 'tonnetz', 'zcr',
'spec_centroid', 'spec_rolloff', 'rms', 'spec_bw', 'tempo'
]
def extract_features(file_path):
"""
Extract audio features from a WAV file.
"""
y, sr = librosa.load(file_path, sr=None, mono=True, duration=47)
features = {}
# 1. MFCCs
mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
features['mfcc_mean'] = np.mean(mfccs, axis=1)
features['mfcc_std'] = np.std(mfccs, axis=1)
# 2. Chroma Features
chroma = librosa.feature.chroma_stft(y=y, sr=sr)
features['chroma_mean'] = np.mean(chroma, axis=1)
features['chroma_std'] = np.std(chroma, axis=1)
# 3. Spectral Contrast
spec_contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
features['spec_contrast_mean'] = np.mean(spec_contrast, axis=1)
features['spec_contrast_std'] = np.std(spec_contrast, axis=1)
# 4. Tonnetz
y_harmonic = librosa.effects.harmonic(y)
tonnetz = librosa.feature.tonnetz(y=y_harmonic, sr=sr)
features['tonnetz_mean'] = np.mean(tonnetz, axis=1)
features['tonnetz_std'] = np.std(tonnetz, axis=1)
# 5. Zero-Crossing Rate
zcr = librosa.feature.zero_crossing_rate(y)
features['zcr_mean'] = np.mean(zcr)
features['zcr_std'] = np.std(zcr)
# 6. Spectral Centroid
spec_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
features['spec_centroid_mean'] = np.mean(spec_centroid)
features['spec_centroid_std'] = np.std(spec_centroid)
# 7. Spectral Roll-off
spec_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
features['spec_rolloff_mean'] = np.mean(spec_rolloff)
features['spec_rolloff_std'] = np.std(spec_rolloff)
# 8. RMS Energy
rms = librosa.feature.rms(y=y)
features['rms_mean'] = np.mean(rms)
features['rms_std'] = np.std(rms)
# 9. Spectral Bandwidth
spec_bw = librosa.feature.spectral_bandwidth(y=y, sr=sr)
features['spec_bw_mean'] = np.mean(spec_bw)
features['spec_bw_std'] = np.std(spec_bw)
# 10. Tempo
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
features['tempo'] = tempo
return features
def process_audio_directory(directory, batch_size, df_features):
"""
Process all WAV files in the given directory and extract features.
"""
feature_list = []
filenames = []
cumulative_feature_list = []
cumulative_filenames = []
file_list = [f for f in os.listdir(directory) if f.lower().endswith('.wav')]
if batch_size:
file_list = file_list[:max(batch_size, 30)] # FIXME WE NEED 30 for the FIXME elsewhere
total_processed_records = len(df_features) # Initialize with existing records
last_backup_threshold = total_processed_records # Initialize the last backup threshold
# Calculate total_to_process
if batch_size:
total_to_process = min(total_processed_records + batch_size, len(file_list))
else:
total_to_process = total_processed_records + len(file_list)
for filename in file_list:
if filename in df_features['filename'].values:
continue
file_path = os.path.join(directory, filename)
print(f"Processing {total_processed_records + 1} of about {total_to_process} files \n {filename}")
try:
features = extract_features(file_path)
feature_list.append(features)
filenames.append(filename)
cumulative_feature_list.append(features)
cumulative_filenames.append(filename)
total_processed_records += 1
# Check if we've crossed a 0.5k threshold
if total_processed_records >= last_backup_threshold + 500:
new_features_df = pd.DataFrame(cumulative_feature_list)
new_features_df['filename'] = cumulative_filenames
# Combine with existing features
df_features_temp = pd.concat([df_features, new_features_df], ignore_index=True)
df_features_temp.drop_duplicates(subset=['filename'], inplace=True)
# Calculate the n for the filename based on total processed records
n_thousands = total_processed_records / 1000.0
n_formatted = f"{n_thousands:.1f}".rstrip('0').rstrip('.')
backup_filename = f"audio_features_{n_formatted}k.pkl"
df_features_temp.to_pickle(backup_filename)
print(f"Backup saved to {backup_filename}")
# Update the last backup threshold
last_backup_threshold = (total_processed_records // 500) * 500
# Clear cumulative lists to free up memory
cumulative_feature_list = []
cumulative_filenames = []
except Exception as e:
print(f"Error processing {filename}: {e}")
# After processing all files, create new_features_df
new_features_df = pd.DataFrame(feature_list)
new_features_df['filename'] = filenames
return new_features_df
def truncate_filename(filename):
if len(filename) > 20:
return f"{filename[:7]}…{filename[-15:]}"
else:
return filename
def main():
parser = argparse.ArgumentParser(description="Audio Feature Extraction and Visualization")
parser.add_argument('-a', '--append', action='store_true', help="Append new features to existing dataset")
parser.add_argument('-b', '--batch', type=int, default=0, help="Number of audio files to process in this batch (minimum 30)")
args = parser.parse_args()
# Load or process features
features_pkl = 'audio_features.pkl'
if args.append and os.path.exists(features_pkl):
df_features = pd.read_pickle(features_pkl)
else:
# Initialize with 'filename' column to prevent KeyError
df_features = pd.DataFrame(columns=['filename'])
new_features_df = process_audio_directory(AUDIO_DIR, args.batch, df_features)
df_features = pd.concat([df_features, new_features_df], ignore_index=True)
df_features.drop_duplicates(subset=['filename'], inplace=True)
df_features.to_pickle(features_pkl)
# Truncate filenames
df_features['truncated_filename'] = df_features['filename'].apply(truncate_filename)
app = dash.Dash(__name__, suppress_callback_exceptions=True)
server = app.server
@server.route('/audio/<filename>')
def serve_audio(filename):
return send_file(os.path.join(AUDIO_DIR, filename))
# Define color schemes
DARK_COLORS = {
'background': '#1a1a1a',
'paper': '#2d2d2d',
'text': '#ffffff',
'secondary_text': '#a0a0a0',
'button': '#404040',
'button_hover': '#505050',
'accent': '#bb86fc'
}
LIGHT_COLORS = {
'background': '#ffffff',
'paper': '#f5f5f5',
'text': '#000000',
'secondary_text': '#505050',
'button': '#e0e0e0',
'button_hover': '#d0d0d0',
'accent': '#1976d2'
}
app.layout = html.Div([
dcc.Store(id='theme', data='dark'), # Default theme is dark
dcc.Store(id='audio-enabled', data=False),
dcc.Store(id='selected-points', data=[]),
dcc.Store(id='stored-features', data=df_features.to_dict('records')),
dcc.Store(id='status-message', data=''),
html.Div(id='app-content')
])
@app.callback(
Output('app-content', 'children'),
[Input('theme', 'data'),
Input('stored-features', 'data')]
)
def render_layout(theme, stored_features):
COLORS = DARK_COLORS if theme == 'dark' else LIGHT_COLORS
total_num_files = len(stored_features)
styles = {
'header': {
'padding': '10px',
'backgroundColor': COLORS['paper'],
'color': COLORS['text'],
'display': 'flex'
},
'headerTitle': {
'margin': '0',
'display': 'flex',
'alignItems': 'center'
},
'buttonContainer': {
'marginLeft': '10px',
'display': 'flex',
'alignItems': 'center'
},
'controlButton': {
'marginRight': '10px',
'padding': '10px',
'backgroundColor': COLORS['button'],
'color': COLORS['text'],
'border': 'none',
'cursor': 'pointer'
},
'controlButtonHover': {
'backgroundColor': COLORS['button_hover']
},
'rangeSlider':{
'width': '220px',
},
'controlInput': {
'marginRight': '10px',
'padding': '5px',
'backgroundColor': COLORS['paper'],
'color': COLORS['text'],
'border': '1px solid',
'borderColor': COLORS['secondary_text'],
'borderRadius': '5px',
'width': '80px',
'height': '40px',
'fontSize': '16px',
'textAlign': 'center'
},
'mainContent': {
'display': 'flex',
'backgroundColor': COLORS['background'],
'color': COLORS['text']
},
'plotContainer': {
'flex': 3
},
'sidebar': {
'flex': 1,
'padding': '20px',
'backgroundColor': COLORS['paper'],
'color': COLORS['text']
},
'selectedList': {
'maxHeight': '400px',
'overflowY': 'auto'
},
'graph': {
'backgroundColor': COLORS['background']
},
'checkbox': {
'marginRight': '10px',
'color': COLORS['text'],
'display': 'flex',
'marginBottom': '5px',
'marginTop': '-15px'
},
'featureSelection': {
'paddingLeft': '5px',
'backgroundColor': COLORS['background']
},
'featureLabel': {
'display': 'flex',
'alignItems': 'center',
'cursor': 'pointer',
'fontSize': '14px'
},
'featureCheckbox': {
'marginRight': '5px'
},
'statusArea': {
'padding': '10px',
'backgroundColor': COLORS['paper'],
'color': COLORS['text'],
'fontSize': '14px'
}
}
# Feature selection checkboxes
feature_checkboxes = html.Div([
html.H4("Select Features to Include", style=styles['featureSelection']),
dcc.Checklist(
id='feature-selection',
options=[{'label': feature, 'value': feature} for feature in FEATURES_LIST],
value=FEATURES_LIST.copy(),
style=styles['checkbox'],
inputStyle={"marginRight": "5px"}
)
])
return html.Div([
html.Div([
html.H2("Audio Sample Explorer", style=styles['headerTitle']),
html.Div([
html.Button('Toggle Theme', id='toggle-theme-button', style=styles['controlButton']),
html.Button('Enable Audio Playback', id='enable-audio-button', style=styles['controlButton']),
html.Button('Play Selected', id='play-selected-button', style=styles['controlButton']),
html.Button('Stop All', id='stop-button', style=styles['controlButton']),
html.Span([
dcc.RangeSlider(
id='subset-slider',
min=0,
max=total_num_files - 1,
value=[0, min(100, total_num_files - 1)],
marks={
0: {'label': '0'},
total_num_files - 1: {'label': str(total_num_files - 1)}
},
tooltip={'always_visible': False, 'placement': 'bottom'},
allowCross=False,
updatemode='mouseup',
)
],style=styles['rangeSlider'])
,
dcc.Input(
id='n-clusters-input',
type='number',
min=2,
value=8,
step=1,
debounce=True,
style=styles['controlInput']
),
], style=styles['buttonContainer']),
], style=styles['header']),
html.Div([
html.Div(id='initial-status-message', style={'margin-bottom': '10px'}),
html.Div(id='final-status-message')
], id='status-area', style=styles['statusArea']),
html.Div([
html.Div([
feature_checkboxes,
dcc.Graph(
id='scatter-plot',
config={'doubleClick': 'reset', 'modeBarButtonsToAdd': ['lasso2d']},
style=styles['graph']
),
], style=styles['plotContainer']),
html.Div([
html.H3("Selected Samples"),
html.Div(id='selected-samples', style=styles['selectedList']),
html.H3("Audio Visualization"),
html.Div(id='audio-player'),
dcc.Graph(id='waveform-plot', style={'backgroundColor': COLORS['paper']})
], style=styles['sidebar']),
], style=styles['mainContent']),
html.Div(id='audio-players')
])
@app.callback(
Output('theme', 'data'),
Input('toggle-theme-button', 'n_clicks'),
State('theme', 'data'),
prevent_initial_call=True
)
def toggle_theme(n_clicks, current_theme):
return 'light' if current_theme == 'dark' else 'dark'
@app.callback(
Output('audio-enabled', 'data'),
[Input('enable-audio-button', 'n_clicks'),
Input('stop-button', 'n_clicks')],
[State('audio-enabled', 'data')],
prevent_initial_call=True
)
def toggle_audio(enable_clicks, stop_clicks, audio_enabled):
ctx = dash.callback_context
if not ctx.triggered:
raise dash.exceptions.PreventUpdate
button_id = ctx.triggered[0]['prop_id'].split('.')[0]
if button_id == 'enable-audio-button':
return not audio_enabled
elif button_id == 'stop-button':
return False
else:
return audio_enabled
@app.callback(
[Output('scatter-plot', 'figure'),
Output('final-status-message', 'children')],
[Input('feature-selection', 'value'),
Input('n-clusters-input', 'value'),
Input('subset-slider', 'value')],
State('stored-features', 'data'),
State('theme', 'data')
)
def update_scatter_plot(selected_features, n_clusters, subset_slider_value, stored_features, theme):
COLORS = DARK_COLORS if theme == 'dark' else LIGHT_COLORS
df_features = pd.DataFrame(stored_features)
if n_clusters is None:
n_clusters = 8
start_idx, end_idx = subset_slider_value
subset_df_features = df_features.iloc[start_idx:end_idx + 1].reset_index(drop=True)
feature_columns = []
for feature in selected_features:
if feature == 'mfcc':
feature_columns.extend(['mfcc_mean', 'mfcc_std'])
elif feature == 'chroma':
feature_columns.extend(['chroma_mean', 'chroma_std'])
elif feature == 'spec_contrast':
feature_columns.extend(['spec_contrast_mean', 'spec_contrast_std'])
elif feature == 'tonnetz':
feature_columns.extend(['tonnetz_mean', 'tonnetz_std'])
elif feature == 'zcr':
feature_columns.extend(['zcr_mean', 'zcr_std'])
elif feature == 'spec_centroid':
feature_columns.extend(['spec_centroid_mean', 'spec_centroid_std'])
elif feature == 'spec_rolloff':
feature_columns.extend(['spec_rolloff_mean', 'spec_rolloff_std'])
elif feature == 'rms':
feature_columns.extend(['rms_mean', 'rms_std'])
elif feature == 'spec_bw':
feature_columns.extend(['spec_bw_mean', 'spec_bw_std'])
elif feature == 'tempo':
feature_columns.append('tempo')
if not feature_columns:
return go.Figure(), "No features selected."
X = subset_df_features[feature_columns]
for col in X.columns:
if X[col].apply(lambda x: isinstance(x, list)).any():
X[col] = X[col].apply(np.array)
X_flat = pd.DataFrame()
for col in X.columns:
if X[col].apply(lambda x: isinstance(x, (np.ndarray, list))).all():
arr = np.stack(X[col].apply(np.array).values)
arr_df = pd.DataFrame(arr, columns=[f"{col}_{i}" for i in range(arr.shape[1])])
X_flat = pd.concat([X_flat.reset_index(drop=True), arr_df.reset_index(drop=True)], axis=1)
else:
X_flat[col] = X[col]
if not np.isfinite(X_flat.to_numpy()).all():
status_message = "Data contains NaNs or Infs after flattening features."
return go.Figure(), status_message
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_flat)
n_samples, n_features = X_scaled.shape
n_components = min(n_samples, n_features, 30)
if n_components >= 2:
pca = PCA(n_components=n_components, random_state=1337)
X_pca = pca.fit_transform(X_scaled)
tsne = TSNE(n_components=2, perplexity=5 if n_samples < 50 else len(X) * .2, random_state=1337)
X_embedded = tsne.fit_transform(X_pca)
subset_df_features['x'] = X_embedded[:, 0]
subset_df_features['y'] = X_embedded[:, 1]
else:
return go.Figure(), "Insufficient data for dimensionality reduction."
sample_count = len(subset_df_features)
max_clusters = int(max(2, sample_count * 0.1))
n_clusters = max(2, min(n_clusters, max_clusters))
kmeans = KMeans(n_clusters=n_clusters, random_state=1337)
subset_df_features['cluster'] = kmeans.fit_predict(X_pca)
colormap = px.colors.qualitative.Safe
fig = go.Figure()
for i in range(n_clusters):
cluster_data = subset_df_features[subset_df_features['cluster'] == i]
fig.add_trace(go.Scatter(
x=cluster_data['x'],
y=cluster_data['y'],
mode='markers',
name=f'Cluster {i}',
marker=dict(size=3, color=colormap[i % len(colormap)]),
customdata=cluster_data[['filename', 'truncated_filename', 'cluster', 'tempo', 'rms_mean']].values,
hovertemplate="""
<b>Filename:</b> %{customdata[1]}<br>
<b>Cluster:</b> %{customdata[2]}<br>
<b>Tempo:</b> %{customdata[3]:.4f}<br>
<b>RMS Mean:</b> %{customdata[4]:.4f}<extra></extra>
"""
))
fig.update_layout(
title=None,
xaxis_title=None,
yaxis_title=None,
margin=dict(t=30, l=0, r=0, b=0),
hovermode='closest',
height=800,
plot_bgcolor=COLORS['paper'],
paper_bgcolor=COLORS['paper'],
font_color=COLORS['text'],
legend_font_color=COLORS['text'],
xaxis=dict(color=COLORS['text']),
yaxis=dict(color=COLORS['text'])
)
status_message = f"Processing completed for subset {start_idx}-{end_idx}. Number of samples: {len(subset_df_features)}. Number of clusters: {n_clusters}."
return fig, status_message
app.clientside_callback(
"""
function(selected_features, n_clusters, subset_slider_value) {
var start_idx = subset_slider_value[0];
var end_idx = subset_slider_value[1];
var status_message = `Processing subset ${start_idx}-${end_idx}. Selected features: ${selected_features.join(', ')}. Clustering into ${n_clusters} clusters.`;
return status_message;
}
""",
Output('initial-status-message', 'children'),
[Input('feature-selection', 'value'),
Input('n-clusters-input', 'value'),
Input('subset-slider', 'value')]
)
@app.callback(
Output('selected-points', 'data'),
Input('scatter-plot', 'selectedData'),
prevent_initial_call=True
)
def update_selected_points(selectedData):
if selectedData is None:
return []
return [point['customdata'][0] for point in selectedData['points']]
@app.callback(
Output('selected-samples', 'children'),
Input('selected-points', 'data'),
State('stored-features', 'data'),
State('theme', 'data')
)
def update_selected_samples(selected_points, stored_features, theme):
COLORS = DARK_COLORS if theme == 'dark' else LIGHT_COLORS
if not selected_points:
return html.P("No samples selected", style={'color': COLORS['text']})
truncated_filenames = [truncate_filename(fn) for fn in selected_points]
return [html.Div(filename, style={'color': COLORS['text']}) for filename in truncated_filenames]
@app.callback(
[Output('waveform-plot', 'figure'),
Output('audio-player', 'children')],
[Input('scatter-plot', 'hoverData'),
Input('audio-enabled', 'data')],
State('theme', 'data')
)
def update_waveform_plot(hoverData, audio_enabled, theme):
COLORS = DARK_COLORS if theme == 'dark' else LIGHT_COLORS
if hoverData is None or not audio_enabled:
return go.Figure(), []
filename = hoverData['points'][0]['customdata'][0]
file_path = os.path.join(AUDIO_DIR, filename)
clusterinfo = hoverData['points'][0]['customdata'][2]
try:
y, sr = librosa.load(file_path, sr=None, mono=True, duration=47)
fig = make_subplots(rows=2, cols=1, subplot_titles=(f"Cluster {clusterinfo}", "Spectrogram"))
t = np.linspace(0, len(y) / sr, len(y))
fig.add_trace(go.Scatter(x=t, y=y, mode='lines', line=dict(color=COLORS['accent'])), row=1, col=1)
S = librosa.feature.melspectrogram(y=y, sr=sr)
S_dB = librosa.power_to_db(S, ref=np.max)
fig.add_trace(go.Heatmap(
z=S_dB,
x=np.linspace(0, len(y) / sr, S_dB.shape[1]),
y=np.arange(S_dB.shape[0]),
colorscale='Viridis',
showscale=False
), row=2, col=1)
fig.update_layout(
height=800,
title_text=f"{truncate_filename(filename)}",
plot_bgcolor=COLORS['paper'],
paper_bgcolor=COLORS['paper'],
font_color=COLORS['text']
)
fig.update_xaxes(title_text="Time (s)", row=1, col=1, color=COLORS['text'])
fig.update_xaxes(title_text="Time (s)", row=2, col=1, color=COLORS['text'])
fig.update_yaxes(title_text="Amplitude", row=1, col=1, color=COLORS['text'])
fig.update_yaxes(title_text="Mel Frequency", row=2, col=1, color=COLORS['text'])
audio_player = html.Audio(src=f'/audio/{filename}', controls=True, autoPlay=True, style={'width': '100%'})
return fig, audio_player
except Exception as e:
print(f"Error loading audio file {filename}: {e}")
return go.Figure(), []
@app.callback(
Output('audio-players', 'children'),
[Input('play-selected-button', 'n_clicks'),
Input('stop-button', 'n_clicks'),
State('audio-enabled', 'data')],
[State('selected-points', 'data')],
)
def update_audio_players(play_n_clicks, stop_n_clicks, audio_enabled, selected_points):
ctx = dash.callback_context
if not ctx.triggered:
raise dash.exceptions.PreventUpdate
button_id = ctx.triggered[0]['prop_id'].split('.')[0]
if button_id == 'stop-button':
return []
if button_id == 'play-selected-button':
if not audio_enabled or not selected_points:
return []
audio_elements = []
for filename in selected_points:
audio_element = html.Audio(
src=f'/audio/{filename}',
controls=True,
autoPlay=True,
style={'width': '100%'}
)
audio_elements.append(audio_element)
return audio_elements
return []
if __name__ == '__main__':
app.run_server(debug=True)
if __name__ == '__main__':
main()
@twobob
Copy link
Author

twobob commented Oct 12, 2024

V2 (dark mode by default revision)
V3 (sort by reduced features revision)
V4 (range slider, better logging and batch management, auto pkl backups every 500 records total)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment