Skip to content

Instantly share code, notes, and snippets.

@wfjsw
Created August 3, 2023 16:54
Show Gist options
  • Save wfjsw/0e0317755dc126ec31614222b05bac5d to your computer and use it in GitHub Desktop.
Save wfjsw/0e0317755dc126ec31614222b05bac5d to your computer and use it in GitHub Desktop.
A conceptual safetensors converter for RVC
import pathlib
import torch
from safetensors import safe_open
import argparse
import json
parser = argparse.ArgumentParser()
parser.add_argument("input", metavar="in", type=str, default="model.safetensors", help="Input .safetensors model path")
parser.add_argument("-m", "--metadata", required=False, metavar="meta", type=str, default="model.json", help="Metadata .json path")
parser.add_argument("output", metavar="out", nargs="?", default=None, type=str, help="Output .pth model path")
cmd_opts = parser.parse_args()
safetensors_path = cmd_opts.input
metadata_path = cmd_opts.metadata if cmd_opts.metadata is not None else safetensors_path.replace('.safetensors', '.json')
pth_path = cmd_opts.output if cmd_opts.output is not None else safetensors_path.replace('.safetensors', '.pth')
state_dict = {}
metadata = {}
with safe_open(safetensors_path, 'rb') as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if not pathlib.Path(metadata_path).exists():
with open(metadata_path, 'r') as f:
metadata = json.load(f)
else:
meta = f.metadata()
metadata['config'] = json.loads(meta.get('config'))
metadata['sr'] = meta.get('sr')
metadata['f0'] = 1 if meta.get('f0') == 'true' else 0
metadata['version'] = meta.get('version', 'v1')
metadata['info'] = meta.get('info')
metadata['weight'] = state_dict
torch.save(metadata, pth_path)
import torch
from safetensors.torch import save_file
import argparse
import json
parser = argparse.ArgumentParser()
parser.add_argument("input", metavar="in", type=str, default="model.pth", help="Input .pth model path")
parser.add_argument("output", metavar="out", nargs="?", default=None, type=str, help="Output .safetensors model path")
parser.add_argument("-m", "--metadata", metavar="meta", required=False, default=None, type=str, help="Output .json metadata path")
cmd_opts = parser.parse_args()
pth_path = cmd_opts.input
safetensors_path = cmd_opts.output if cmd_opts.output is not None else pth_path.replace('.pth', '.safetensors')
metadata_path = cmd_opts.metadata if cmd_opts.metadata is not None else safetensors_path.replace('.safetensors', '.json')
model = torch.load(pth_path)
state_dict = model.get('weight')
config = model.get('config')
metadata_safetensors = {
'config': json.dumps(config),
'sr': model.get('sr'),
'f0': 'true' if model.get('f0') else 'false',
'version': model.get('version', 'v1'),
'info': model.get('info')
}
metadata = {k:v for k,v in model.items() if k != 'weight'}
save_file(state_dict, safetensors_path, metadata_safetensors)
with open(metadata_path, 'w') as f:
json.dump(metadata, f)
@HelenaWhite1128
Copy link

python convert_to_pth.py input_model.safetensors -m metadata.json output_model.pth

이런식으로 사용하는게 맞습니까?
.safe 와 .json에 경로를 지정해주어도 다음과 같은 오류가뜹니다
usage: convert_to_pth.py [-h] [-m meta] in [out]
convert_to_pth.py: error: unrecognized arguments: output_model.pth

@wfjsw
Copy link
Author

wfjsw commented Dec 23, 2023

put safetensors behind -m json

@HelenaWhite1128
Copy link

Can you write the instructions from scratch? I'm not sure because I'm a beginner

@wfjsw
Copy link
Author

wfjsw commented Dec 23, 2023

Try

python convert_to_pth.py -m metadata.json input_model.safetensors output_model.pth

or simply

python convert_to_pth.py -m metadata.json input_model.safetensors

@kaiyang0914
Copy link

can i know how to use this code. im 100% beginner

@kaiyang0914
Copy link

i just want to convert safetensors to pth file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment