Created
July 22, 2022 19:33
-
-
Save moyix/0f37da9c21c4ddfa0ab39ddad1639db4 to your computer and use it in GitHub Desktop.
Convert a SalesForce CodeGen model's weights to plain GPT-J
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import torch | |
from transformers import GPTJForCausalLM, GPTJConfig | |
# Note: these need the git version of Transformers as of 7/22/2022 | |
from transformers import CodeGenTokenizer, CodeGenForCausalLM | |
from transformers import CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST | |
parser = argparse.ArgumentParser('Convert SalesForce CodeGen model to GPT-J') | |
parser.add_argument('--code_model', | |
choices=CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST, default='Salesforce/codegen-350M-multi', | |
help='which SalesForce model to convert' | |
) | |
parser.add_argument('output_dir', help='where to store the converted model') | |
args = parser.parse_args() | |
print("Creating tokenizer") | |
tokenizer = CodeGenTokenizer.from_pretrained('Salesforce/codegen-350M-multi') | |
print('Loading CodeGen model') | |
cg_model = CodeGenForCausalLM.from_pretrained(args.code_model) | |
cg_config = cg_model.config | |
# Create empty GPTJ model | |
print('Creating empty GPTJ model') | |
config = GPTJConfig( | |
vocab_size=cg_config.vocab_size, | |
n_positions=cg_config.n_positions, | |
n_embd=cg_config.n_embd, | |
n_layer=cg_config.n_layer, | |
n_head=cg_config.n_head, | |
rotary_dim=cg_config.rotary_dim, | |
n_inner=cg_config.n_inner, | |
activation_function=cg_config.activation_function, | |
resid_pdrop=cg_config.resid_pdrop, | |
embd_pdrop=cg_config.embd_pdrop, | |
attn_pdrop=cg_config.attn_pdrop, | |
layer_norm_epsilon=cg_config.layer_norm_epsilon, | |
initializer_range=cg_config.initializer_range, | |
scale_attn_weights=cg_config.scale_attn_weights, | |
use_cache=cg_config.use_cache, | |
bos_token_id=cg_config.bos_token_id, | |
eos_token_id=cg_config.eos_token_id, | |
) | |
# Fix tokenizer type | |
config.tokenizer_class = 'CodeGenTokenizer' | |
gptj_model = GPTJForCausalLM(config) | |
embed_dim = config.n_embd | |
# Sample input for validating the conversion went OK | |
inputs = tokenizer.encode('#!/usr/bin/env python', return_tensors='pt') | |
def replace(model, weights, name): | |
model.state_dict()[name].copy_(weights.detach()) | |
def replace_by_name(dest_model, src_model, old_name, new_name): | |
assert old_name in src_model.state_dict() | |
assert new_name in dest_model.state_dict() | |
replace(dest_model, src_model.state_dict()[old_name], new_name) | |
print('Converting...') | |
# Copy weights from CodeGen model | |
with torch.no_grad(): | |
cg_model.eval() | |
gptj_model.eval() | |
for name, param in cg_model.named_parameters(): | |
# print(f'Converting {name}') | |
# Handle the qkv weights separately because we need to split them | |
if 'qkv_proj' in name: | |
qkv_proj = param.detach().clone() | |
mp_num = 4 # number of cores on their TPU I guess? | |
local_dim = embed_dim // mp_num | |
# GPT-J and CodeGen slice up the qkv projection slightly differently. | |
# After a great deal of pain, I figured out that this permutation on | |
# the weights of the qkv_proj fixes it. | |
base_permutation = [0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11] | |
permutation = torch.cat([torch.arange(i*local_dim, (i+1)*local_dim) for i in base_permutation]) | |
# NB: we permute the *rows* here because the computation is xA.T | |
new_qkv_proj = qkv_proj[permutation,:] | |
# NB: the name QKV is misleading here; they are actually stored in | |
# the order QVK | |
query, value, key = torch.split(new_qkv_proj, embed_dim, dim=0) | |
replace(gptj_model, query, name.replace('qkv_proj', 'q_proj')) | |
replace(gptj_model, key, name.replace('qkv_proj', 'k_proj')) | |
replace(gptj_model, value, name.replace('qkv_proj', 'v_proj')) | |
else: | |
replace_by_name(gptj_model, cg_model, name, name) | |
print('Conversion complete, running inference') | |
cg_out = cg_model.generate(inputs, min_length=32, max_length=32, do_sample=False, pad_token_id=50256) | |
gptj_out = gptj_model.generate(inputs, min_length=32, max_length=32, do_sample=False, pad_token_id=50256) | |
print(cg_out[0]) | |
print(gptj_out[0]) | |
cg_dec, gptj_dec = tokenizer.batch_decode(torch.stack([cg_out,gptj_out]).squeeze()) | |
print("====== CodeGen ======") | |
print(cg_dec) | |
print("====== GPT-J ======") | |
print(gptj_dec) | |
assert cg_dec == gptj_dec | |
print(f"Saving model to {args.output_dir}...") | |
gptj_model.save_pretrained(args.output_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment