Skip to content

Instantly share code, notes, and snippets.

View gaphex's full-sized avatar

Denis gaphex

  • Moscow
View GitHub Profile
MODEL_PREFIX = "tokenizer" #@param {type: "string"}
VOC_SIZE = 32000 #@param {type:"integer"}
SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}
NUM_PLACEHOLDERS = 256 #@param {type:"integer"}
SPM_COMMAND = ('--input={} --model_prefix={} '
'--vocab_size={} --input_sentence_size={} '
'--shuffle_input_sentence=true '
'--bos_id=-1 --eos_id=-1').format(
PRC_DATA_FPATH, MODEL_PREFIX,
def read_sentencepiece_vocab(filepath):
voc = []
with open(filepath, encoding='utf-8') as fi:
for line in fi:
voc.append(line.split("\t")[0])
# skip the first <unk> token
voc = voc[1:]
return voc
snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
def parse_sentencepiece_token(token):
if token.startswith("▁"):
return token[1:]
else:
return "##" + token
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))
ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab
VOC_FNAME = "vocab.txt" #@param {type:"string"}
with open(VOC_FNAME, "w") as fo:
for token in bert_vocab:
fo.write(token+"\n")
mkdir ./shards
split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_
mkdir ./shards
split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
# controls how many parallel processes xargs can create
PROCESSES = 2 #@param {type:"integer"}
XARGS_CMD = ("ls ./shards/ | "
"xargs -n 1 -P {} -I{} "
"python3 bert/create_pretraining_data.py "
"--input_file=./shards/{} "
"--output_file={}/{}.tfrecord "
"--vocab_file={} "
"--do_lower_case={} "
"--max_predictions_per_seq={} "
"--max_seq_length={} "
"--masked_lm_prob={} "
BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
tf.gfile.MkDir(MODEL_DIR)
if not BUCKET_NAME:
log.warning("WARNING: BUCKET_NAME is not set. "
"You will not be able to train the model.")
# use this for BERT-base
bert_base_config = {
"attention_probs_dropout_prob": 0.1,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,