test your checkpoints,
creating images for selected checkpoints,
a custom script for AUTOMATIC1111 / stable-diffusion-webui.
download this script and put it into your stable diffusion scripts folder.
""" | |
Copyright 2023 -- Zhang Xiaoke [email protected] | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
https://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
""" | |
""" | |
test your checkpoints | |
creating images for selected checkpoints | |
a custom script for AUTOMATIC1111/stable-diffusion-webui | |
Script: https://gist.github.com/mcxiaoke/42faaf3baa77870f31df386d150c710c | |
Version: 1.0.0 | |
Created at 2023.12.12 | |
Created by https://github.com/mcxiaoke | |
""" | |
import os | |
import sys | |
import pathlib | |
from datetime import datetime | |
from modules.processing import Processed, process_images, images | |
from modules import sd_models, processing, shared | |
import modules.scripts as scripts | |
import gradio as gr | |
from collections import namedtuple | |
from random import randint | |
import itertools | |
import operator | |
import functools | |
import random | |
model_path = sd_models.model_path | |
UI_TITLE = "Test Checkpoints" | |
MODEL_EXT = [".ckpt", ".safetensors"] | |
# https://realpython.com/python-flatten-list/ | |
def flatten_concatenation(matrix): | |
flat_list = [] | |
for row in matrix: | |
flat_list += row | |
return flat_list | |
def flatten_extend(matrix): | |
flat_list = [] | |
for row in matrix: | |
flat_list.extend(row) | |
return flat_list | |
def flatten_reduce_iconcat(matrix): | |
return functools.reduce(operator.iconcat, matrix, []) | |
def flatten_list_iter(nested_list): | |
for item in nested_list: | |
if isinstance(item, list): | |
yield from flatten_list(item) | |
else: | |
yield item | |
def get_files(paths): | |
paths = paths if isinstance(paths, list) else [paths] | |
filepaths = [] | |
for path in paths: | |
for dirpath, dirnames, filenames in os.walk(path, followlinks=True): | |
for filename in filenames: | |
filepaths.append(os.path.join(dirpath, filename)) | |
return filepaths | |
def get_all_files(root_dir): | |
files = [] | |
for dirpath in pathlib.Path(root_dir).iterdir(): | |
if dirpath.is_file(): | |
files.append(os.path.relpath(dirpath, start=root_dir)) | |
elif dirpath.is_dir(): | |
if dirpath.is_symlink(): | |
dirpath = dirpath.resolve() | |
files.extend(get_all_files(dirpath)) | |
return files | |
def get_subdirectories_w(root_dir): | |
for current_dir, subdirectories, _ in os.walk(root_dir, followlinks=True): | |
for subdir in subdirectories: | |
subdir_path = os.path.join(current_dir, subdir) | |
yield os.path.relpath(subdir_path, start=root_dir) | |
def get_subdirectories_p(root_dir): | |
for dirpath in pathlib.Path(root_dir).iterdir(): | |
if dirpath.is_dir(): | |
print("++", os.path.abspath(dirpath)) | |
yield os.path.relpath(dirpath, start=root_dir) | |
elif dirpath.is_symlink(): | |
dirpath = os.path.realpath(dirpath) | |
if dirpath.is_dir(): | |
print("++", os.path.abspath(dirpath)) | |
yield from get_subdirectories_p(dirpath) | |
def get_model_name(filename): | |
abspath = os.path.abspath(filename) | |
if abspath.startswith(model_path): | |
name = abspath.replace(model_path, "") | |
else: | |
name = os.path.basename(filename) | |
if name.startswith("\\") or name.startswith("/"): | |
name = name[1:] | |
model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] | |
return model_name | |
def get_model_filename(filename): | |
name = os.path.basename(filename) | |
return os.path.splitext(name)[0] | |
def get_model_list(selected_dirs): | |
models = [] | |
if selected_dirs is None or len(selected_dirs) == 0: | |
return models | |
# hack for all models | |
if selected_dirs[0] == "All": | |
selected_dirs = ["."] | |
selected_dirs = [os.path.join(model_path, x) for x in selected_dirs] | |
model_filenames = [list(shared.walk_files(d, MODEL_EXT)) for d in selected_dirs] | |
model_filenames = flatten_concatenation(model_filenames) | |
for f in model_filenames: | |
path = os.path.abspath(f) | |
name = get_model_name(path) | |
model = sd_models.get_closet_checkpoint_match(name) | |
if model is not None: | |
models.append(model) | |
# print(f"Model:", model.title) | |
return sorted(models, key=lambda x: x.name) | |
class Script(scripts.Script): | |
def title(self): | |
return UI_TITLE | |
def ui(self, is_img2img): | |
model_dirs = list(get_subdirectories_w(model_path)) | |
model_dirs = [x.replace("\\", "/") for x in model_dirs] | |
model_dirs = [f"{x}/" for x in model_dirs] | |
model_dirs.insert(0, "All") | |
selected_dirs = gr.CheckboxGroup( | |
choices=model_dirs, label="Choose checkpoint folders" | |
) | |
batch_size = gr.Number(value=1, label="Batch size for every checkpoint") | |
random_seed = gr.Checkbox(label="All Random Seed", info="Use random seed for every test?") | |
return [selected_dirs, batch_size, random_seed] | |
def run(self, p, selected_dirs, batch_size, random_seed): | |
positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt | |
if not positive_prompt: | |
# return process_images(p) | |
raise ValueError(f"{UI_TITLE}: Empty positive prompt!") | |
if not selected_dirs: | |
raise ValueError(f"{UI_TITLE}: No checkpoint folders selected!") | |
models = get_model_list(selected_dirs) | |
if not models: | |
raise ValueError(f"{UI_TITLE}: No checkpoints found!") | |
initial_seed = p.seed | |
if initial_seed == -1: | |
initial_seed = random.randrange(4294967294) | |
b_size = int(batch_size) | |
model_names = [m.name for m in models] | |
all_model_names = [] | |
all_seeds = [] | |
for m in model_names: | |
for i in range(b_size): | |
all_model_names.append(m) | |
if random_seed: | |
all_seeds.append(random.randrange(4294967294)) | |
else: | |
all_seeds.append(initial_seed + i) | |
total_count = len(all_model_names) | |
print( | |
f"{UI_TITLE}: total {len(model_names)} checkpoints in folder: {selected_dirs}." | |
) | |
print(f"{UI_TITLE}: create {total_count} images in {len(model_names)} batches.") | |
if shared.state.job_count == -1: | |
shared.state.job_count = total_count | |
for i in range(total_count): | |
if shared.state.interrupted: | |
return processed | |
shared.state.job = f"{UI_TITLE} job {i+1} out of {total_count}" | |
p.override_settings["sd_model_checkpoint"] = all_model_names[i] | |
p.seed = all_seeds[i] | |
p.do_not_save_grid = True | |
print( | |
f"{UI_TITLE}: processing model:{all_model_names[i]} seed:{all_seeds[i]} ({i+1}/{total_count})" | |
) | |
if i == 0: | |
processed = process_images(p) | |
else: | |
appendimages = process_images(p) | |
processed.images.insert(0, appendimages.images[0]) | |
processed.infotexts.insert(0, appendimages.infotexts[0]) | |
return processed |