Last active
July 27, 2021 17:11
-
-
Save Narsil/7198efc17ad5c9bc4f711ffb0e0e3758 to your computer and use it in GitHub Desktop.
Creating all dummy models with weights
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
import copy | |
import re | |
import importlib | |
import os | |
import tempfile | |
from collections import OrderedDict | |
import string | |
import h5py | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from transformers import ( | |
AutoTokenizer, | |
CONFIG_MAPPING, | |
MODEL_FOR_CAUSAL_LM_MAPPING, | |
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, | |
MODEL_FOR_MASKED_LM_MAPPING, | |
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, | |
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, | |
MODEL_FOR_OBJECT_DETECTION_MAPPING, | |
MODEL_FOR_PRETRAINING_MAPPING, | |
MODEL_FOR_QUESTION_ANSWERING_MAPPING, | |
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | |
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | |
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, | |
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, | |
MODEL_MAPPING, | |
MODEL_WITH_LM_HEAD_MAPPING, | |
TF_MODEL_FOR_CAUSAL_LM_MAPPING, | |
TF_MODEL_FOR_MASKED_LM_MAPPING, | |
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, | |
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, | |
TF_MODEL_FOR_PRETRAINING_MAPPING, | |
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, | |
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | |
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | |
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, | |
TF_MODEL_MAPPING, | |
TF_MODEL_WITH_LM_HEAD_MAPPING, | |
logging, | |
) | |
logging.set_verbosity_error() | |
HOME = os.getenv("HOME") | |
weights_path = f"{HOME}/data/weights" | |
def to_snake_case(name): | |
"https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case" | |
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) | |
name = re.sub("__([A-Z])", r"_\1", name) | |
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name) | |
return name.lower() | |
def flattened(somelist): | |
output = [] | |
for item in somelist: | |
if isinstance(item, (tuple, list)): | |
output.extend(list(item)) | |
else: | |
output.append(item) | |
return output | |
# UTILITY METHODS | |
def get_tiny_config_from_class(configuration_class): | |
""" | |
Retrieve a tiny configuration from the configuration class. It uses each class' `ModelTester`. | |
Args: | |
configuration_class: Subclass of `PreTrainedConfig`. | |
Returns: | |
an instance of the configuration passed, with very small hyper-parameters | |
""" | |
model_type = configuration_class.model_type | |
camel_case_model_name = configuration_class.__name__.split("Config")[0] | |
try: | |
module = importlib.import_module(f".test_modeling_{model_type.replace('-', '_')}", package="tests") | |
model_tester_class = getattr(module, f"{camel_case_model_name}ModelTester", None) | |
except ModuleNotFoundError: | |
print(f"Will not build {model_type}: no model tester or cannot find the testing module from the model name.") | |
return | |
if model_tester_class is None: | |
return | |
model_tester = model_tester_class(parent=None) | |
if hasattr(model_tester, "get_pipeline_config"): | |
return model_tester.get_pipeline_config() | |
elif hasattr(model_tester, "get_config"): | |
return model_tester.get_config() | |
def eventual_create_tokenizer(dirname, architecture, config): | |
try: | |
_ = AutoTokenizer.from_pretrained(dirname, local_files_only=True) | |
return | |
except: | |
pass | |
checkpoint = get_checkpoint_from_architecture(architecture) | |
if checkpoint is None: | |
return | |
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint) | |
if tokenizer is None: | |
return | |
if hasattr(config, "max_position_embeddings"): | |
tokenizer.model_max_length = config.max_position_embeddings | |
assert tokenizer.vocab_size <= config.vocab_size | |
if checkpoint is not None and tokenizer is not None: | |
try: | |
tokenizer.save_pretrained(dirname) | |
except Exception: | |
pass | |
try: | |
tokenizer._tokenizer.save(f"{dirname}/tokenizer.json") | |
except Exception: | |
return | |
_ = AutoTokenizer.from_pretrained(dirname, local_files_only=True) | |
# print(f"SUCCESS {dirname}") | |
def build_pt_architecture(architecture, config): | |
dirname = os.path.join(weights_path, config.model_type, to_snake_case(architecture.__name__)) | |
try: | |
model = architecture.from_pretrained(dirname, local_files_only=True) | |
# Already created | |
return | |
except Exception: | |
pass | |
state_dict = {} | |
if "DPRQuestionEncoder" in architecture.__name__: | |
# Not supported | |
return | |
if "ReformerModelWithLMHead" in architecture.__name__: | |
config.is_decoder = True | |
if "ReformerForMaskedLM" in architecture.__name__: | |
config.is_decoder = False | |
os.makedirs(dirname, exist_ok=True) | |
config.save_pretrained(dirname) | |
eventual_create_tokenizer(dirname, architecture, config) | |
model = architecture.from_pretrained(None, config=config, state_dict=state_dict, local_files_only=True) | |
state_dict = { | |
**{k.split(f"{config.model_type}.")[-1]: v for k, v in model.state_dict().items()}, | |
} | |
for key in state_dict.keys(): | |
if key.startswith(f"{config.model_type}."): | |
del state_dict[key] | |
torch.save(OrderedDict(state_dict), os.path.join(dirname, "pytorch_model.bin")) | |
# Make sure we can load what we just saved | |
model = architecture.from_pretrained(dirname, local_files_only=True) | |
def build_pytorch_weights_from_multiple_architectures(pytorch_architectures): | |
# Create the PyTorch tiny models | |
for config, architectures in tqdm(pytorch_architectures.items(), desc="Building PyTorch weights"): | |
base_tiny_config = get_tiny_config_from_class(config) | |
if base_tiny_config is None: | |
continue | |
flat_architectures = flattened(architectures) | |
for architecture in flat_architectures: | |
build_pt_architecture(architecture, copy.deepcopy(base_tiny_config)) | |
def build_tf_architecture(architecture, config): | |
# [2:] remove TF prefix of architecture name | |
dirname = os.path.join(weights_path, config.model_type, to_snake_case(architecture.__name__[2:])) | |
try: | |
model = architecture.from_pretrained(dirname, local_files_only=True) | |
# Already created | |
return | |
except Exception: | |
pass | |
if "DPRQuestionEncoder" in architecture.__name__: | |
# Not supported | |
return | |
if "ReformerModelWithLMHead" in architecture.__name__: | |
config.is_decoder = True | |
if "ReformerForMaskedLM" in architecture.__name__: | |
config.is_decoder = False | |
config.num_labels = 2 | |
os.makedirs(dirname, exist_ok=True) | |
config.save_pretrained(dirname) | |
eventual_create_tokenizer(dirname, architecture, config) | |
try: | |
model = architecture.from_pretrained(dirname, config=config, from_pt=True, local_files_only=True) | |
except Exception as e: | |
raise ValueError(f"Couldn't load {architecture.__name__}.") from e | |
model.save_pretrained(dirname) | |
model = architecture.from_pretrained(dirname, local_files_only=True) | |
def build_tensorflow_weights_from_multiple_architectures(tensorflow_architectures): | |
# Create the TensorFlow tiny models | |
for config, architectures in tqdm(tensorflow_architectures.items(), desc="Building TensorFlow weights"): | |
base_tiny_config = get_tiny_config_from_class(config) | |
if base_tiny_config is None: | |
continue | |
flat_architectures = flattened(architectures) | |
for architecture in flat_architectures: | |
build_tf_architecture(architecture, copy.deepcopy(base_tiny_config)) | |
def get_tiny_tokenizer_from_checkpoint(checkpoint): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, local_files_only=True) | |
except Exception: | |
return | |
# logger.warning("Training new from iterator ...") | |
vocabulary = string.ascii_letters + string.digits + " " | |
if not tokenizer.__class__.__name__.endswith("Fast"): | |
return | |
try: | |
tokenizer = tokenizer.train_new_from_iterator(vocabulary, vocab_size=len(vocabulary), show_progress=False) | |
except: # noqa: E722 | |
return | |
# logger.warning("Trained.") | |
return tokenizer | |
def get_checkpoint_from_architecture(architecture): | |
try: | |
module = importlib.import_module(architecture.__module__) | |
except Exception: | |
# logger.error(f"Ignoring architecture {architecture}") | |
return | |
if hasattr(module, "_CHECKPOINT_FOR_DOC"): | |
return module._CHECKPOINT_FOR_DOC | |
else: | |
# logger.warning(f"Can't retrieve checkpoint from {architecture.__name__}") | |
pass | |
def pt_architectures(): | |
pytorch_mappings = [ | |
MODEL_MAPPING, | |
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | |
MODEL_FOR_MASKED_LM_MAPPING, | |
MODEL_FOR_PRETRAINING_MAPPING, | |
MODEL_FOR_CAUSAL_LM_MAPPING, | |
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, | |
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, | |
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, | |
MODEL_FOR_OBJECT_DETECTION_MAPPING, | |
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, | |
MODEL_WITH_LM_HEAD_MAPPING, | |
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, | |
MODEL_FOR_QUESTION_ANSWERING_MAPPING, | |
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | |
] | |
pt_architectures = { | |
config: [pytorch_mapping[config] for pytorch_mapping in pytorch_mappings if config in pytorch_mapping] | |
for config in CONFIG_MAPPING.values() | |
} | |
build_pytorch_weights_from_multiple_architectures(pt_architectures) | |
print("Built PyTorch weights") | |
for config, architectures in tqdm(pt_architectures.items(), desc="Checking PyTorch weights validity"): | |
base_tiny_config = get_tiny_config_from_class(config) | |
if base_tiny_config is None: | |
continue | |
flat_architectures = flattened(architectures) | |
for architecture in flat_architectures: | |
if "DPRQuestionEncoder" in architecture.__name__: | |
continue | |
dirname = (os.path.join(weights_path, config.model_type, to_snake_case(architecture.__name__)) | |
model, loading_info = architecture.from_pretrained( | |
dirname, | |
output_loading_info=True, | |
local_files_only=True, | |
) | |
if len(loading_info["missing_keys"]) > 0: | |
raise ValueError(f"Missing weights when loading PyTorch checkpoints: {loading_info['missing_keys']}") | |
print("Checked PyTorch weights") | |
def tf_architectures(): | |
tensorflow_mappings = [ | |
TF_MODEL_MAPPING, | |
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | |
TF_MODEL_FOR_MASKED_LM_MAPPING, | |
TF_MODEL_FOR_PRETRAINING_MAPPING, | |
TF_MODEL_FOR_CAUSAL_LM_MAPPING, | |
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, | |
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, | |
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, | |
TF_MODEL_WITH_LM_HEAD_MAPPING, | |
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, | |
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | |
] | |
tf_architectures = { | |
config: [ | |
tensorflow_mapping[config] for tensorflow_mapping in tensorflow_mappings if config in tensorflow_mapping | |
] | |
for config in CONFIG_MAPPING.values() | |
} | |
build_tensorflow_weights_from_multiple_architectures(tf_architectures) | |
print("Built TensorFlow weights") | |
for config, architectures in tqdm(tf_architectures.items(), desc="Checking TensorFlow weights validity"): | |
base_tiny_config = get_tiny_config_from_class(config) | |
if base_tiny_config is None: | |
continue | |
flat_architectures = flattened(architectures) | |
for architecture in flat_architectures: | |
if "DPRQuestionEncoder" in architecture.__name__: | |
# Not supported | |
return | |
# [2:] to remove TF prefix | |
dirname = os.path.join(weights_path, config.model_type, to_snake_case(architecture.__name__[2:])) | |
try: | |
model, loading_info = architecture.from_pretrained( | |
dirname, output_loading_info=True, local_files_only=True | |
) | |
except Exception as e: | |
raise ValueError(f"Couldn't load {architecture.__name__}") from e | |
if len(loading_info["missing_keys"]) != 0: | |
required_weights_missing = [] | |
for missing_key in loading_info["missing_keys"]: | |
if "dropout" not in missing_key: | |
required_weights_missing.append(missing_key) | |
if len(required_weights_missing) > 0: | |
raise ValueError(f"Found missing weights in {architecture}: {required_weights_missing}") | |
print("Checked TensorFlow weights") | |
def main(): | |
# Define the PyTorch and TensorFlow mappings | |
pt_architectures() | |
tf_architectures() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment