Last active
February 11, 2021 13:23
-
-
Save LysandreJik/db4c948f6b4483960de5cbac598ad4ed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
from configure_pretraining import PretrainingConfig | |
from run_pretraining import PretrainingModel | |
from pretrain.pretrain_data import get_input_fn, Inputs | |
import tensorflow as tf | |
import torch | |
from model import modeling | |
from transformers.modeling_electra import ElectraModel, ElectraGenerator, ElectraDiscriminator, load_tf_weights_in_electra | |
from transformers import BertConfig | |
import numpy as np | |
tf.enable_eager_execution() | |
tf.compat.v1.set_random_seed(0) | |
training = False | |
data_directory = "/home/jik/python/ELECTRA/electra_small" | |
config = PretrainingConfig("small", data_directory) | |
train_input_fn = get_input_fn(config, training) | |
features = tf.data.make_one_shot_iterator(train_input_fn({"batch_size": 1})).get_next() | |
input_ids = list(features["input_ids"].__iter__())[0].numpy() | |
input_mask = list(features["input_mask"].__iter__())[0].numpy() | |
segment_ids = list(features["segment_ids"].__iter__())[0].numpy() | |
FakedData = collections.namedtuple("FakedData", ["inputs", "is_fake_tokens"]) # , "sampled_tokens"]) | |
fake_data = FakedData( | |
Inputs( | |
tf.constant([[ 101, 2151, 11385, 2052, 2031, 20464, 4235, 15484, 2011, 2796, 4153, 14731, 1999, 4952, 2733, 2144, 1996, 4946, 5419, 1012, 1523, 8045, 2031, 2116, 7367, 3001, 4082, 1999, 1996, 22366, 1010, 2021, 2498, 2001, 3856, 2039, 1010, 1422, 4373, 5902, 19219, 11961, 17357, 16374, 1010, 2708, 1997, 3095, 1997, 2634, 1521, 21904, 1998, 23093, 2015, 1998, 19332, 8237, 3094, 1010, 2409, 26665, 1012, 1523, 2009, 2003, 2825, 2008, 1996, 15778, 7217, 2015, 12767, 7237, 2125, 2004, 2057, 5452, 2006, 2019, 2004, 1011, 3223, 3978, 1012, 2061, 3383, 3905, 4852, 2015, 2020, 4082, 1010, 20229, 2089, 2025, 2031, 1996, 3223, 2846, 2000, 11487, 1037, 3462, 2012, 2019, 21157, 1997, 431, 1010, 2199, 2519, 1012, 1524, 102, 16353, 7069, 4153, 2038, 2019, 2779, 5995, 1997, 2062, 2084, 2260, 1010, 102]], dtype=tf.int32), | |
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=tf.int32), | |
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=tf.int32), | |
tf.constant([[ 37, 69, 5, 106, 13, 115, 51, 43, 72, 116, 21, 21, 88, 24, 13, 29, 93, 69, 108]], dtype=tf.int32), | |
tf.constant([[1524, 2510, 2042, 7998, 1996, 1996, 1055, 4886, 2020, 2796, 2057, 2057, 7217, 7217, 1996, 2181, 2029, 2510, 3486]], dtype=tf.int32), | |
tf.constant([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]) | |
), | |
tf.constant([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=tf.int32) | |
) | |
fake_input_ids = list(fake_data[0].input_ids.numpy()) | |
fake_input_mask = list(fake_data[0].input_mask.numpy()) | |
fake_segment_ids = list(fake_data[0].segment_ids.numpy()) | |
fake_labels = list(fake_data[1].numpy()) | |
# ELECTRA can only run with eager execution disabled. | |
tf.disable_eager_execution() | |
# features = tf.data.make_one_shot_iterator(train_input_fn({"batch_size": 1})).get_next() | |
model = PretrainingModel(config, features, training) | |
init_checkpoint = data_directory + "/models/electra_small" | |
tvars = tf.trainable_variables() | |
assignment_map, _ = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) | |
tf.train.init_from_checkpoint(init_checkpoint, assignment_map) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
masked_inputs = sess.run(model.masked_inputs) | |
token_embeddings = sess.run(model.token_embeddings) | |
generator_embedding_output = sess.run(model.generator_embedding_output) | |
generator_all_encoder_layers = sess.run(model.generator_all_encoder_layers) | |
generator_sequence_output = sess.run(model.sequence_output) | |
generator_pooled_output = sess.run(model.pooled_output) | |
generator_mlm_output = sess.run(model.mlm_output) | |
generator_relevant_hidden = sess.run(model.relevant_hidden) | |
discriminator_embedding_output = sess.run(model.discriminator_embedding_output) | |
discriminator_all_encoder_layers = sess.run(model.discriminator_all_encoder_layers) | |
discriminator_output = sess.run(model.discriminator_output) | |
hf_model = ElectraModel.from_pretrained("here") | |
hf_generator = ElectraGenerator.from_pretrained("here") | |
hf_discriminator = ElectraDiscriminator.from_pretrained("here") | |
hf_model.eval() | |
hf_input_ids = torch.tensor(masked_inputs.input_ids, dtype=torch.long) | |
hf_attention_mask = torch.tensor(masked_inputs.input_mask, dtype=torch.long) | |
hf_token_type_ids = torch.tensor(masked_inputs.segment_ids, dtype=torch.long) | |
hf_masked_lm_positions = torch.tensor(masked_inputs.masked_lm_positions, dtype=torch.long) | |
hf_masked_lm_ids = torch.tensor(masked_inputs.masked_lm_ids, dtype=torch.long) | |
hf_masked_lm_weights = torch.tensor(masked_inputs.masked_lm_weights, dtype=torch.long) | |
hf_fake_input_ids = torch.tensor(fake_input_ids, dtype=torch.long) | |
hf_fake_attention_mask = torch.tensor(fake_input_mask, dtype=torch.long) | |
hf_fake_token_type_ids = torch.tensor(fake_segment_ids, dtype=torch.long) | |
hf_fake_labels = torch.tensor(fake_labels, dtype=torch.long) # .nonzero().split(1, dim=1)[1].view(-1) | |
outputs = hf_model( | |
hf_input_ids, | |
hf_attention_mask, | |
hf_token_type_ids, | |
masked_lm_positions=hf_masked_lm_positions, | |
masked_lm_ids=hf_masked_lm_ids, | |
) | |
discrim_outputs = hf_model( | |
hf_fake_input_ids, | |
hf_fake_attention_mask, | |
hf_fake_token_type_ids, | |
fake_token_labels=hf_fake_labels | |
) | |
gen_outputs_from_generator = hf_generator( | |
hf_input_ids, | |
hf_attention_mask, | |
hf_token_type_ids, | |
masked_lm_positions=hf_masked_lm_positions, | |
masked_lm_ids=hf_masked_lm_ids, | |
) | |
discrim_outputs_from_discrim = hf_discriminator( | |
hf_fake_input_ids, | |
hf_fake_attention_mask, | |
hf_fake_token_type_ids, | |
fake_token_labels=hf_fake_labels | |
) | |
hf_generator_sequence_output, hf_generator_pooled_output, _, hf_gen_logits, hf_gen_probs, hf_gen_preds, hf_gen_loss = outputs | |
hf_generator_sequence_output_2, hf_generator_pooled_output_2, hf_gen_logits_2, hf_gen_probs_2, hf_gen_preds_2, hf_gen_loss_2 = gen_outputs_from_generator | |
_, _, hf_discriminator_sequence_output, hf_discrim_probs, hf_discrim_preds, hf_discrim_loss = discrim_outputs | |
hf_discriminator_sequence_output_2, hf_discrim_probs_2, hf_discrim_preds_2, hf_discrim_loss_2 = discrim_outputs_from_discrim | |
def difference_between_tensors(tf_tensor, pt_tensor): | |
tf_np = np.array(tf_tensor) | |
pt_np = np.array(pt_tensor.detach()) | |
return np.max(np.abs(tf_np - pt_np)) | |
print("\n\n--- GENERATOR ---") | |
print("DIFFERENCE BETWEEN ALL ENCODER LAYERS: ", difference_between_tensors(generator_all_encoder_layers[-1], hf_generator_sequence_output)) | |
print("DIFFERENCE BETWEEN SEQUENCE OUTPUTS: ", difference_between_tensors(generator_sequence_output, hf_generator_sequence_output)) | |
print("DIFFERENCE BETWEEN POOLED OUTPUT ", difference_between_tensors(generator_pooled_output, hf_generator_pooled_output)) | |
print("DIFFERENCE BETWEEN MLM LOGITS ", difference_between_tensors(generator_mlm_output.logits, hf_gen_logits)) | |
print("DIFFERENCE BETWEEN MLM PROBS ", difference_between_tensors(generator_mlm_output.probs, hf_gen_probs)) | |
print("DIFFERENCE BETWEEN MLM LOSS ", difference_between_tensors(generator_mlm_output.loss, hf_gen_loss)) | |
print("DIFFERENCE BETWEEN MLM PREDS ", difference_between_tensors(generator_mlm_output.preds, hf_gen_preds)) | |
print("\n--- GENERATOR ONLY ---") | |
print("DIFFERENCE BETWEEN ALL ENCODER LAYERS: ", difference_between_tensors(generator_all_encoder_layers[-1], hf_generator_sequence_output_2)) | |
print("DIFFERENCE BETWEEN SEQUENCE OUTPUTS: ", difference_between_tensors(generator_sequence_output, hf_generator_sequence_output_2)) | |
print("DIFFERENCE BETWEEN POOLED OUTPUT ", difference_between_tensors(generator_pooled_output, hf_generator_pooled_output_2)) | |
print("DIFFERENCE BETWEEN MLM LOGITS ", difference_between_tensors(generator_mlm_output.logits, hf_gen_logits_2)) | |
print("DIFFERENCE BETWEEN MLM PROBS ", difference_between_tensors(generator_mlm_output.probs, hf_gen_probs_2)) | |
print("DIFFERENCE BETWEEN MLM LOSS ", difference_between_tensors(generator_mlm_output.loss, hf_gen_loss_2)) | |
print("DIFFERENCE BETWEEN MLM PREDS ", difference_between_tensors(generator_mlm_output.preds, hf_gen_preds_2)) | |
print("\n--- DISCRIMINATOR ---") | |
print("DIFFERENCE BETWEEN ALL ENCODER LAYERS ", difference_between_tensors(discriminator_all_encoder_layers[-1], hf_discriminator_sequence_output)) | |
print("DIFFERENCE BETWEEN SEQUENCE OUTPUT ", difference_between_tensors(discriminator_all_encoder_layers[-1], hf_discriminator_sequence_output)) | |
print("DIFFERENCE BETWEEN PROBS ", difference_between_tensors(discriminator_output.probs, hf_discrim_probs)) | |
print("DIFFERENCE BETWEEN LOSS ", difference_between_tensors(discriminator_output.loss, hf_discrim_loss)) | |
print("DIFFERENCE BETWEEN PREDS ", difference_between_tensors(discriminator_output.preds, hf_discrim_preds)) | |
print("\n--- DISCRIMINATOR ONLY ---") | |
print("DIFFERENCE BETWEEN ALL ENCODER LAYERS ", difference_between_tensors(discriminator_all_encoder_layers[-1], hf_discriminator_sequence_output_2)) | |
print("DIFFERENCE BETWEEN SEQUENCE OUTPUT ", difference_between_tensors(discriminator_all_encoder_layers[-1], hf_discriminator_sequence_output_2)) | |
print("DIFFERENCE BETWEEN PROBS ", difference_between_tensors(discriminator_output.probs, hf_discrim_probs_2)) | |
print("DIFFERENCE BETWEEN LOSS ", difference_between_tensors(discriminator_output.loss, hf_discrim_loss_2)) | |
print("DIFFERENCE BETWEEN PREDS ", difference_between_tensors(discriminator_output.preds, hf_discrim_preds_2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
Upto line 53, I executed the code successfully. But while I ran the code from line 55-70,
I find that the PretrainingModel object has no attribute 'token_embeddings', 'masked_input', 'generator_embedding_output' etc. I guess I need to edit the 'run_pretraining.py' file to return all these attributes from the retrainingModel object. Could you please suggest these edits?