Last active
November 30, 2020 21:44
-
-
Save gokart23/4318c8e1eeb71a5f1d206f335cbffef5 to your computer and use it in GitHub Desktop.
Resave ELECTRA checkpoint with optimizer variables reinitialized and added back in
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import re | |
from model import modeling | |
from model import optimization | |
from pretrain import pretrain_data | |
from pretrain import pretrain_helpers | |
from util import training_utils | |
from util import utils | |
import configure_pretraining | |
from run_pretraining import PretrainingModel | |
def _get_variable_name(param_name): | |
"""Get the variable name from the tensor name.""" | |
m = re.match("^(.*):\\d+$", param_name) | |
if m is not None: | |
param_name = m.group(1) | |
return param_name | |
config = configure_pretraining.PretrainingConfig("electra_base", "data/", num_train_steps=4001000, model_size="base") | |
input_ids = tf.placeholder(tf.int32, shape=(128, 128)) | |
input_mask = tf.placeholder(tf.int32, shape=(128, 128)) | |
segment_ids = tf.placeholder(tf.int32, shape=(128, 128)) | |
features = {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids} | |
g_step = tf.train.get_or_create_global_step() | |
model = PretrainingModel(config, features, True) | |
print ("model creation done") | |
# varslist = [(_get_variable_name(x.name), x) for x in tf.trainable_variables()] | |
modelvars = [x for x in tf.trainable_variables()] | |
train_op = optimization.create_optimizer( | |
model.total_loss, config.learning_rate, config.num_train_steps, | |
weight_decay_rate=config.weight_decay_rate, | |
use_tpu=config.use_tpu, | |
warmup_steps=config.num_warmup_steps, | |
lr_decay_power=config.lr_decay_power | |
) | |
print ("opt creation done") | |
init_ckpt = tf.train.latest_checkpoint("data/models/electra_base/") | |
print ("init ckpt is", init_ckpt) | |
saver = tf.train.Saver(modelvars) | |
saver_all = tf.train.Saver() | |
with tf.Session() as sess: | |
import ipdb; ipdb.set_trace() | |
myvar = [x for x in tf.trainable_variables() if 'discriminator_predictions/dense/bias:0' in x.name.lower()][0] | |
sess.run(tf.global_variables_initializer()) | |
init_val = sess.run(myvar) | |
saver.restore(sess, init_ckpt) | |
post_val = sess.run(myvar) | |
saver_all.save(sess, "resaved/bsz_128") | |
print ("done saving") | |
with tf.Session() as sess: | |
myvars = [x for x in tf.global_variables() if 'discriminator_predictions/dense/bias' in x.name.lower()] | |
saver_all.restore(sess, "resaved/bsz_128") | |
post_val = sess.run(myvars) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment