Created
June 2, 2022 18:36
-
-
Save kusal1990/ecd28c635bc97c11926d6951f632a01e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_lstm(input_shape): | |
# The shape was explained above, must have this order | |
inp = Input(shape=(input_shape[1], input_shape[2],)) | |
# This is the LSTM layer | |
# Bidirecional implies that the 160 chunks are calculated in both ways, 0 to 159 and 159 to zero | |
# although it appear that just 0 to 159 way matter, I have tested with and without, and tha later worked best | |
# 128 and 64 are the number of cells used, too many can overfit and too few can underfit | |
x = Bidirectional(CuDNNLSTM(128, return_sequences=True))(inp) | |
# x = Activation('relu')(x) | |
# x = Dropout(0.25)(x) | |
# x = BatchNormalization()(x) | |
# The second LSTM can give more fire power to the model, but can overfit it too | |
x = Bidirectional(CuDNNLSTM(64, return_sequences=True))(x) | |
# x = Activation('relu')(x) | |
# x = Dropout(0.25)(x) | |
# x = BatchNormalization()(x) | |
# Attention is a new tecnology that can be applyed to a Recurrent NN to give more meanings to a signal found in the middle | |
# of the data, it helps more in longs chains of data. A normal RNN give all the responsibility of detect the signal | |
# to the last cell. Google RNN Attention for more information :) | |
#x = Dropout(0.2)(x) | |
#x = BatchNormalization()(x) | |
x = Attention(input_shape[1])(x) | |
# A intermediate full connected (Dense) can help to deal with nonlinears outputs | |
x = Dense(64, activation="relu")(x) | |
# A binnary classification as this must finish with shape (1,) | |
x = Dense(1, activation="sigmoid")(x) | |
model = Model(inputs=inp, outputs=x) | |
# Pay attention in the addition of matthews_correlation metric in the compilation, it is a success factor key | |
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=[matthews_correlation]) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
ok