Skip to content

Instantly share code, notes, and snippets.

@arodland
Last active August 6, 2025 18:21
Show Gist options
  • Save arodland/eac3882bb67d3a3b80b21e4eede973f7 to your computer and use it in GitHub Desktop.
Save arodland/eac3882bb67d3a3b80b21e4eede973f7 to your computer and use it in GitHub Desktop.
import argparse
import torch
from diffusers import AutoencoderTiny
"""
Example - From the diffusers root directory:
```sh
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \
--input_path model.pth \
--dump_path taesd-diffusers
```
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--input_path",
default=None,
type=str,
required=True,
help="Path to the input ckpt.",
)
args = parser.parse_args()
print("Loading the original state_dicts of the encoder and the decoder...")
input_state_dict = torch.load(args.input_path, map_location="cpu")
print("Populating the state_dicts in the diffusers format...")
tiny_autoencoder = AutoencoderTiny()
new_state_dict = {}
for k in input_state_dict:
endec, layer_id, rest = k.split(".", 2)
layer_id = int(layer_id) - (1 if endec == "decoder" else 0)
new_k = f"{endec}.layers.{layer_id}.{rest}"
new_state_dict.update({new_k: input_state_dict[k]})
# Assertion tests with the original implementation can be found here:
# https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f
tiny_autoencoder.load_state_dict(new_state_dict)
print("Population successful, serializing...")
tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment