Skip to content

Instantly share code, notes, and snippets.

@gaphex
Created May 9, 2019 16:40
Show Gist options
  • Select an option

  • Save gaphex/aa786b7c7323b7b1530db04fea7e4c5a to your computer and use it in GitHub Desktop.

Select an option

Save gaphex/aa786b7c7323b7b1530db04fea7e4c5a to your computer and use it in GitHub Desktop.
BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
VOC_FNAME = "vocab.txt" #@param {type:"string"}
# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 1000000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}
NUM_TPU_CORES = 8
if BUCKET_NAME:
BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
BUCKET_PATH = "."
BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)
VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)
CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")
INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)
bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))
log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment