Created
November 3, 2022 04:59
-
-
Save Mikubill/e9da6f3fbf3b89916aa341912a3e71ed to your computer and use it in GitHub Desktop.
Extract tags from image using deepdanbooru (and save to txt)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os.path | |
import re | |
import tempfile | |
import argparse | |
import glob | |
import zipfile | |
import deepdanbooru as dd | |
import tensorflow as tf | |
import numpy as np | |
from basicsr.utils.download_util import load_file_from_url | |
from PIL import Image | |
from tqdm import tqdm | |
re_special = re.compile(r"([\\()])") | |
def get_deepbooru_tags_model(): | |
model_path = os.path.abspath(os.path.join(tempfile.gettempdir(), "deepbooru")) | |
if not os.path.exists(os.path.join(model_path, "project.json")): | |
# there is no point importing these every time | |
load_file_from_url( | |
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", | |
model_path, | |
) | |
with zipfile.ZipFile( | |
os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r" | |
) as zip_ref: | |
zip_ref.extractall(model_path) | |
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) | |
tags = dd.project.load_tags_from_project(model_path) | |
model = dd.project.load_model_from_project(model_path, compile_model=False) | |
return model, tags | |
def get_deepbooru_tags_from_model( | |
model, | |
tags, | |
pil_image, | |
threshold, | |
alpha_sort=False, | |
use_spaces=True, | |
use_escape=True, | |
include_ranks=False, | |
): | |
width = model.input_shape[2] | |
height = model.input_shape[1] | |
image = np.array(pil_image) | |
image = tf.image.resize( | |
image, | |
size=(height, width), | |
method=tf.image.ResizeMethod.AREA, | |
preserve_aspect_ratio=True, | |
) | |
image = image.numpy() # EagerTensor to np.array | |
image = dd.image.transform_and_pad_image(image, width, height) | |
image = image / 255.0 | |
image_shape = image.shape | |
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2])) | |
y = model.predict(image)[0] | |
result_dict = {} | |
for i, tag in enumerate(tags): | |
result_dict[tag] = y[i] | |
unsorted_tags_in_theshold = [] | |
result_tags_print = [] | |
for tag in tags: | |
if result_dict[tag] >= threshold: | |
if tag.startswith("rating:"): | |
continue | |
unsorted_tags_in_theshold.append((result_dict[tag], tag)) | |
result_tags_print.append(f"{result_dict[tag]} {tag}") | |
# sort tags | |
result_tags_out = [] | |
sort_ndx = 0 | |
if alpha_sort: | |
sort_ndx = 1 | |
# sort by reverse by likelihood and normal for alpha, and format tag text as requested | |
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) | |
for weight, tag in unsorted_tags_in_theshold: | |
tag_outformat = tag | |
if use_spaces: | |
tag_outformat = tag_outformat.replace("_", " ") | |
if use_escape: | |
tag_outformat = re.sub(re_special, r"\\\1", tag_outformat) | |
if include_ranks: | |
tag_outformat = f"({tag_outformat}:{weight:.3f})" | |
result_tags_out.append(tag_outformat) | |
print("\n".join(sorted(result_tags_print, reverse=True))) | |
return ", ".join(result_tags_out) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--path", type=str, default=".") | |
parser.add_argument("--threshold", type=int, default=0.75) | |
parser.add_argument("--alpha_sort", type=bool, default=False) | |
parser.add_argument("--use_spaces", type=bool, default=True) | |
parser.add_argument("--use_escape", type=bool, default=True) | |
parser.add_argument("--include_ranks", type=bool, default=False) | |
args = parser.parse_args() | |
types = ('*.jpg', '*.png', '*.jpeg', '*.gif', '*.webp', '*.bmp') # the tuple of file types | |
files_grabbed = [] | |
for files in types: | |
files_grabbed.extend(glob.glob(files)) | |
model, tags = get_deepbooru_tags_model() | |
for image_path in tqdm(files_grabbed, desc="Processing"): | |
image = Image.open(image_path).convert("RGB") | |
prompt = get_deepbooru_tags_from_model( | |
model, | |
tags, | |
image, | |
args.threshold, | |
alpha_sort=args.alpha_sort, | |
use_spaces=args.use_spaces, | |
use_escape=args.use_escape, | |
include_ranks=args.include_ranks, | |
) | |
image_name = os.path.splitext(os.path.basename(image_path))[0] | |
txt_filename = os.path.join(args.path, f"{image_name}.txt") | |
print(f"writing {txt_filename}: {prompt}") | |
with open(txt_filename, 'w') as f: | |
f.write(prompt) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment