Skip to content

Instantly share code, notes, and snippets.

@wfjsw
Created August 3, 2023 16:54
Show Gist options
  • Save wfjsw/0e0317755dc126ec31614222b05bac5d to your computer and use it in GitHub Desktop.
Save wfjsw/0e0317755dc126ec31614222b05bac5d to your computer and use it in GitHub Desktop.
A conceptual safetensors converter for RVC
import pathlib
import torch
from safetensors import safe_open
import argparse
import json
parser = argparse.ArgumentParser()
parser.add_argument("input", metavar="in", type=str, default="model.safetensors", help="Input .safetensors model path")
parser.add_argument("-m", "--metadata", required=False, metavar="meta", type=str, default="model.json", help="Metadata .json path")
parser.add_argument("output", metavar="out", nargs="?", default=None, type=str, help="Output .pth model path")
cmd_opts = parser.parse_args()
safetensors_path = cmd_opts.input
metadata_path = cmd_opts.metadata if cmd_opts.metadata is not None else safetensors_path.replace('.safetensors', '.json')
pth_path = cmd_opts.output if cmd_opts.output is not None else safetensors_path.replace('.safetensors', '.pth')
state_dict = {}
metadata = {}
with safe_open(safetensors_path, 'rb') as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if not pathlib.Path(metadata_path).exists():
with open(metadata_path, 'r') as f:
metadata = json.load(f)
else:
meta = f.metadata()
metadata['config'] = json.loads(meta.get('config'))
metadata['sr'] = meta.get('sr')
metadata['f0'] = 1 if meta.get('f0') == 'true' else 0
metadata['version'] = meta.get('version', 'v1')
metadata['info'] = meta.get('info')
metadata['weight'] = state_dict
torch.save(metadata, pth_path)
import torch
from safetensors.torch import save_file
import argparse
import json
parser = argparse.ArgumentParser()
parser.add_argument("input", metavar="in", type=str, default="model.pth", help="Input .pth model path")
parser.add_argument("output", metavar="out", nargs="?", default=None, type=str, help="Output .safetensors model path")
parser.add_argument("-m", "--metadata", metavar="meta", required=False, default=None, type=str, help="Output .json metadata path")
cmd_opts = parser.parse_args()
pth_path = cmd_opts.input
safetensors_path = cmd_opts.output if cmd_opts.output is not None else pth_path.replace('.pth', '.safetensors')
metadata_path = cmd_opts.metadata if cmd_opts.metadata is not None else safetensors_path.replace('.safetensors', '.json')
model = torch.load(pth_path)
state_dict = model.get('weight')
config = model.get('config')
metadata_safetensors = {
'config': json.dumps(config),
'sr': model.get('sr'),
'f0': 'true' if model.get('f0') else 'false',
'version': model.get('version', 'v1'),
'info': model.get('info')
}
metadata = {k:v for k,v in model.items() if k != 'weight'}
save_file(state_dict, safetensors_path, metadata_safetensors)
with open(metadata_path, 'w') as f:
json.dump(metadata, f)
@kaiyang0914
Copy link

i just want to convert safetensors to pth file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment