Skip to content

Instantly share code, notes, and snippets.

@shyamupa
Created February 24, 2016 16:32
Show Gist options
  • Select an option

  • Save shyamupa/90a9050a6f35323c10df to your computer and use it in GitHub Desktop.

Select an option

Save shyamupa/90a9050a6f35323c10df to your computer and use it in GitHub Desktop.
model = Graph()
k = 2 * options.lstm_units
L = options.xmaxlen;
model.add_input(name='inputx', input_shape=(options.xmaxlen,), dtype=int)
model.add_input(name='inputy', input_shape=(options.ymaxlen,), dtype=int)
model.add_node(Embedding(options.max_features, options.wx_emb, input_length=options.xmaxlen), name='x_emb',
input='inputx')
model.add_node(Embedding(options.max_features, options.wy_emb, input_length=options.ymaxlen), name='y_emb',
input='inputy')
model.add_node(LSTM(options.lstm_units, return_sequences=True), name='forward', inputs=['x_emb', 'y_emb'],
concat_axis=1)
model.add_node(LSTM(options.lstm_units, return_sequences=True, go_backwards=True), name='backward',
inputs=['x_emb', 'y_emb'], concat_axis=1)
model.add_node(Dropout(0.5), name='dropout', inputs=['forward', 'backward'])
model.add_node(Lambda(get_H_n, output_shape=(k,)), name='h_n', input='dropout')
model.add_node(Lambda(get_Y, output_shape=(L, k)), name='Y', input='dropout')
model.add_node(RepeatVector(L), name='H_n_cross_e', input='h_n')
model.add_node(TimeDistributedDense(k), name='WY', input='Y')
model.add_node(TimeDistributedDense(k), name='Whn', input='H_n_cross_e')
model.add_node(TimeDistributedDense(k, activation='tanh'), name='M', inputs=['Whn', 'WY'])
model.add_node(TimeDistributedDense(1), name='alpha', input='M')
model.add_node(Lambda(get_R, output_shape=(k,)), name='r', inputs=['Y','alpha'], merge_mode='join')
model.add_node(Dense(k), name='Wr', input='r');
model.add_node(Dense(k), name='Wh', input='h_n');
model.add_node(Dense(k, activation='tanh'), name='h_star', inputs=['Wr', 'Wh'], merge_mode='sum')
model.add_node(Dense(1, activation='softmax'), name='out', input='h_star')
print(model.summary())
model.add_output(name='output', input='out')
plot(model, 'model.png')
model.compile(loss={'output':'categorical_crossentropy'}, optimizer='rmsprop')
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment