Created
January 1, 2024 07:09
-
-
Save ehartford/5d8452c1f2e8395398e86106388660df to your computer and use it in GitHub Desktop.
convert yayi2-30b to llama. All the credit to Charles Goddard and Weyaxi
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import os | |
import safetensors.torch | |
import glob | |
import json | |
def transform_st(path: str, out_dir: str): | |
data = safetensors.torch.load_file(path) | |
old_keys = list(data.keys()) | |
for key in old_keys: | |
old_key = key | |
if ".ln1." in key: | |
key = key.replace(".ln1.", ".input_layernorm.") | |
if ".ln2." in key: | |
key = key.replace(".ln2.", ".post_attention_layernorm.") | |
if key != old_key: | |
data[key] = data[old_key] | |
del data[old_key] | |
safetensors.torch.save_file( | |
data, os.path.join(out_dir, os.path.basename(path)), metadata={"format": "pt"} | |
) | |
def process_model(path: str, out_path: str): | |
for p in glob.glob(os.path.join(path, "model-*.safetensors")): | |
transform_st(p, out_path) | |
with open(os.path.join(path, "model.safetensors.index.json", "r")) as fd: | |
index_data = json.load(fd) | |
new_index = {"metadata": copy.copy(index_data["metadata"]), "weight_map": {}} | |
for key in index_data["weight_map"]: | |
new_key = key.replace(".ln1.", ".input_layernorm.").replace( | |
".ln2.", ".post_attention_layernorm." | |
) | |
new_index["weight_map"][new_key] = index_data["weight_map"][key] | |
with open( | |
os.path.join(out_path, "model.safetensors.index.json", "w", encoding="utf-8") | |
) as fd: | |
json.dump(new_index, fd) | |
process_model("/workspace/models/yayi2/", "/workspace/yayi2-30b-llama/") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment