Created
August 3, 2023 16:54
-
-
Save wfjsw/0e0317755dc126ec31614222b05bac5d to your computer and use it in GitHub Desktop.
A conceptual safetensors converter for RVC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pathlib | |
import torch | |
from safetensors import safe_open | |
import argparse | |
import json | |
parser = argparse.ArgumentParser() | |
parser.add_argument("input", metavar="in", type=str, default="model.safetensors", help="Input .safetensors model path") | |
parser.add_argument("-m", "--metadata", required=False, metavar="meta", type=str, default="model.json", help="Metadata .json path") | |
parser.add_argument("output", metavar="out", nargs="?", default=None, type=str, help="Output .pth model path") | |
cmd_opts = parser.parse_args() | |
safetensors_path = cmd_opts.input | |
metadata_path = cmd_opts.metadata if cmd_opts.metadata is not None else safetensors_path.replace('.safetensors', '.json') | |
pth_path = cmd_opts.output if cmd_opts.output is not None else safetensors_path.replace('.safetensors', '.pth') | |
state_dict = {} | |
metadata = {} | |
with safe_open(safetensors_path, 'rb') as f: | |
for k in f.keys(): | |
state_dict[k] = f.get_tensor(k) | |
if not pathlib.Path(metadata_path).exists(): | |
with open(metadata_path, 'r') as f: | |
metadata = json.load(f) | |
else: | |
meta = f.metadata() | |
metadata['config'] = json.loads(meta.get('config')) | |
metadata['sr'] = meta.get('sr') | |
metadata['f0'] = 1 if meta.get('f0') == 'true' else 0 | |
metadata['version'] = meta.get('version', 'v1') | |
metadata['info'] = meta.get('info') | |
metadata['weight'] = state_dict | |
torch.save(metadata, pth_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from safetensors.torch import save_file | |
import argparse | |
import json | |
parser = argparse.ArgumentParser() | |
parser.add_argument("input", metavar="in", type=str, default="model.pth", help="Input .pth model path") | |
parser.add_argument("output", metavar="out", nargs="?", default=None, type=str, help="Output .safetensors model path") | |
parser.add_argument("-m", "--metadata", metavar="meta", required=False, default=None, type=str, help="Output .json metadata path") | |
cmd_opts = parser.parse_args() | |
pth_path = cmd_opts.input | |
safetensors_path = cmd_opts.output if cmd_opts.output is not None else pth_path.replace('.pth', '.safetensors') | |
metadata_path = cmd_opts.metadata if cmd_opts.metadata is not None else safetensors_path.replace('.safetensors', '.json') | |
model = torch.load(pth_path) | |
state_dict = model.get('weight') | |
config = model.get('config') | |
metadata_safetensors = { | |
'config': json.dumps(config), | |
'sr': model.get('sr'), | |
'f0': 'true' if model.get('f0') else 'false', | |
'version': model.get('version', 'v1'), | |
'info': model.get('info') | |
} | |
metadata = {k:v for k,v in model.items() if k != 'weight'} | |
save_file(state_dict, safetensors_path, metadata_safetensors) | |
with open(metadata_path, 'w') as f: | |
json.dump(metadata, f) |
Can you write the instructions from scratch? I'm not sure because I'm a beginner
Try
python convert_to_pth.py -m metadata.json input_model.safetensors output_model.pth
or simply
python convert_to_pth.py -m metadata.json input_model.safetensors
can i know how to use this code. im 100% beginner
i just want to convert safetensors to pth file
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
put safetensors behind -m json