Created
February 13, 2019 17:24
-
-
Save dalequark/0b0fb125663d2749522cb07c1936c371 to your computer and use it in GitHub Desktop.
BERT create model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_model(is_predicting, input_ids, input_mask, segment_ids, labels, | |
num_labels): | |
"""Creates a classification model.""" | |
bert_module = hub.Module( | |
BERT_MODEL_HUB, | |
trainable=True) | |
bert_inputs = dict( | |
input_ids=input_ids, | |
input_mask=input_mask, | |
segment_ids=segment_ids) | |
bert_outputs = bert_module( | |
inputs=bert_inputs, | |
signature="tokens", | |
as_dict=True) | |
# Use "pooled_output" for classification tasks on an entire sentence. | |
# Use "sequence_outputs" for token-level output. | |
output_layer = bert_outputs["pooled_output"] | |
hidden_size = output_layer.shape[-1].value | |
# Create our own layer to tune for politeness data. | |
output_weights = tf.get_variable( | |
"output_weights", [num_labels, hidden_size], | |
initializer=tf.truncated_normal_initializer(stddev=0.02)) | |
output_bias = tf.get_variable( | |
"output_bias", [num_labels], initializer=tf.zeros_initializer()) | |
with tf.variable_scope("loss"): | |
# Dropout helps prevent overfitting | |
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) | |
logits = tf.matmul(output_layer, output_weights, transpose_b=True) | |
logits = tf.nn.bias_add(logits, output_bias) | |
log_probs = tf.nn.log_softmax(logits, axis=-1) | |
# Convert labels into one-hot encoding | |
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) | |
predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32)) | |
# If we're predicting, we want predicted labels and the probabiltiies. | |
if is_predicting: | |
return (predicted_labels, log_probs) | |
# If we're train/eval, compute loss between predicted and actual label | |
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) | |
loss = tf.reduce_mean(per_example_loss) | |
return (loss, predicted_labels, log_probs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment