Created
November 18, 2016 09:18
-
-
Save sunsided/f9cda9cfc926436704bab28473ad182c to your computer and use it in GitHub Desktop.
Caffe LSTM trouble
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# http://christopher5106.github.io/deep/learning/2016/06/07/recurrent-neural-net-with-Caffe.html | |
# https://github.com/BVLC/caffe/pull/3948 | |
# https://github.com/junhyukoh/caffe-lstm/blob/master/examples/lstm_sequence/lstm_sequence.cpp | |
# https://github.com/BVLC/caffe/issues/4547 | |
import caffe | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# noinspection PyUnresolvedReferences | |
import seaborn as sns | |
# generate data | |
a = np.arange(0, 32, 0.01) | |
d = 0.5 * np.sin(2 * a) - 0.05 * np.cos(17 * a + 0.8) + 0.05 * np.sin(25 * a + 10) - 0.02 * np.cos(45 * a + 0.3) | |
d = d / max(np.max(d), -np.min(d)) | |
d = d - np.mean(d) | |
caffe.set_mode_gpu() | |
solver = caffe.SGDSolver('solver.prototxt') | |
# train the network | |
print('Training network ...') | |
#niter = 5000 | |
niter = 500 | |
train_loss = np.zeros(niter) | |
# Set the bias to the forget gate to 5.0 as explained in the clockwork RNN paper | |
solver.net.params['lstm1'][2].data[15:30] = 5 | |
solver.net.blobs['clip'].data[...] = 1 | |
for i in range(niter): | |
seq_idx = i % (len(d) / 320) | |
solver.net.blobs['clip'].data[0] = seq_idx > 0 | |
solver.net.blobs['label'].data[:, 0] = d[seq_idx * 320: (seq_idx + 1) * 320] | |
solver.step(1) | |
train_loss[i] = solver.net.blobs['loss'].data | |
print('Done training network.') | |
# TODO: Losses are bad | |
# plot the training loss | |
plt.plot(np.arange(niter), train_loss) | |
plt.show() | |
# TODO: It will fail below this line | |
# test the network | |
print('Testing network ...') | |
solver.test_nets[0].blobs['data'].reshape(2, 1) | |
solver.test_nets[0].blobs['clip'].reshape(2, 1) | |
solver.test_nets[0].reshape() | |
solver.test_nets[0].blobs['clip'].data[...] = 1 | |
preds = np.zeros(len(d)) | |
for i in range(len(d)): | |
solver.test_nets[0].blobs['clip'].data[0] = i > 0 | |
preds[i] = solver.test_nets[0].forward()['ip1'][0][0] | |
print('Done testing network.') | |
# plot the training output | |
plt.plot(np.arange(len(d)), preds) | |
plt.plot(np.arange(len(d)), d) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: "LSTM" | |
# T = 320 time_steps, N = 1 streams, 1-D data | |
input: "data" | |
input_shape { dim: 320 dim: 1 dim: 1 } | |
input: "clip" | |
input_shape { dim: 320 dim: 1 } | |
input: "label" | |
input_shape { dim: 320 dim: 1 } | |
layer { | |
name: "Silence" | |
type: "Silence" | |
bottom: "label" | |
include: { phase: TEST } | |
} | |
layer { | |
name: "lstm1" | |
type: "LSTM" | |
bottom: "data" | |
bottom: "clip" | |
top: "lstm1" | |
recurrent_param { | |
num_output: 15 | |
weight_filler { | |
type: "uniform" | |
min: -0.01 | |
max: 0.01 | |
} | |
bias_filler { | |
type: "constant" | |
value: 0 | |
} | |
} | |
} | |
layer { | |
name: "ip1" | |
type: "InnerProduct" | |
bottom: "lstm1" | |
top: "ip1" | |
inner_product_param { | |
num_output: 1 | |
weight_filler { | |
type: "gaussian" | |
std: 0.1 | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "loss" | |
type: "EuclideanLoss" | |
bottom: "ip1" | |
bottom: "label" | |
top: "loss" | |
include: { phase: TRAIN } | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
net: "lstm.prototxt" | |
test_iter: 10 | |
test_interval: 2000000 | |
base_lr: 0.0001 | |
momentum: 0.95 | |
clip_gradients: 0.1 | |
lr_policy: "fixed" | |
display: 200 | |
max_iter: 100000 | |
solver_mode: CPU | |
average_loss: 200 | |
debug_info: false |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment