Created
July 21, 2020 09:40
-
-
Save jcrousse/cf6515c9fa0d4d6e0bedfecd137ed7a3 to your computer and use it in GitHub Desktop.
SOS Keras model 2.1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_learned_scores(**kwargs): | |
""" | |
scores each sentence, then multiply by score before next sequence layer. | |
:Keyword Arguments: | |
* sent_len (int) Sentence length | |
* embedding_size (int) word embedding length | |
* seq_len (int) length of overall sequence, equal to number of sentences x number of words per sentence | |
* pre_embedded (bool) True if input is already vectors of word embeddings, false if tokens to be embedded | |
:param : (int) | |
""" | |
sent_len = kwargs.get('sent_len') | |
embed_size = kwargs.get('embedding_size') | |
sent_per_obs = kwargs.get('num_sent') | |
pre_embedded = kwargs.get("pre_embedded", False) | |
model_type = kwargs.get("model_type", 'attention') | |
lstm_units_1 = kwargs.get('lstm_units_1', 16) | |
lstm_units_2 = kwargs.get('lstm_cells', 16) | |
if pre_embedded: | |
inputs = tf.keras.layers.Input(shape=(None, ), name="input") | |
embedded = tf.reshape(inputs, (-1, 1200, 768)) | |
else: | |
inputs = tf.keras.layers.Input(shape=(None,), name="input") | |
embedded = tf.keras.layers.Embedding(kwargs.get('vocab_size'), embed_size)(inputs) | |
reshaped = tf.reshape(embedded, (-1, sent_len, embed_size)) | |
lstm_level1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(lstm_units_1))(reshaped) | |
x = tf.keras.layers.Dense(1, activation=None)(lstm_level1) | |
logits = tf.reshape(x, (-1, sent_per_obs)) | |
score = tf.keras.layers.Softmax(name="score")(logits) | |
weighted = tf.multiply(lstm_level1, tf.reshape(score, (-1, 1))) | |
reshaped_level2 = tf.reshape(weighted, (-1, sent_per_obs, lstm_units_1*2)) | |
w_average = tf.keras.layers.GlobalAveragePooling1D()(reshaped_level2) | |
lstm_level2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(lstm_units_2))(reshaped_level2) | |
if model_type == 'attention': | |
outputs = tf.keras.layers.Dense(1, name="output")(w_average) | |
elif model_type == 'sos': | |
outputs = tf.keras.layers.Dense(1, name="output")(lstm_level2) | |
elif model_type == "combined": | |
classifier = tf.keras.layers.Dense(1)(lstm_level2) | |
classifier2 = tf.keras.layers.Dense(1)(w_average) | |
outputs = tf.keras.layers.concatenate([classifier, classifier2], name="output") | |
else: | |
raise ValueError(f"unexpected value for model_type: {model_type}") | |
model = tf.keras.Model(inputs=inputs, outputs=outputs) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment