Skip to content

Instantly share code, notes, and snippets.

@nwatab
Created September 18, 2020 03:25
Show Gist options
  • Save nwatab/1e221a9fe200dd547bc0605d90835ae4 to your computer and use it in GitHub Desktop.
Save nwatab/1e221a9fe200dd547bc0605d90835ae4 to your computer and use it in GitHub Desktop.
Log scale melspectrogram layer tensorflow implementation (thanks to https://keunwoochoi.wordpress.com/2019/09/28/log-melspectrogram-layer-using-tensorflow-keras/)
import tensorflow as tf
class LogMelspectrogramLayer(tf.keras.layers.Layer):
"""
signals = librosa.load('path/to/audio.mp3')
log_melspectrogram_layer = LogMelspectrogramLayer()
logmelspectrogram = log_melspectrogram_layer(signals)
"""
def __init__(self, num_fft=2048, hop_length=512, sr=24000, fmin=125., fmax=3800., num_mel=128, **kwargs):
super(LogMelgramLayer, self).__init__(**kwargs)
self.num_fft = num_fft
self.hop_length = hop_length
lin_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=num_mel,
num_spectrogram_bins=num_fft // 2 + 1,
sample_rate=sr,
lower_edge_hertz=fmin,
upper_edge_hertz=fmax,
)
self.lin_to_mel_matrix = lin_to_mel_matrix
def call(self, input):
"""
Args:
input (tensor): Batch of mono waveform, shape: (None, N)
Returns:
log_melgrams (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1)
"""
EPS = 1e-6
def _power_to_db(x):
""" 10 * log10(x) """
numerator = tf.math.log(x)
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
return 10. * numerator / denominator
# tf.signal.stft seems to be applied along the last axis
stfts = tf.signal.stft(
input, frame_length=self.num_fft, frame_step=self.hop_length
)
mag_stfts = tf.abs(stfts) # complex to real
melgrams = tf.matmul(tf.square(mag_stfts), self.lin_to_mel_matrix)
log_melgrams = _power_to_db(melgrams + EPS)
return log_melgrams
def get_config(self):
config = {'num_fft': self.num_fft, 'hop_length': self.hop_length}
base_config = super(LogMelgramLayer, self).get_config()
return dict(list(config.items()) + list(base_config.items()))
@nwatab
Copy link
Author

nwatab commented Sep 18, 2020

melspectrogram_washikamome

import matplotlit.pyplot as plt
y, _ = librosa.load('./glaucous-winged_gull.mp3')
log_melspectrogram_layer = LogMelspectrogramLayer()
s = log_melspectrogram_layer(y).numpy()
s -= np.min(s)
s /= np.max(s)
plt.imshow(s, vmin=0, vmax=1)
plt.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment