Last active
March 10, 2025 20:25
-
-
Save saftle/0e527ac2842900deb190768929d5a558 to your computer and use it in GitHub Desktop.
SD-Mecha State Dict Config Creator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import yaml | |
import os | |
def create_complete_config(checkpoint_path, output_yaml_path): | |
"""Create a complete SD Mecha configuration that preserves all keys""" | |
print(f"Loading checkpoint from {checkpoint_path}...") | |
# Load original checkpoint | |
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
if "state_dict" in state_dict: | |
state_dict = state_dict["state_dict"] | |
# Get all keys | |
all_keys = list(state_dict.keys()) | |
print(f"Found {len(all_keys)} keys in checkpoint") | |
# Group keys by component | |
components = { | |
"clip_l": [], | |
"vae": [], | |
"unet": [], | |
"ema": [], | |
"other": [] | |
} | |
for key in all_keys: | |
if key.startswith("cond_stage_model"): | |
components["clip_l"].append(key) | |
elif key.startswith("first_stage_model"): | |
components["vae"].append(key) | |
elif key.startswith("model.diffusion_model"): | |
components["unet"].append(key) | |
elif key.startswith("model_ema"): | |
components["ema"].append(key) | |
else: | |
components["other"].append(key) | |
# Create config structure | |
config = { | |
"identifier": "sd1-ldm_complete", | |
"components": {} | |
} | |
# Add components with shapes and dtypes | |
for component_name, keys in components.items(): | |
if keys: # Only add non-empty components | |
config["components"][component_name] = {} | |
for key in keys: | |
tensor = state_dict[key] | |
config["components"][component_name][key] = { | |
"shape": list(tensor.shape), | |
"dtype": str(tensor.dtype).split(".")[-1] | |
} | |
# Write config to file | |
print(f"Writing configuration with {sum(len(v) for v in components.values())} keys...") | |
with open(output_yaml_path, 'w') as f: | |
yaml.dump(config, f, default_flow_style=False) | |
print(f"Created configuration file: {output_yaml_path}") | |
print(f"Keys per component: {', '.join(f'{k}: {len(v)}' for k, v in components.items() if v)}") | |
# Example usage | |
create_complete_config( | |
checkpoint_path="/path/to/checkpoint/v1-5-pruned.ckpt", | |
output_yaml_path="/path/to/config/sd1-ldm-complete.yaml" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment