Created
February 2, 2015 11:23
-
-
Save ShigekiKarita/e1902eb13d64cf649d47 to your computer and use it in GitHub Desktop.
Learning auto encoder with RNN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# built in | |
import random | |
# previous code | |
from sdr import to_sdr, to_real_number | |
# third library : https://github.com/pybrain/pybrain | |
from pybrain.structure import LinearLayer, SigmoidLayer, BiasUnit, FullConnection, RecurrentNetwork | |
from pybrain.datasets import SupervisedDataSet | |
from pybrain.supervised.trainers import BackpropTrainer | |
def construct_rnn(input_nodes, hidden_nodes, output_nodes): | |
n = RecurrentNetwork() | |
n.addInputModule(LinearLayer(input_nodes, name="i")) | |
n.addModule(BiasUnit("b")) | |
n.addModule(SigmoidLayer(hidden_nodes, name="h")) | |
n.addOutputModule(LinearLayer(output_nodes, name="o")) | |
n.addConnection(FullConnection(n["i"], n["h"])) | |
n.addConnection(FullConnection(n["b"], n["h"])) | |
n.addConnection(FullConnection(n["b"], n["o"])) | |
n.addConnection(FullConnection(n["h"], n["o"])) | |
n.addRecurrentConnection(FullConnection(n["h"], n["h"])) | |
n.sortModules() | |
n.reset() | |
return n | |
def construct_data(total_data_size, train_data_size, length=100): | |
real_input_list = [random.uniform(0, 1.0/3) for i in range(total_data_size)] | |
real_output_list = [3 * i for i in real_input_list] | |
sdr_input_length = length | |
sdr_output_length = length | |
sdr_input_list = [to_sdr(i, sdr_input_length) for i in real_input_list] | |
sdr_output_list = [to_sdr(i, sdr_output_length) for i in real_output_list] | |
sdr_train_list = zip(sdr_input_list[:train_data_size], | |
sdr_output_list[:train_data_size]) | |
sdr_eval_list = zip(sdr_input_list[train_data_size:], | |
sdr_output_list[train_data_size:]) | |
return sdr_train_list, sdr_eval_list | |
def learn_multiplication(length_list): | |
digit_length = 100 | |
rnn = construct_rnn(digit_length, 10, digit_length) | |
data = SupervisedDataSet(digit_length, digit_length) | |
trainer = BackpropTrainer(rnn, data) | |
total_data_size = 100 | |
train_data_size = 90 | |
eval_data_size = total_data_size - train_data_size | |
error = [] | |
for input_window in length_list: | |
sdr_train_list, sdr_eval_list = construct_data(total_data_size, train_data_size, digit_length) | |
# train | |
for i, o in sdr_train_list: | |
remain = [0] * (digit_length - input_window) | |
input = i[:input_window] + remain | |
output = o[:input_window] + remain | |
data.addSample(input, output) | |
trainer.train() | |
# eval score | |
rnn.reset() | |
err = 0.0 | |
for i, o in sdr_eval_list: | |
out = rnn.activate(i) | |
err += abs(to_real_number(o) - to_real_number(out)) | |
error.append(err / eval_data_size) | |
print("length:", input_window, ",\t error:", err) | |
return error | |
def plot_error(x, y): | |
import pylab | |
import time | |
pylab.plot(x, y) | |
args = {'fontsize': '20'} | |
pylab.xlabel(r'Input Window Length', args) | |
pylab.ylabel(r'Error', args) | |
pylab.show() | |
pylab.savefig("error_" + str(time.time()) + ".pdf") | |
if __name__ == "__main__": | |
window_list = range(1, 90, 10) | |
error_list = learn_multiplication(window_list) | |
print(error_list) | |
plot_error(window_list, error_list) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Default Library | |
import random | |
def to_real_number(xs, size=None): | |
if size is None: | |
size = len(xs) | |
real = 0.0 | |
for i in range(size): | |
real += xs[i] * 2 ** -(i+1) | |
return real | |
def is_consistent(x, xs, size): | |
k = 2 ** - size | |
n = to_real_number(xs, size) | |
return k > abs(x - n) | |
def to_sdr(x, length=100): | |
sdr = [0] * length | |
i = 0 | |
while i < length: | |
sdr[i] = random.choice([-1, 0, 1]) | |
if not is_consistent(x, sdr, i + 1): | |
sdr[i] = 0 | |
else: | |
i += 1 | |
return sdr | |
# Test conversion: signed digit representation <-> real number | |
if __name__ == "__main__": | |
input_real = random.uniform(0, 1./3.) | |
input_sdr = to_sdr(input_real) | |
print("raw:\t", input_real) | |
print("result:\t", to_real_number(input_sdr)) | |
print("digits:\t", "".join(str(a) for a in input_sdr)) | |
for i in range(999): | |
x = random.uniform(0, 1./3.) | |
y = to_real_number(to_sdr(x)) | |
assert(x == y) | |
print("test passed") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment