Created
August 26, 2015 13:31
-
-
Save stas-sl/40c6a06b795ca734638e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import theano.tensor as T | |
import numpy as np | |
from fuel.streams import DataStream | |
from fuel.datasets import IterableDataset | |
from blocks.main_loop import MainLoop | |
from blocks.extensions import FinishAfter, Printing, Timing, ProgressBar | |
from blocks.algorithms import GradientDescent, Scale | |
from blocks.bricks import Linear, Logistic | |
from blocks.bricks.recurrent import LSTM | |
from blocks.bricks.cost import BinaryCrossEntropy | |
from blocks.initialization import Constant, IsotropicGaussian | |
from blocks.graph import ComputationGraph | |
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring | |
from blocks.monitoring import aggregation | |
seq_length = 20 | |
batch_size = 10 | |
ltsm_dim = 10 | |
class MyDataset(IterableDataset): | |
def __init__(self, nb_examples): | |
super(MyDataset, self).__init__(self.generate_data(nb_examples)) | |
def generate_data(self, nb_examples): | |
x = np.random.randint(0, 2, (nb_examples / batch_size, seq_length, batch_size, 1)) | |
y = x.sum(axis=(1,)) % 2 | |
return {'x': x.astype('float32'), 'y': y.astype('float32')} | |
train_dataset = MyDataset(10000) | |
test_dataset = MyDataset(100) | |
stream = DataStream(dataset=train_dataset) | |
stream_test = DataStream(dataset=test_dataset) | |
x = T.tensor3('x') | |
y = T.matrix('y') | |
x_to_h = Linear(name='x_to_h', input_dim=1, | |
output_dim=ltsm_dim * 4, weights_init=IsotropicGaussian(0.01), | |
biases_init=Constant(0.0)) | |
x_transform = x_to_h.apply(x) | |
lstm = LSTM(ltsm_dim, weights_init=IsotropicGaussian(0.01), biases_init=Constant(0.0)) | |
h, c = lstm.apply(x_transform) | |
h_to_o = Linear(name='h_to_o', input_dim=ltsm_dim, output_dim=1, | |
weights_init=IsotropicGaussian(0.01), biases_init=Constant(0.0)) | |
y_hat = h_to_o.apply(h[-1]) | |
y_hat2 = Logistic().apply(y_hat) | |
cost = BinaryCrossEntropy().apply(y, y_hat2) | |
cg = ComputationGraph(cost) | |
cost.name = 'cost' | |
lstm.initialize() | |
x_to_h.initialize() | |
h_to_o.initialize() | |
algorithm = GradientDescent(cost=cost, parameters=cg.parameters, | |
step_rule=Scale(0.001)) | |
test_monitor = DataStreamMonitoring(variables=[cost], data_stream=stream_test, prefix="test") | |
train_monitor = TrainingDataMonitoring(variables=[cost, | |
aggregation.mean(algorithm.total_gradient_norm), | |
aggregation.mean(algorithm.total_step_norm)], | |
prefix="train", after_epoch=True) | |
main_loop = MainLoop(data_stream=stream, algorithm=algorithm, | |
extensions=[Timing(), test_monitor, train_monitor, | |
FinishAfter(after_n_epochs=10000), Printing(), | |
ProgressBar()]) | |
main_loop.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment