Skip to content

Instantly share code, notes, and snippets.

@gokart23
Last active November 30, 2020 21:44
Show Gist options
  • Save gokart23/4318c8e1eeb71a5f1d206f335cbffef5 to your computer and use it in GitHub Desktop.
Save gokart23/4318c8e1eeb71a5f1d206f335cbffef5 to your computer and use it in GitHub Desktop.
Resave ELECTRA checkpoint with optimizer variables reinitialized and added back in
import tensorflow as tf
import re
from model import modeling
from model import optimization
from pretrain import pretrain_data
from pretrain import pretrain_helpers
from util import training_utils
from util import utils
import configure_pretraining
from run_pretraining import PretrainingModel
def _get_variable_name(param_name):
"""Get the variable name from the tensor name."""
m = re.match("^(.*):\\d+$", param_name)
if m is not None:
param_name = m.group(1)
return param_name
config = configure_pretraining.PretrainingConfig("electra_base", "data/", num_train_steps=4001000, model_size="base")
input_ids = tf.placeholder(tf.int32, shape=(128, 128))
input_mask = tf.placeholder(tf.int32, shape=(128, 128))
segment_ids = tf.placeholder(tf.int32, shape=(128, 128))
features = {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids}
g_step = tf.train.get_or_create_global_step()
model = PretrainingModel(config, features, True)
print ("model creation done")
# varslist = [(_get_variable_name(x.name), x) for x in tf.trainable_variables()]
modelvars = [x for x in tf.trainable_variables()]
train_op = optimization.create_optimizer(
model.total_loss, config.learning_rate, config.num_train_steps,
weight_decay_rate=config.weight_decay_rate,
use_tpu=config.use_tpu,
warmup_steps=config.num_warmup_steps,
lr_decay_power=config.lr_decay_power
)
print ("opt creation done")
init_ckpt = tf.train.latest_checkpoint("data/models/electra_base/")
print ("init ckpt is", init_ckpt)
saver = tf.train.Saver(modelvars)
saver_all = tf.train.Saver()
with tf.Session() as sess:
import ipdb; ipdb.set_trace()
myvar = [x for x in tf.trainable_variables() if 'discriminator_predictions/dense/bias:0' in x.name.lower()][0]
sess.run(tf.global_variables_initializer())
init_val = sess.run(myvar)
saver.restore(sess, init_ckpt)
post_val = sess.run(myvar)
saver_all.save(sess, "resaved/bsz_128")
print ("done saving")
with tf.Session() as sess:
myvars = [x for x in tf.global_variables() if 'discriminator_predictions/dense/bias' in x.name.lower()]
saver_all.restore(sess, "resaved/bsz_128")
post_val = sess.run(myvars)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment