This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_H_n(X): | |
| ans=X[:, -1, :] # get last element from time dim | |
| return ans | |
| def build_model(options, verbose=False): | |
| model = Graph() | |
| k = 2 * options.lstm_units | |
| L = options.xmaxlen | |
| N = options.xmaxlen + options.ymaxlen + 1 # for delim | |
| print("x len", L, "total len", N) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_H_n(X): | |
| return X[:, -1, :] # get last element from time dim | |
| def get_Y(X): | |
| return X[:, :110, :] # get first xmaxlen elem from time dim | |
| def get_R(X): | |
| Y, alpha = X.values() # Y should be (L,k) and alpha should be (L,) and ans should be (k,) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model = Graph() | |
| k = 2 * options.lstm_units | |
| L = options.xmaxlen; | |
| model.add_input(name='inputx', input_shape=(options.xmaxlen,), dtype=int) | |
| model.add_input(name='inputy', input_shape=(options.ymaxlen,), dtype=int) | |
| model.add_node(Embedding(options.max_features, options.wx_emb, input_length=options.xmaxlen), name='x_emb', | |
| input='inputx') | |
| model.add_node(Embedding(options.max_features, options.wy_emb, input_length=options.ymaxlen), name='y_emb', | |
| input='inputy') | |
| model.add_node(LSTM(options.lstm_units, return_sequences=True), name='forward', inputs=['x_emb', 'y_emb'], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def build_model(options): | |
| model = Graph() | |
| k=2*options.lstm_units | |
| L=int(0.5*options.maxlen); | |
| model.add_input(name='inputx', input_shape=(options.maxlen,), dtype=int) | |
| model.add_input(name='inputy', input_shape=(options.maxlen,), dtype=int) | |
| model.add_node(Embedding(options.max_features, options.w_emb, input_length=options.maxlen), name='x_emb', input='inputx') | |
| model.add_node(Embedding(options.max_features, options.w_emb, input_length=options.maxlen), name='y_emb', input='inputy') | |
| model.add_node(LSTM(options.lstm_units, return_sequences=True), name='forward', inputs=['x_emb','y_emb']) | |
| model.add_node(LSTM(options.lstm_units, return_sequences=True, go_backwards=True), name='backward', inputs=['x_emb','y_emb']) |