Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active May 9, 2019 17:55
Show Gist options
  • Save gaphex/2e8b77e9ebed98b803f94685f2b507a8 to your computer and use it in GitHub Desktop.
Save gaphex/2e8b77e9ebed98b803f94685f2b507a8 to your computer and use it in GitHub Desktop.
setting up BERT learning environment
import os
import sys
import json
import nltk
import random
import logging
import tensorflow as tf
import sentencepiece as spm
from glob import glob
from google.colab import auth, drive
from tensorflow.keras.utils import Progbar
sys.path.append("bert")
from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder
auth.authenticate_user()
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s : %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]
if 'COLAB_TPU_ADDR' in os.environ:
log.info("Using TPU runtime")
USE_TPU = True
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
with tf.Session(TPU_ADDRESS) as session:
log.info('TPU address is ' + TPU_ADDRESS)
# Upload credentials to TPU.
with open('/content/adc.json', 'r') as f:
auth_info = json.load(f)
tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
else:
log.warning('Not connected to TPU runtime')
USE_TPU = False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment