Created
June 13, 2023 01:10
-
-
Save UniDyne/acd70e52e91472753cbd7a23611f44d9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# Stable Diffusion Embedding Converter | |
This is a simple script that converts a `.pt` Textual Inversion embedding file to `.safetensors` format. Nothing more, nothing less. | |
## To Use | |
``` | |
$ python convert_embedding.py embeddings/myembed.pt embeddings/myembed.safetensors | |
Trained on v1-5-pruned-emaonly. | |
Trained for 6808 steps. | |
Dimensions of embedding: torch.Size([12, 768]) | |
``` | |
""" | |
import os | |
import argparse | |
from pathlib import Path | |
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union | |
import torch | |
from safetensors.torch import safe_open | |
from safetensors.torch import save_file as safe_save | |
def convert(path, outpath, overwrite=False): | |
# check it's not there already | |
if os.path.exists(outpath) and not overwrite: | |
raise ValueError( | |
f"Output path {outpath} already exists, and overwrite is not True" | |
) | |
# Load model and extract the embedding | |
model = torch.load(path) | |
model_tensors = model.get('string_to_param').get('*') | |
s_model = { | |
'emb_params': model_tensors | |
} | |
# Print the checkpoint name, if defined | |
if ('sd_checkpoint_name' in model) and (model['sd_checkpoint_name'] is not None): | |
print(f"Trained on {model['sd_checkpoint_name']}.") | |
else: | |
print("Checkpoint name not found in the model.") | |
# Print the number of training steps | |
if ('step' in model) and (model['step'] is not None): | |
print(f"Trained for {model['step']} steps.") | |
else: | |
print("Step not found in the model.") | |
# Display the tensor shape | |
print(f"Dimensions of embedding: {model_tensors.shape}") | |
print() | |
safe_save(s_model, outpath) | |
def main(args_in: Optional[List[str]] = None) -> None: | |
parser = argparse.ArgumentParser(description="Convert embedding to safetensor.") | |
parser.add_argument("model", type=Path, help="Embedding .pt file input") | |
parser.add_argument("output", type=Path, help="Embedding .safetensors file output") | |
args = parser.parse_args(args_in) | |
convert(args.model, args.output) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment