Skip to content

Instantly share code, notes, and snippets.

@saftle
Last active March 10, 2025 20:25
Show Gist options
  • Save saftle/0e527ac2842900deb190768929d5a558 to your computer and use it in GitHub Desktop.
Save saftle/0e527ac2842900deb190768929d5a558 to your computer and use it in GitHub Desktop.
SD-Mecha State Dict Config Creator
import torch
import yaml
import os
def create_complete_config(checkpoint_path, output_yaml_path):
"""Create a complete SD Mecha configuration that preserves all keys"""
print(f"Loading checkpoint from {checkpoint_path}...")
# Load original checkpoint
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
# Get all keys
all_keys = list(state_dict.keys())
print(f"Found {len(all_keys)} keys in checkpoint")
# Group keys by component
components = {
"clip_l": [],
"vae": [],
"unet": [],
"ema": [],
"other": []
}
for key in all_keys:
if key.startswith("cond_stage_model"):
components["clip_l"].append(key)
elif key.startswith("first_stage_model"):
components["vae"].append(key)
elif key.startswith("model.diffusion_model"):
components["unet"].append(key)
elif key.startswith("model_ema"):
components["ema"].append(key)
else:
components["other"].append(key)
# Create config structure
config = {
"identifier": "sd1-ldm_complete",
"components": {}
}
# Add components with shapes and dtypes
for component_name, keys in components.items():
if keys: # Only add non-empty components
config["components"][component_name] = {}
for key in keys:
tensor = state_dict[key]
config["components"][component_name][key] = {
"shape": list(tensor.shape),
"dtype": str(tensor.dtype).split(".")[-1]
}
# Write config to file
print(f"Writing configuration with {sum(len(v) for v in components.values())} keys...")
with open(output_yaml_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
print(f"Created configuration file: {output_yaml_path}")
print(f"Keys per component: {', '.join(f'{k}: {len(v)}' for k, v in components.items() if v)}")
# Example usage
create_complete_config(
checkpoint_path="/path/to/checkpoint/v1-5-pruned.ckpt",
output_yaml_path="/path/to/config/sd1-ldm-complete.yaml"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment