Skip to content

Instantly share code, notes, and snippets.

@shyamupa
Created February 21, 2016 18:43
Show Gist options
  • Select an option

  • Save shyamupa/47a5a9301be6ad0eca5b to your computer and use it in GitHub Desktop.

Select an option

Save shyamupa/47a5a9301be6ad0eca5b to your computer and use it in GitHub Desktop.
def build_model(options):
model = Graph()
k=2*options.lstm_units
L=int(0.5*options.maxlen);
model.add_input(name='inputx', input_shape=(options.maxlen,), dtype=int)
model.add_input(name='inputy', input_shape=(options.maxlen,), dtype=int)
model.add_node(Embedding(options.max_features, options.w_emb, input_length=options.maxlen), name='x_emb', input='inputx')
model.add_node(Embedding(options.max_features, options.w_emb, input_length=options.maxlen), name='y_emb', input='inputy')
model.add_node(LSTM(options.lstm_units, return_sequences=True), name='forward', inputs=['x_emb','y_emb'])
model.add_node(LSTM(options.lstm_units, return_sequences=True, go_backwards=True), name='backward', inputs=['x_emb','y_emb'])
model.add_node(Dropout(0.5), name='dropout', inputs=['forward', 'backward'])
model.add_node(Lambda(get_H_n,output_shape=(k,)), name='h_n', input='dropout')
model.add_node(Lambda(get_Y,output_shape=(L,k)), name='Y', input='dropout')
model.add_node(Permute((2,1)), name='Y_t', input='Y')
# model.add_node(Reshape((2*options.maxlen,options.lstm_units)), name='Ys', input='dropout')
model.add_node(RepeatVector(int(0.5*options.maxlen)),name='H_n_cross_e',input='h_n')
model.add_node(TimeDistributedDense(options.lstm_units),name='WY',input='Y')
model.add_node(TimeDistributedDense(options.lstm_units),name='Whn',input='H_n_cross_e')
model.add_node(TimeDistributedDense(options.lstm_units,activation='tanh'),name='M', inputs=['Whn','WY'])
model.add_node(Permute((2,1)), name='M_perm', input='M')
model.add_node(TimeDistributedMerge(), name='alpha', input='M_perm')
model.add_node(LambdaMerge(get_R), name='r', inputs=['Y','alpha'])
# model.add_node(LambdaMerge(get_R), name='r', inputs=['M_perm','Y'])
print(model.summary())
plot(model,'model.png')
# model.compile(loss={'output':'categorical_crossentropy'}, optimizer='rmsprop')
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment