|
""" |
|
Train a neural network to learn an FIR filter. |
|
|
|
Created on Fri Aug 3 15:00:40 2018 |
|
""" |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.layers import Dense |
|
from tensorflow.keras.callbacks import Callback |
|
import numpy as np |
|
from scipy import signal |
|
import matplotlib.pyplot as plt |
|
from soundfile import read |
|
|
|
""" |
|
Generate or load a signal to use as input data |
|
""" |
|
# Only learns at the frequencies present in the signal |
|
# https://en.wikipedia.org/wiki/File:Short-beaked_Echidna.ogg |
|
sig, fs = read('echidna.wav') |
|
|
|
# Learns at all frequencies with white noise |
|
# sig, fs = np.random.randn(10000), 10000 |
|
|
|
""" |
|
Create the FIR filter for the ANN to copy |
|
""" |
|
numtaps = 51 |
|
# b = signal.firwin(numtaps, 1, fs=fs) |
|
# b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris', |
|
# pass_zero=False) |
|
b = signal.firwin(numtaps, cutoff=[6000, 11000], fs=fs, |
|
window='blackmanharris', pass_zero=False) |
|
|
|
# TODO: Use an IIR filter and have ANN approximate it as best it can |
|
|
|
""" |
|
Training data is chunks of input and output of FIR filter |
|
""" |
|
# filtered = signal.lfilter(b, 1.0, sig) |
|
filtered = signal.convolve(sig, b, mode='valid') |
|
|
|
|
|
def rolling_window(a, window): |
|
""" |
|
Return chunks of signal `a` of size `window`, incremented by 1 each time. |
|
|
|
https://gist.github.com/codehacken/708f19ae746784cef6e68b037af65788 |
|
""" |
|
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) |
|
strides = a.strides + (a.strides[-1],) |
|
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) |
|
|
|
|
|
X = rolling_window(sig, numtaps) |
|
Y = filtered # Filter outputs 1 sample for each chunk of input samples |
|
|
|
# plt.plot(X[0]) |
|
# plt.plot(Y[0]) |
|
|
|
|
|
""" |
|
Create model |
|
|
|
Initializer matters because signal might have missing areas of spectrum, and |
|
model will not learn there. So if the initial guess is all zeros, those areas |
|
of the spectrum will stay silenced, while the passband is "built up" for |
|
frequencies that are present. |
|
""" |
|
model = Sequential([ |
|
Dense(1, input_dim=numtaps, use_bias=False, |
|
# kernel_initializer='random_normal', # typical usage |
|
# kernel_initializer='ones', # boxcar window = running average |
|
kernel_initializer='zeros', # nothing (good for non-white input) |
|
) |
|
]) |
|
|
|
model.summary() |
|
|
|
initial = model.get_weights() |
|
|
|
print('Initial weights:') |
|
print(initial) |
|
|
|
|
|
""" |
|
Make block diagram of network (not from tutorial) |
|
""" |
|
from tensorflow.keras.utils import plot_model |
|
plot_model(model, to_file='model.png', show_shapes=True) |
|
|
|
|
|
""" |
|
Make graph diagram of network (not from tutorial) |
|
|
|
Viewable with `tensorboard --logdir="logs"` |
|
""" |
|
import tensorflow as tf |
|
|
|
# TODO: This isn't working like it used to. Replace with TF2.0 conventions. |
|
with tf.compat.v1.Session() as sess: |
|
writer = tf.compat.v1.summary.FileWriter('logs', sess.graph) |
|
writer.close() |
|
|
|
|
|
""" |
|
Node-level graph |
|
""" |
|
|
|
# Working version: https://github.com/endolith/ann-visualizer |
|
# from ann_visualizer.visualize import ann_viz |
|
# ann_viz(model, title="Learned FIR filter") |
|
|
|
# https://github.com/Dicksonchin93/keras-architecture-visualizer/ |
|
# from keras_architecture_visualizer import KerasArchitectureVisualizer |
|
# vis = KerasArchitectureVisualizer() |
|
# vis.visualize(model) |
|
|
|
# Compile model |
|
model.compile(loss='mean_squared_error', |
|
optimizer='adam', |
|
) |
|
|
|
|
|
class LossHistory(Callback): |
|
def on_train_begin(self, logs={}): |
|
self.losses = [] |
|
# Could plot the convergence here |
|
|
|
def on_batch_end(self, batch, logs={}): |
|
self.losses.append(logs.get('loss')) |
|
|
|
def on_epoch_end(self, batch, logs={}): |
|
pass |
|
# Could plot the convergence here |
|
|
|
|
|
history = LossHistory() |
|
|
|
|
|
# Fit the model |
|
print("Fitting...") |
|
model.fit(X, Y, epochs=35, batch_size=100, callbacks=[history]) |
|
|
|
# evaluate the model |
|
print("Evaluating...") |
|
scores = model.evaluate(X, Y) |
|
print(scores*100) # percent?? |
|
|
|
final = model.get_weights() |
|
|
|
fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, num='kernel', sharex=True) |
|
ax1.plot(b, '.-', label='Filter', alpha=0.5, c='gray') |
|
ax0.plot(initial[0], '.-', label='Initial') |
|
ax1.plot(final[0], '.-', label='Learned') |
|
ax0.grid(True, color='0.7', linestyle='-', which='major') |
|
ax0.grid(True, color='0.9', linestyle='-', which='minor') |
|
ax1.grid(True, color='0.7', linestyle='-', which='major') |
|
ax1.grid(True, color='0.9', linestyle='-', which='minor') |
|
ax0.set_title('Kernel') |
|
ax0.legend() |
|
ax1.legend() |
|
|
|
plt.figure('frequency response') |
|
w, h = signal.freqz(b, [1.0]) |
|
plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Filter', |
|
alpha=0.5, c='gray') |
|
w, h = signal.freqz(initial[0], [1.0]) |
|
plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Initial') |
|
w, h = signal.freqz(final[0], [1.0]) |
|
plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Learned', alpha=0.5) |
|
plt.grid(True, color='0.7', linestyle='-', which='major') |
|
plt.grid(True, color='0.9', linestyle='-', which='minor') |
|
plt.xlabel('Frequency [Hz]') |
|
plt.ylabel('Response [dB]') |
|
plt.xlim(None, fs/2) |
|
plt.title('Frequency response') |
|
plt.legend() |
|
|
|
plt.figure('loss') |
|
plt.semilogy(history.losses) |
|
plt.xlabel('Batch') |
|
plt.ylabel('Loss') |
|
plt.grid(True, which="both") |
|
plt.title('Loss') |