|
""" |
|
Train a neural network to learn an FIR filter. |
|
|
|
Created on Fri Aug 3 15:00:40 2018 |
|
""" |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import Dense |
|
from tensorflow.keras.callbacks import Callback |
|
import numpy as np |
|
from scipy import signal |
|
import matplotlib.pyplot as plt |
|
from soundfile import read |
|
|
|
""" |
|
Generate or load a signal to use as input data |
|
""" |
|
# Only learns at the frequencies present in the signal |
|
# https://en.wikipedia.org/wiki/File:Short-beaked_Echidna.ogg |
|
sig, fs = read('echidna.wav') |
|
|
|
# Learns at all frequencies with white noise |
|
# sig, fs = np.random.randn(10000), 10000 |
|
|
|
""" |
|
Create the FIR filter for the ANN to copy |
|
""" |
|
numtaps = 51 |
|
# b = signal.firwin(numtaps, 1, fs=fs) |
|
# b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris', |
|
# pass_zero=False) |
|
b = signal.firwin(numtaps, cutoff=[6000, 11000], fs=fs, |
|
window='blackmanharris', pass_zero=False) |
|
|
|
# TODO: Use an IIR filter and have ANN approximate it as best it can |
|
|
|
""" |
|
Training data is chunks of input and output of FIR filter |
|
""" |
|
# filtered = signal.lfilter(b, 1.0, sig) |
|
filtered = signal.convolve(sig, b, mode='valid') |
|
|
|
|
|
def rolling_window(a, window): |
|
""" |
|
Return chunks of signal `a` of size `window`, incremented by 1 each time. |
|
|
|
https://gist.github.com/codehacken/708f19ae746784cef6e68b037af65788 |
|
""" |
|
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) |
|
strides = a.strides + (a.strides[-1],) |
|
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) |
|
|
|
|
|
X = rolling_window(sig, numtaps) |
|
Y = filtered # Filter outputs 1 sample for each chunk of input samples |
|
|
|
# plt.plot(X[0]) |
|
# plt.plot(Y[0]) |
|
|
|
|
|
""" |
|
Create model |
|
|
|
Initializer matters because signal might have missing areas of spectrum, and |
|
model will not learn there. So if the initial guess is all zeros, those areas |
|
of the spectrum will stay silenced, while the passband is "built up" for |
|
frequencies that are present. |
|
""" |
|
model = Sequential([ |
|
Dense(1, input_dim=numtaps, use_bias=False, |
|
# kernel_initializer='random_normal', # typical usage |
|
# kernel_initializer='ones', # boxcar window = running average |
|
kernel_initializer='zeros', # nothing (good for non-white input) |
|
) |
|
]) |
|
|
|
model.summary() |
|
|
|
initial = model.get_weights() |
|
|
|
print('Initial weights:') |
|
print(initial) |
|
|
|
|
|
""" |
|
Make block diagram of network (not from tutorial) |
|
""" |
|
from tensorflow.keras.utils import plot_model |
|
plot_model(model, to_file='model.png', show_shapes=True) |
|
|
|
|
|
""" |
|
Make graph diagram of network (not from tutorial) |
|
|
|
Viewable with `tensorboard --logdir="logs"` |
|
""" |
|
import tensorflow as tf |
|
|
|
# TODO: This isn't working like it used to. Replace with TF2.0 conventions. |
|
with tf.compat.v1.Session() as sess: |
|
writer = tf.compat.v1.summary.FileWriter('logs', sess.graph) |
|
writer.close() |
|
|
|
|
|
""" |
|
Node-level graph |
|
""" |
|
|
|
# Working version: https://github.com/endolith/ann-visualizer |
|
# from ann_visualizer.visualize import ann_viz |
|
# ann_viz(model, title="Learned FIR filter") |
|
|
|
# https://github.com/Dicksonchin93/keras-architecture-visualizer/ |
|
# from keras_architecture_visualizer import KerasArchitectureVisualizer |
|
# vis = KerasArchitectureVisualizer() |
|
# vis.visualize(model) |
|
|
|
# Compile model |
|
model.compile(loss='mean_squared_error', |
|
optimizer='adam', |
|
) |
|
|
|
|
|
class LossHistory(Callback): |
|
def on_train_begin(self, logs={}): |
|
self.losses = [] |
|
# Could plot the convergence here |
|
|
|
def on_batch_end(self, batch, logs={}): |
|
self.losses.append(logs.get('loss')) |
|
|
|
def on_epoch_end(self, batch, logs={}): |
|
pass |
|
# Could plot the convergence here |
|
|
|
|
|
history = LossHistory() |
|
|
|
|
|
# Fit the model |
|
print("Fitting...") |
|
model.fit(X, Y, epochs=35, batch_size=100, callbacks=[history]) |
|
|
|
# evaluate the model |
|
print("Evaluating...") |
|
scores = model.evaluate(X, Y) |
|
print(scores*100) # percent?? |
|
|
|
final = model.get_weights() |
|
|
|
fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, num='kernel', sharex=True) |
|
ax1.plot(b, '.-', label='Filter', alpha=0.5, c='gray') |
|
ax0.plot(initial[0], '.-', label='Initial') |
|
ax1.plot(final[0], '.-', label='Learned') |
|
ax0.grid(True, color='0.7', linestyle='-', which='major') |
|
ax0.grid(True, color='0.9', linestyle='-', which='minor') |
|
ax1.grid(True, color='0.7', linestyle='-', which='major') |
|
ax1.grid(True, color='0.9', linestyle='-', which='minor') |
|
ax0.set_title('Kernel') |
|
ax0.legend() |
|
ax1.legend() |
|
|
|
plt.figure('frequency response') |
|
w, h = signal.freqz(b, [1.0]) |
|
plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Filter', |
|
alpha=0.5, c='gray') |
|
w, h = signal.freqz(initial[0], [1.0]) |
|
plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Initial') |
|
w, h = signal.freqz(final[0], [1.0]) |
|
plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Learned', alpha=0.5) |
|
plt.grid(True, color='0.7', linestyle='-', which='major') |
|
plt.grid(True, color='0.9', linestyle='-', which='minor') |
|
plt.xlabel('Frequency [Hz]') |
|
plt.ylabel('Response [dB]') |
|
plt.xlim(None, fs/2) |
|
plt.title('Frequency response') |
|
plt.legend() |
|
|
|
plt.figure('loss') |
|
plt.semilogy(history.losses) |
|
plt.xlabel('Batch') |
|
plt.ylabel('Loss') |
|
plt.grid(True, which="both") |
|
plt.title('Loss') |
Yes, "non-white" = signals that don't cover the entire spectrum.
If you have white noise in the training data, then use
(But note that there is no good reason to learn filters this way. It is just an experiment for learning about neural networks.)