Created
September 13, 2023 20:21
-
-
Save epicfilemcnulty/1f55fd96b08f8d4d6693293e37b4c55e to your computer and use it in GitHub Desktop.
convert pytorch weights to safetensors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import sys | |
from collections import defaultdict | |
from tqdm import tqdm | |
import argparse | |
import torch | |
from safetensors.torch import load_file, save_file | |
def shared_pointers(tensors): | |
ptrs = defaultdict(list) | |
for k, v in tensors.items(): | |
ptrs[v.data_ptr()].append(k) | |
failing = [] | |
for ptr, names in ptrs.items(): | |
if len(names) > 1: | |
failing.append(names) | |
return failing | |
def check_file_size(sf_filename: str, pt_filename: str): | |
sf_size = os.stat(sf_filename).st_size | |
pt_size = os.stat(pt_filename).st_size | |
if (sf_size - pt_size) / pt_size > 0.01: | |
raise RuntimeError( | |
f"""The file size different is more than 1%: | |
- {sf_filename}: {sf_size} | |
- {pt_filename}: {pt_size} | |
""" | |
) | |
def convert_file( | |
pt_filename: str, | |
sf_filename: str, | |
): | |
loaded = torch.load(pt_filename, map_location="cpu") | |
if "state_dict" in loaded: | |
loaded = loaded["state_dict"] | |
shared = shared_pointers(loaded) | |
for shared_weights in shared: | |
for name in shared_weights[1:]: | |
loaded.pop(name) | |
# For tensors to be contiguous | |
loaded = {k: v.contiguous().half() for k, v in loaded.items()} | |
dirname = os.path.dirname(sf_filename) | |
os.makedirs(dirname, exist_ok=True) | |
save_file(loaded, sf_filename, metadata={"format": "pt"}) | |
check_file_size(sf_filename, pt_filename) | |
reloaded = load_file(sf_filename) | |
for k in loaded: | |
pt_tensor = loaded[k] | |
sf_tensor = reloaded[k] | |
if not torch.equal(pt_tensor, sf_tensor): | |
raise RuntimeError(f"The output tensors do not match for key {k}") | |
def rename(pt_filename: str) -> str: | |
filename, ext = os.path.splitext(pt_filename) | |
local = f"{filename}.safetensors" | |
local = local.replace("pytorch_model", "model") | |
return local | |
def convert_multi(folder: str, delprv: bool): | |
filename = "pytorch_model.bin.index.json" | |
with open(os.path.join(folder, filename), "r") as f: | |
data = json.load(f) | |
filenames = set(data["weight_map"].values()) | |
local_filenames = [] | |
for filename in tqdm(filenames): | |
pt_filename = os.path.join(folder, filename) | |
sf_filename = rename(pt_filename) | |
sf_filename = os.path.join(folder, sf_filename) | |
convert_file(pt_filename, sf_filename) | |
local_filenames.append(sf_filename) | |
if(delprv): | |
os.remove(pt_filename) | |
index = os.path.join(folder, "model.safetensors.index.json") | |
with open(index, "w") as f: | |
newdata = {k: v for k, v in data.items()} | |
newmap = {k: rename(v) for k, v in data["weight_map"].items()} | |
newdata["weight_map"] = newmap | |
json.dump(newdata, f, indent=4) | |
local_filenames.append(index) | |
if(delprv): | |
os.remove(os.path.join(folder,"pytorch_model.bin.index.json")) | |
return | |
def convert_single(folder: str, delprv: bool): | |
pt_name = "pytorch_model.bin" | |
pt_filename = os.path.join(folder, pt_name) | |
sf_name = "model.safetensors" | |
sf_filename = os.path.join(folder, sf_name) | |
convert_file(pt_filename, sf_filename) | |
if(delprv): | |
os.remove(pt_filename) | |
return | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-m', '--model', required=True, type=str, help="Path to the model dir") | |
parser.add_argument('-d', '--delete', default=False, required=False, type=bool, help="Delete pytorch files after conversion") | |
args = parser.parse_args() | |
for filename in os.listdir(args.model): | |
if filename == "pytorch_model.bin": | |
convert_single(args.model, args.delete) | |
sys.exit(0) | |
convert_multi(args.model, args.delete) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment