Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM tensorflow/tensorflow:1.12.0-py3 | |
ENV LANG=C.UTF-8 | |
RUN mkdir /gpt-2 | |
WORKDIR /gpt-2 | |
ADD . /gpt-2 | |
RUN pip3 install -r requirements.txt |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_model(is_predicting, input_ids, input_mask, segment_ids, labels, | |
num_labels): | |
bert_module = hub.Module( | |
BERT_MODEL_HUB, | |
trainable=True) | |
bert_inputs = dict( | |
input_ids=input_ids, | |
input_mask=input_mask, | |
segment_ids=segment_ids) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is a path to an uncased (all lowercase) version of BERT | |
BERT_MODEL_HUB = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1" | |
def create_tokenizer_from_hub_module(): | |
"""Get the vocab file and casing info from the Hub module.""" | |
with tf.Graph().as_default(): | |
bert_module = hub.Module(BERT_MODEL_HUB) | |
tokenization_info = bert_module(signature="tokenization_info", as_dict=True) | |
with tf.Session() as sess: | |
vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_InputExamples = train.apply(lambda x: bert.run_classifier.InputExample(guid=None, | |
text_a = x[DATA_COLUMN], | |
text_b = None, | |
label = x[LABEL_COLUMN]), axis = 1) | |
val_InputExamples = val.apply(lambda x: bert.run_classifier.InputExample(guid=None, | |
text_a = x[DATA_COLUMN], | |
text_b = None, | |
label = x[LABEL_COLUMN]), axis = 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_InputExamples = train.apply(lambda x: bert.run_classifier.InputExample(guid=None, | |
text_a = x[DATA_COLUMN], | |
text_b = None, | |
label = x[LABEL_COLUMN]), axis = 1) | |
val_InputExamples = val.apply(lambda x: bert.run_classifier.InputExample(guid=None, | |
text_a = x[DATA_COLUMN], | |
text_b = None, | |
label = x[LABEL_COLUMN]), axis = 1) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Training the model for each combination of the hyperparameters. | |
x_train = X_train | |
x_test, y_test = X_val , y_val | |
#A unique number for each training session | |
session_num = 0 | |
#Nested for loop training with all possible combinathon of hyperparameters | |
for num_units in HP_NUM_UNITS.domain.values: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#A function to log the training process | |
def run(run_dir, hparams): | |
with tf.summary.create_file_writer(run_dir).as_default(): | |
hp.hparams(hparams) | |
rmse = train_test_model(hparams) | |
tf.summary.scalar(METRIC_RMSE, rmse, step=10) |