Created
August 22, 2019 16:34
-
-
Save Tahsin-Mayeesha/8a9565423113085a1ebe9ed31b5b31fb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LanguageClassifier(tf.train.Checkpoint): | |
def __init__(self,language_module,num_labels,dense_units=(128,128),dropouts=(0.1,0.1)): | |
# initialization stuff | |
super(LanguageClassifier,self).__init__() | |
self.language_module = language_module | |
self.model_encoder = language_module.model | |
# classifier head layers | |
self.dense_layers = [Dense(units,activation="relu") for units in dense_units] | |
self.dropout_layers = [Dropout(p) for p in dropouts] | |
self.max_pool_layer = GlobalMaxPooling1D() | |
self.average_pool_layer = GlobalAveragePooling1D() | |
self.batchnorm_layer = BatchNormalization() | |
self.n_layers = len(self.dense_layers) | |
self.final_layer = Dense(num_labels,activation="sigmoid") | |
#@tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)]) | |
def __call__(self,sentences): | |
#tokens,lookup_ids = self.language_module._tokens_to_lookup_ids(sentences) | |
self.enc_out = self.language_module.get_encoder_output(sentences) | |
last_h = self.enc_out[:,-1,:] | |
max_pool_output = self.max_pool_layer(self.enc_out) | |
average_pool_output = self.average_pool_layer(self.enc_out) | |
output = concatenate([last_h,max_pool_output,average_pool_output]) | |
for i in range(self.n_layers): | |
output = self.dense_layers[i](output) | |
output = self.dropout_layers[i](output) | |
output = self.batchnorm_layer(output) | |
final_output = self.final_layer(output) | |
return final_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LanguageModelEncoder(tf.train.Checkpoint): | |
def __init__(self,vocab_size,emb_dim,state_size,n_layers): | |
super(LanguageModelEncoder, self).__init__() | |
self._state_size = state_size | |
self._lstm_layers = [LSTM(self._state_size,return_sequences=True) for i in range(n_layers)] | |
#@tf.function(input_signature=[tf.TensorSpec([None,None,None], tf.dtypes.int64)]) | |
def __call__(self,sentence_embeddings): | |
#lstm_output = sentence_embeddings # initialize to the input | |
#print(sentence_embeddings.shape) | |
for lstm_layer in self._lstm_layers: | |
lstm_output = lstm_layer(lstm_output) | |
return lstm_output | |
def write_vocabulary_file(vocabulary): | |
"""Write temporary vocab file for module construction.""" | |
tmpdir = tempfile.mkdtemp() | |
vocabulary_file = os.path.join(tmpdir, "tokens.txt") | |
with tf.io.gfile.GFile(vocabulary_file, "w") as f: | |
for entry in vocabulary: | |
f.write(entry + "\n") | |
return vocabulary_file | |
class ULMFiTModule(tf.train.Checkpoint): | |
""" | |
Trains a language model on given sentences | |
""" | |
def __init__(self, vocab, emb_dim, buckets, state_size,n_layers): | |
super(ULMFiTModule, self).__init__() | |
self._buckets = buckets | |
self._vocab_size = len(vocab) | |
self.emb_row_size = self._vocab_size+self._buckets | |
self._embeddings = tf.Variable(tf.random.uniform(shape=[self.emb_row_size, emb_dim])) | |
self._state_size = state_size | |
self.model = LanguageModelEncoder(self.emb_row_size,emb_dim,state_size,n_layers) | |
self._vocabulary_file = tracking.TrackableAsset(write_vocabulary_file(vocab)) | |
self.w2i_table = lookup_ops.index_table_from_file( | |
vocabulary_file= self._vocabulary_file, | |
num_oov_buckets=self._buckets, | |
hasher_spec=lookup_ops.FastHashSpec) | |
self.i2w_table = lookup_ops.index_to_string_table_from_file( | |
vocabulary_file=self._vocabulary_file, | |
default_value="UNKNOWN") | |
self._logit_layer = tf.keras.layers.Dense(self.emb_row_size) | |
self.optimizer = tf.keras.optimizers.Adam() | |
def _tokenize(self, sentences): | |
# Perform a minimalistic text preprocessing by removing punctuation and | |
# splitting on spaces. | |
normalized_sentences = tf.strings.regex_replace( | |
input=sentences, pattern=r"\pP", rewrite="") | |
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse() | |
# Deal with a corner case: there is one empty sentence. | |
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant("")) | |
# Deal with a corner case: all sentences are empty. | |
sparse_tokens = tf.sparse.reset_shape(sparse_tokens) | |
return (sparse_tokens.indices, sparse_tokens.values, | |
sparse_tokens.dense_shape) | |
def _indices_to_words(self, indices): | |
#return tf.gather(self._vocab_tensor, indices) | |
return self.i2w_table.lookup(indices) | |
def _words_to_indices(self, words): | |
#return tf.strings.to_hash_bucket(words, self._buckets) | |
return self.w2i_table.lookup(words) | |
@tf.function(input_signature=[tf.TensorSpec([None],tf.dtypes.string)]) | |
def _tokens_to_lookup_ids(self,sentences): | |
token_ids, token_values, token_dense_shape = self._tokenize(sentences) | |
tokens_sparse = tf.sparse.SparseTensor( | |
indices=token_ids, values=token_values, dense_shape=token_dense_shape) | |
tokens = tf.sparse.to_dense(tokens_sparse, default_value="") | |
sparse_lookup_ids = tf.sparse.SparseTensor( | |
indices=tokens_sparse.indices, | |
values=self._words_to_indices(tokens_sparse.values), | |
dense_shape=tokens_sparse.dense_shape) | |
lookup_ids = tf.sparse.to_dense(sparse_lookup_ids, default_value=0) | |
return tokens,lookup_ids | |
@tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)]) | |
def train(self, sentences): | |
tokens,lookup_ids = self._tokens_to_lookup_ids(sentences) | |
# Targets are the next word for each word of the sentence. | |
tokens_ids_seq = lookup_ids[:, 0:-1] | |
tokens_ids_target = lookup_ids[:, 1:] | |
tokens_prefix = tokens[:, 0:-1] | |
# Mask determining which positions we care about for a loss: all positions | |
# that have a valid non-terminal token. | |
mask = tf.logical_and( | |
tf.logical_not(tf.equal(tokens_prefix, "")), | |
tf.logical_not(tf.equal(tokens_prefix, "<E>"))) | |
input_mask = tf.cast(mask, tf.int32) | |
with tf.GradientTape() as t: | |
sentence_embeddings = tf.nn.embedding_lookup(self._embeddings,tokens_ids_seq) | |
print(sentence_embeddings.shape) | |
lstm_output = self.model(sentence_embeddings) | |
lstm_output = tf.reshape(lstm_output, [-1,self._state_size]) | |
logits = self._logit_layer(lstm_output) | |
targets = tf.reshape(tokens_ids_target, [-1]) | |
weights = tf.cast(tf.reshape(input_mask, [-1]), tf.float32) | |
losses = tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=targets, logits=logits) | |
# Final loss is the mean loss for all token losses. | |
final_loss = tf.math.divide( | |
tf.reduce_sum(tf.multiply(losses, weights)), | |
tf.reduce_sum(weights), | |
name="final_loss") | |
watched = t.watched_variables() | |
gradients = t.gradient(final_loss, watched) | |
self.optimizer.apply_gradients(zip(gradients, watched)) | |
#for w, g in zip(watched, gradients): | |
# w.assign_sub(g) | |
return final_loss | |
@tf.function(input_signature=[tf.TensorSpec([None],tf.dtypes.string)]) | |
def get_encoder_output(self,sentences): | |
tokens,lookup_ids = self._tokens_to_lookup_ids(sentences) | |
lstm_output = self.model(lookup_ids) | |
return lstm_output | |
@tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)]) | |
def validate(self,sentences): | |
tokens,lookup_ids = self._tokens_to_lookup_ids(sentences) | |
# Targets are the next word for each word of the sentence. | |
tokens_ids_seq = lookup_ids[:, 0:-1] | |
tokens_ids_target = lookup_ids[:, 1:] | |
tokens_prefix = tokens[:, 0:-1] | |
# Mask determining which positions we care about for a loss: all positions | |
# that have a valid non-terminal token. | |
mask = tf.logical_and( | |
tf.logical_not(tf.equal(tokens_prefix, "")), | |
tf.logical_not(tf.equal(tokens_prefix, "<E>"))) | |
input_mask = tf.cast(mask, tf.int32) | |
lstm_output = self.model(tokens_ids_seq) | |
lstm_output = tf.reshape(lstm_output, [-1,self._state_size]) | |
logits = self._logit_layer(lstm_output) | |
targets = tf.reshape(tokens_ids_target, [-1]) | |
weights = tf.cast(tf.reshape(input_mask, [-1]), tf.float32) | |
losses = tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=targets, logits=logits) | |
# Final loss is the mean loss for all token losses. | |
final_loss = tf.math.divide( | |
tf.reduce_sum(tf.multiply(losses, weights)), | |
tf.reduce_sum(weights), | |
name="final_validation_loss") | |
return final_loss | |
@tf.function | |
def decode_greedy(self, sequence_length, first_word): | |
sequence = [first_word] | |
current_word = first_word | |
current_id = tf.expand_dims(self._words_to_indices(current_word), 0) | |
for _ in range(sequence_length): | |
lstm_output = self.model(tf.expand_dims(current_id,0)) | |
lstm_output = tf.reshape(lstm_output, [-1,self._state_size]) | |
logits = self._logit_layer(lstm_output) | |
softmax = tf.nn.softmax(logits) | |
next_ids = tf.math.argmax(softmax, axis=1) | |
next_words = self._indices_to_words(next_ids)[0] | |
current_id = next_ids | |
current_word = next_words | |
sequence.append(current_word) | |
return sequence |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment