Created
January 23, 2020 14:33
-
-
Save padoremu/8288b47ce76e9530eb288d4eec2e0b4d to your computer and use it in GitHub Desktop.
Tensorflow and Tensorflow Lite code in the context of audio processing (MFCC, RNN)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This file contains a collection of workarounds for missing TFLite support from: | |
# https://github.com/tensorflow/magenta/tree/master/magenta/music | |
# as posted in https://github.com/tensorflow/tensorflow/issues/27303 | |
# Thanks a lot to github.com/rryan for his support! | |
# The function for testing MFCC computation given PCM input is: | |
# - test_mfcc_tflite | |
# Please not that the output has not yet been compared to the one produced by the respective TF functions. | |
# This file also contains test code for other problems in the context of audio processing with TF and TFLite: | |
# - test_rnn_gru_save_load (saving/loading of RNNs with more than one cell): https://github.com/tensorflow/tensorflow/issues/36093 | |
# - test_rnn_gru_tflite (missing TFLite support for some functions needed by RNN): https://github.com/tensorflow/tensorflow/issues/21526#issuecomment-577586202 | |
# This code has been tested with tf-nightly-2.0-preview (2.0.0.dev20190731) on Ubuntu 18.04 with python 3.6.9. | |
import tensorflow as tf | |
import numpy as np | |
import fractions | |
# constants | |
padding = 'right' | |
desired_samples = 16000 | |
window_size_samples = 640 | |
window_stride_samples = 640 | |
dct_coefficient_count = 10 | |
sample_rate = 16000 | |
# line of code for saving plot of model | |
#tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True) | |
# some converter flags that might help sometime | |
#converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] | |
#converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] | |
#converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] | |
#converter.experimental_new_converter = True | |
def _dft_matrix(dft_length): | |
"""Calculate the full DFT matrix in numpy.""" | |
omega = (0 + 1j) * 2.0 * np.pi / float(dft_length) | |
# Don't include 1/sqrt(N) scaling, tf.signal.rfft doesn't apply it. | |
return np.exp(omega * np.outer(np.arange(dft_length), np.arange(dft_length))) | |
def _naive_rdft(signal_tensor, fft_length, padding='center'): | |
"""Implement real-input Fourier Transform by matmul.""" | |
# We are right-multiplying by the DFT matrix, and we are keeping | |
# only the first half ("positive frequencies"). | |
# So discard the second half of rows, but transpose the array for | |
# right-multiplication. | |
# The DFT matrix is symmetric, so we could have done it more | |
# directly, but this reflects our intention better. | |
complex_dft_matrix_kept_values = _dft_matrix(fft_length)[:(fft_length // 2 + 1), :].transpose() | |
real_dft_tensor = tf.constant(np.real(complex_dft_matrix_kept_values).astype(np.float32), name='real_dft_matrix') | |
imag_dft_tensor = tf.constant(np.imag(complex_dft_matrix_kept_values).astype(np.float32), name='imaginary_dft_matrix') | |
signal_frame_length = signal_tensor.shape[-1]#.value | |
half_pad = (fft_length - signal_frame_length) // 2 | |
if padding == 'center': | |
# Center-padding. | |
pad_values = tf.concat([ | |
tf.zeros([tf.rank(signal_tensor) - 1, 2], tf.int32), | |
[[half_pad, fft_length - signal_frame_length - half_pad]] | |
], axis=0) | |
elif padding == 'right': | |
# Right-padding. | |
pad_values = tf.concat([ | |
tf.zeros([tf.rank(signal_tensor) - 1, 2], tf.int32), | |
[[0, fft_length - signal_frame_length]] | |
], axis=0) | |
padded_signal = tf.pad(signal_tensor, pad_values) | |
result_real_part = tf.matmul(padded_signal, real_dft_tensor) | |
result_imag_part = tf.matmul(padded_signal, imag_dft_tensor) | |
return result_real_part, result_imag_part | |
def _fixed_frame(signal, frame_length, frame_step, first_axis=False): | |
"""tflite-compatible tf.signal.frame for fixed-size input. | |
Args: | |
signal: Tensor containing signal(s). | |
frame_length: Number of samples to put in each frame. | |
frame_step: Sample advance between successive frames. | |
first_axis: If true, framing is applied to first axis of tensor; otherwise, | |
it is applied to last axis. | |
Returns: | |
A new tensor where the last axis (or first, if first_axis) of input | |
signal has been replaced by a (num_frames, frame_length) array of individual | |
frames where each frame is drawn frame_step samples after the previous one. | |
Raises: | |
ValueError: if signal has an undefined axis length. This routine only | |
supports framing of signals whose shape is fixed at graph-build time. | |
""" | |
signal_shape = signal.shape.as_list() | |
if first_axis: | |
length_samples = signal_shape[0] | |
else: | |
length_samples = signal_shape[-1] | |
if length_samples <= 0: | |
raise ValueError('fixed framing requires predefined constant signal length') | |
num_frames = max(0, 1 + (length_samples - frame_length) // frame_step) | |
if first_axis: | |
inner_dimensions = signal_shape[1:] | |
result_shape = [num_frames, frame_length] + inner_dimensions | |
gather_axis = 0 | |
else: | |
outer_dimensions = signal_shape[:-1] | |
result_shape = outer_dimensions + [num_frames, frame_length] | |
# Currently tflite's gather only supports axis==0, but that may still | |
# work if we want the last of 1 axes. | |
gather_axis = len(outer_dimensions) | |
subframe_length = fractions.gcd(frame_length, frame_step) # pylint: disable=deprecated-method | |
subframes_per_frame = frame_length // subframe_length | |
subframes_per_hop = frame_step // subframe_length | |
num_subframes = length_samples // subframe_length | |
if first_axis: | |
trimmed_input_size = [num_subframes * subframe_length] + inner_dimensions | |
subframe_shape = [num_subframes, subframe_length] + inner_dimensions | |
else: | |
trimmed_input_size = outer_dimensions + [num_subframes * subframe_length] | |
subframe_shape = outer_dimensions + [num_subframes, subframe_length] | |
subframes = tf.reshape( | |
tf.slice( | |
signal, | |
begin=np.zeros(len(signal_shape), np.int32), | |
size=trimmed_input_size), subframe_shape) | |
# frame_selector is a [num_frames, subframes_per_frame] tensor | |
# that indexes into the appropriate frame in subframes. For example: | |
# [[0, 0, 0, 0], [2, 2, 2, 2], [4, 4, 4, 4]] | |
frame_selector = np.reshape(np.arange(num_frames) * subframes_per_hop, [num_frames, 1]) | |
# subframe_selector is a [num_frames, subframes_per_frame] tensor | |
# that indexes into the appropriate subframe within a frame. For example: | |
# [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]] | |
subframe_selector = np.reshape(np.arange(subframes_per_frame), [1, subframes_per_frame]) | |
# Adding the 2 selector tensors together produces a [num_frames, | |
# subframes_per_frame] tensor of indices to use with tf.gather to select | |
# subframes from subframes. We then reshape the inner-most subframes_per_frame | |
# dimension to stitch the subframes together into frames. For example: | |
# [[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7]]. | |
selector = frame_selector + subframe_selector | |
frames = tf.reshape(tf.gather(subframes, selector.astype(np.int32), axis=gather_axis), result_shape) | |
return frames | |
def _stft_tflite(signal, frame_length, frame_step, fft_length): | |
"""tflite-compatible implementation of tf.signal.stft. | |
Compute the short-time Fourier transform of a 1D input while avoiding tf ops | |
that are not currently supported in tflite (Rfft, Range, SplitV). | |
fft_length must be fixed. A Hann window is of frame_length is always | |
applied. | |
Since fixed (precomputed) framing must be used, signal.shape[-1] must be a | |
specific value (so "?"/None is not supported). | |
Args: | |
signal: 1D tensor containing the time-domain waveform to be transformed. | |
frame_length: int, the number of points in each Fourier frame. | |
frame_step: int, the number of samples to advance between successive frames. | |
fft_length: int, the size of the Fourier transform to apply. | |
Returns: | |
Two (num_frames, fft_length) tensors containing the real and imaginary parts | |
of the short-time Fourier transform of the input signal. | |
""" | |
# Make the window be shape (1, frame_length) instead of just frame_length | |
# in an effort to help the tflite broadcast logic. | |
window = tf.reshape( | |
tf.constant( | |
(0.5 - 0.5 * np.cos(2 * np.pi * np.arange(0, 1.0, 1.0 / frame_length)) | |
).astype(np.float32), | |
name='window'), [1, frame_length]) | |
framed_signal = _fixed_frame(signal, frame_length, frame_step, first_axis=False) | |
framed_signal *= window | |
real_spectrogram, imag_spectrogram = _naive_rdft(framed_signal, fft_length) | |
return real_spectrogram, imag_spectrogram | |
def _stft_magnitude_tflite(waveform_input, window_length_samples, hop_length_samples, fft_length): | |
"""Calculate spectrogram avoiding tflite incompatible ops.""" | |
real_stft, imag_stft = _stft_tflite(waveform_input, frame_length=window_length_samples, frame_step=hop_length_samples, fft_length=fft_length) | |
stft_magnitude = tf.sqrt(tf.add(real_stft * real_stft, imag_stft * imag_stft), name='magnitude_spectrogram') | |
return stft_magnitude, real_stft.shape | |
def _stft_magnitude_full_tf(waveform_input, window_length_samples, hop_length_samples, fft_length): | |
"""Calculate STFT magnitude (spectrogram) using tf.signal ops.""" | |
sftfs = tf.signal.stft(waveform_input, frame_length=window_length_samples, frame_step=hop_length_samples, fft_length=fft_length) | |
stft_magnitude = tf.abs(sftfs, name='magnitude_spectrogram') | |
return stft_magnitude, sftfs.shape | |
def tflite_mfcc(log_mel_spectrogram): | |
axis_dim = log_mel_spectrogram.shape[-1]#.value | |
scale_arg = -tf.range(float(axis_dim)) * np.pi * 0.5 / float(axis_dim) | |
scale_real = 2.0 * tf.cos(scale_arg) | |
scale_imag = 2.0 * tf.sin(scale_arg) | |
# Check for tf.signal compatibility. | |
#scale = 2.0 * tf.exp(tf.complex(0.0, scale_arg)) | |
#np.testing.assert_allclose(scale, tf.complex(scale_real, scale_imag)) | |
rfft_real, rfft_imag = _naive_rdft(log_mel_spectrogram, fft_length=2 * axis_dim, padding=padding) | |
#rfft_real = rfft_real[..., :axis_dim] | |
rfft_real = rfft_real[:, :, :axis_dim] | |
# Conjugate to match tf.signal convention: | |
#rfft_imag = -rfft_imag[..., :axis_dim] | |
rfft_imag = -rfft_imag[:, :, :axis_dim] | |
# Check for tf.signal compatibility: | |
#tf_rfft = tf.signal.rfft(log_mel_spectrogram, fft_length=[2*axis_dim])[..., :axis_dim] | |
#np.testing.assert_allclose(tf_rfft, tf.complex(rfft_real, rfft_imag), atol=1e-4, rtol=1e-4) | |
dct2 = rfft_real * scale_real - rfft_imag * scale_imag | |
# Check for tf.signal compatibility: | |
#tf_dct2 = tf.real(tf_rfft * scale) | |
#np.testing.assert_allclose(dct2, tf_dct2, atol=1e-4, rtol=1e-4) | |
mfcc = dct2 * tf.math.rsqrt(2 * float(axis_dim)) | |
# Check for tf.signal compatibility: | |
#tf_mfcc = tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrogram) | |
#np.testing.assert_allclose(mfcc, tf_mfcc, atol=1e-4, rtol=1e-4) | |
return mfcc | |
def preprocess_audio(audio): | |
reshaped_audio = tf.reshape(audio, [1, audio.shape[1]]) | |
# note: parameter fft_length will be automatically chosen smallest power of 2 > frame_length, e.g. 640 => 1024 | |
#stfts = tf.signal.stft(reshaped_audio, frame_length=window_size_samples, frame_step=window_stride_samples) | |
#spectrograms_abs = tf.abs(stfts) | |
spectrograms_abs, shape = _stft_magnitude_tflite(waveform_input=reshaped_audio, window_length_samples=window_size_samples, hop_length_samples=window_stride_samples, fft_length=1024) | |
spectrograms = tf.square(spectrograms_abs) # can this be done better? | |
# warp the linear scale spectrograms into the mel-scale | |
num_spectrogram_bins = shape[-1]#sfts. #.value | |
lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 80 | |
linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, upper_edge_hertz) | |
mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1) | |
mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate(linear_to_mel_weight_matrix.shape[-1:])) | |
# compute a stabilized log to get log-magnitude mel-scale spectrograms | |
log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6) | |
# compute MFCCs from log_mel_spectrograms and take the dct_coefficient_count coefficients | |
#result_mfcc = tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms)[..., :dct_coefficient_count] | |
#result_mfcc = tflite_mfcc(log_mel_spectrograms)[..., :dct_coefficient_count] | |
result_mfcc = tflite_mfcc(log_mel_spectrograms)[:, :, :dct_coefficient_count] | |
# reshape to (1, x) | |
reshaped_mfcc = tf.reshape(result_mfcc, [1, tf.size(result_mfcc)]) | |
return reshaped_mfcc | |
# tests model conversion to tensorflow lite and invocation | |
def test_tflite(model): | |
converter = tf.lite.TFLiteConverter.from_keras_model(model) | |
#converter.allow_custom_ops = True # does the trick | |
#converter.target_spec.supported_ops = set([tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]) | |
#converter.experimental_new_converter = True # does not help | |
tflite_model = converter.convert() | |
interpreter = tf.lite.Interpreter(model_content=tflite_model) | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
interpreter.allocate_tensors() | |
#interpreter.set_tensor(input_details[0]['index'], tf.convert_to_tensor(np.expand_dims(audio_data, 0), dtype=tf.float32)) | |
interpreter.invoke() | |
output = interpreter.get_tensor(output_details[0]['index']) | |
# tests using RNN with GRU cell in tensorflow lite | |
def test_rnn_gru_tflite(): | |
model = tf.keras.Sequential() | |
model.add(tf.keras.layers.Input(shape=(1, 1,))) | |
cell = tf.keras.layers.GRUCell(10) | |
model.add(tf.keras.layers.RNN(cell)) | |
test_tflite(model) | |
# tests saving/loading RNN with more than one GRU cell | |
def test_rnn_gru_save_load(save_format): | |
# saving succeeds for number_of_cells = 1, but fails for number_of_cells > 1 | |
number_of_cells = 2 | |
model = tf.keras.Sequential() | |
model.add(tf.keras.layers.Input(shape=(1, 1,))) | |
cells = [] | |
for _ in range(number_of_cells): | |
cells.append(tf.keras.layers.GRUCell(10)) | |
model.add(tf.keras.layers.RNN(cells)) | |
model_path = "rnn.h5" if save_format == "h5" else "rnn.tf" | |
model.save(model_path, save_format=save_format) | |
model2 = tf.keras.models.load_model(model_path) | |
# tests audio preprocessing (stft / abs / mfcc) in tensorflow lite | |
def test_mfcc_tflite(): | |
model = tf.keras.Sequential() | |
model.add(tf.keras.layers.Input(shape=16000, dtype=tf.float32)) | |
model.add(tf.keras.layers.Lambda(preprocess_audio)) | |
test_tflite(model) | |
def print_version(): | |
print(tf.version.GIT_VERSION, tf.version.VERSION) | |
print_version() | |
test_mfcc_tflite() | |
#test_rnn_gru_tflite() | |
#test_rnn_gru_save_load("h5") | |
#test_rnn_gru_save_load("tf") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment