Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save lumpidu/5325008bdd17a5fc537fcfa1e23a08d5 to your computer and use it in GitHub Desktop.
Save lumpidu/5325008bdd17a5fc537fcfa1e23a08d5 to your computer and use it in GitHub Desktop.
import argparse
import hashlib
import os
import urllib
import warnings
from collections import OrderedDict
import torch
from torch import nn
from tqdm import tqdm
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer
def remove_ignore_keys_(state_dict):
ignore_keys = ["layers", "blocks"]
for k in ignore_keys:
state_dict.pop(k, None)
WHISPER_MAPPING = OrderedDict([
("decoder.decoders", "decoder"),
("encoder.encoders", "encoder"),
("blocks", "layers"),
("mlp.0", "fc1"),
("mlp.2", "fc2"),
("mlp_ln", "final_layer_norm"),
(".attn.query", ".self_attn.q_proj"),
(".attn.key", ".self_attn.k_proj"),
(".attn.value", ".self_attn.v_proj"),
(".attn_ln", ".self_attn_layer_norm"),
(".attn.out", ".self_attn.out_proj"),
(".cross_attn.query", ".encoder_attn.q_proj"),
(".cross_attn.key", ".encoder_attn.k_proj"),
(".cross_attn.value", ".encoder_attn.v_proj"),
(".cross_attn_ln", ".encoder_attn_layer_norm"),
(".cross_attn.out", ".encoder_attn.out_proj"),
("decoder.ln.", "decoder.layer_norm."),
("encoder.ln.", "encoder.layer_norm."),
("token_embedding", "embed_tokens"),
("encoder.positional_embedding", "encoder.embed_positions.weight"),
("decoder.positional_embedding", "decoder.embed_positions.weight"),
("ln_post", "layer_norm"),
])
def rename_keys(s_dict):
keys = list(s_dict.keys())
for key in keys:
new_key = key
for k, v in WHISPER_MAPPING.items():
if k in new_key:
new_key = new_key.replace(k, v)
print(f"{key} -> {new_key}")
s_dict[new_key] = s_dict.pop(key)
return s_dict
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
def convert_espnet_whisper_to_tfms(espnet_checkpoint, pytorch_dump_folder_path, whisper_config_id):
state_dict = torch.load(espnet_checkpoint, map_location="cpu")
proj_out_weights = state_dict["decoder.decoders.token_embedding.weight"]
remove_ignore_keys_(state_dict)
rename_keys(state_dict)
tie_embeds = True
#ffn_dim = state_dict["decoder.layers.0.fc1.weight"].shape[0]
config = WhisperConfig.from_pretrained(whisper_config_id)
model = WhisperForConditionalGeneration(config)
missing, unexpected = model.model.load_state_dict(state_dict, strict=False)
if len(missing) > 0 and not set(missing) <= {
"encoder.embed_positions.weights",
"decoder.embed_positions.weights",
}:
raise ValueError(
"Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing,"
f" but all the following weights are missing {missing}"
)
if tie_embeds:
model.proj_out = make_linear_from_emb(model.model.decoder.embed_tokens)
else:
model.proj_out.weight.data = proj_out_weights
model.save_pretrained(pytorch_dump_folder_path)
tokenizer = WhisperTokenizer.from_pretrained(whisper_config_id)
tokenizer.save_pretrained(pytorch_dump_folder_path)
processor = WhisperProcessor.from_pretrained(whisper_config_id)
processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument("--whisper-config-id", required=True, type=str, help="Whisper config ID, e.g. openai/whisper-medium")
parser.add_argument("--espnet_checkpoint", required=True, type=str, help="Patht to the Espnet model checkpoint")
parser.add_argument("--pytorch_dump_folder_path", required=True, type=str, help="Path to the output PyTorch model in HuggingFace format")
args = parser.parse_args()
convert_espnet_whisper_to_tfms(args.espnet_checkpoint, args.pytorch_dump_folder_path, args.whisper_config_id)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment