Last active
May 9, 2019 17:55
-
-
Save gaphex/2e8b77e9ebed98b803f94685f2b507a8 to your computer and use it in GitHub Desktop.
setting up BERT learning environment
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import json | |
import nltk | |
import random | |
import logging | |
import tensorflow as tf | |
import sentencepiece as spm | |
from glob import glob | |
from google.colab import auth, drive | |
from tensorflow.keras.utils import Progbar | |
sys.path.append("bert") | |
from bert import modeling, optimization, tokenization | |
from bert.run_pretraining import input_fn_builder, model_fn_builder | |
auth.authenticate_user() | |
# configure logging | |
log = logging.getLogger('tensorflow') | |
log.setLevel(logging.INFO) | |
# create formatter and add it to the handlers | |
formatter = logging.Formatter('%(asctime)s : %(message)s') | |
sh = logging.StreamHandler() | |
sh.setLevel(logging.INFO) | |
sh.setFormatter(formatter) | |
log.handlers = [sh] | |
if 'COLAB_TPU_ADDR' in os.environ: | |
log.info("Using TPU runtime") | |
USE_TPU = True | |
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR'] | |
with tf.Session(TPU_ADDRESS) as session: | |
log.info('TPU address is ' + TPU_ADDRESS) | |
# Upload credentials to TPU. | |
with open('/content/adc.json', 'r') as f: | |
auth_info = json.load(f) | |
tf.contrib.cloud.configure_gcs(session, credentials=auth_info) | |
else: | |
log.warning('Not connected to TPU runtime') | |
USE_TPU = False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment