Created
May 15, 2024 11:17
-
-
Save kohya-ss/1711f17fe77def811fcaf82877b0bec2 to your computer and use it in GitHub Desktop.
Dart V2を使ってプロンプトを作成
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Dart v2を用いて sd-scripts の gen_img.py 用のプロンプトファイルを作成するスクリプト | |
import random | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Rating tag: <|rating:sfw|>, <|rating:general|>, <|rating:sensitive|>, nsfw, <|rating:questionable|>, <|rating:explicit|> | |
# Aspect ratio tag: <|aspect_ratio:ultra_wide|>, <|aspect_ratio:wide|>, <|aspect_ratio:square|>, <|aspect_ratio:tall|>, <|aspect_ratio:ultra_tall|> | |
# Length tag: <|length:very_short|>, <|length:short|>, <|length:medium|>, <|length:long|>, <|length:very_long|> | |
""" | |
prompt = ( | |
f"<|bos|>" | |
f"<copyright>{copyright_tags_here}</copyright>" | |
f"<character>{character_tags_here}</character>" | |
f"<|rating:general|><|aspect_ratio:tall|><|length:long|>" | |
f"<general>{general_tags_here}" | |
) | |
""" | |
def get_prompt(model, num_prompts, rating, aspect_ratio, length, first_tag): | |
prompt = ( | |
f"<|bos|>" f"<copyright></copyright>" f"<character></character>" f"{rating}{aspect_ratio}{length}" f"<general>{first_tag}" | |
) | |
prompts = [prompt] * num_prompts | |
inputs = tokenizer(prompts, return_tensors="pt").input_ids | |
inputs = inputs.to("cuda") | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
do_sample=True, | |
temperature=1.0, | |
top_p=1.0, | |
top_k=100, | |
max_new_tokens=128, | |
num_beams=1, | |
) | |
# return ", ".join([tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""]) | |
decoded = [] | |
for i in range(num_prompts): | |
output = outputs[i].cpu() | |
tags = tokenizer.batch_decode(output, skip_special_tokens=True) | |
prompt = ", ".join([tag for tag in tags if tag.strip() != ""]) | |
decoded.append(prompt) | |
return decoded | |
# 網羅的に作るタグ類 | |
""" | |
1024 x 1024 1:1 Square | |
1152 x 896 9:7 | |
896 x 1152 7:9 | |
1216 x 832 19:13 | |
832 x 1216 13:19 | |
1344 x 768 7:4 Horizontal | |
768 x 1344 4:7 Vertical | |
1536 x 640 12:5 Horizontal | |
640 x 1536 5:12 Vertical | |
""" | |
DIMENSIONS = [(1024, 1024), (1152, 896), (896, 1152), (1216, 832), (832, 1216), (1344, 768), (768, 1344), (1536, 640), (640, 1536)] | |
ASPECT_RATIO_TAGS = [ | |
"<|aspect_ratio:square|>", | |
"<|aspect_ratio:wide|>", | |
"<|aspect_ratio:tall|>", | |
"<|aspect_ratio:wide|>", | |
"<|aspect_ratio:tall|>", | |
"<|aspect_ratio:wide|>", | |
"<|aspect_ratio:tall|>", | |
"<|aspect_ratio:ultra_wide|>", | |
"<|aspect_ratio:ultra_tall|>", | |
] | |
RATING_MODIFIERS = ["safe", "sensitive"] # , "nsfw", "explicit, nsfw"] | |
RATING_TAGS = ["<|rating:general|>", "<|rating:sensitive|>"] # , "<|rating:questionable|>", "<|rating:explicit|>"] | |
FIRST_TAGS = [ | |
"no humans", | |
"1girl", | |
"2girls", | |
"3girls", | |
"4girls", | |
"5girls", | |
"6+girls", | |
"1boy", | |
"2boys", | |
"3boys", | |
"4boys", | |
"5boys", | |
"6+boys", | |
"1other", | |
"2others", | |
"3others", | |
"4others", | |
"5others", | |
"6+others", | |
] | |
# ランダムに選ぶタグ類 | |
LENGTH_TAGS = ["<|length:very_short|>", "<|length:short|>", "<|length:medium|>", "<|length:long|>", "<|length:very_long|>"] | |
""" | |
newest 2021 to 2024 | |
recent 2018 to 2020 | |
mid 2015 to 2017 | |
early 2011 to 2014 | |
oldest 2005 to 2010 | |
""" | |
YEAR_MODIFIERS = [None, "newest", "recent", "mid"] # , "early", "oldest"] | |
# ranomly select 0 to 4 of these | |
QUALITY_MODIFIERS_AND_AESTHETIC = ["masterpiece", "best quality", "very aesthetic", "absurdres"] | |
# negative prompt | |
NEGATIVE_PROMPT = ( | |
"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, " | |
+ "oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]" | |
) | |
# NEGATIVE_PROMPT = ( | |
# "nsfw, lowres, bad, text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, " | |
# + "oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, abstract" | |
# ) | |
NUM_PROMPTS_PER_VARIATION = 8 | |
BATCH_SIZE = 8 # 大きくしたいが、バッチ内で length が同じになってしまう | |
assert NUM_PROMPTS_PER_VARIATION * len(YEAR_MODIFIERS) % BATCH_SIZE == 0 | |
MODEL_NAME = "p1atdev/dart-v2-base" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16) | |
model.to("cuda") | |
# make prompts | |
PARTITION = "a" # prefix for the output file | |
random.seed(42) | |
prompts = [] | |
for rating_modifier, rating_tag in zip(RATING_MODIFIERS, RATING_TAGS): | |
negative_prompt = NEGATIVE_PROMPT | |
if "nsfw" in rating_modifier: | |
negative_prompt = negative_prompt.replace("nsfw, ", "") | |
for dimension, aspect_ratio_tag in zip(DIMENSIONS, ASPECT_RATIO_TAGS): | |
for first_tag in FIRST_TAGS: | |
print(f"rating: {rating_modifier}, aspect ratio: {dimension}, first tag: {first_tag}") | |
# year_modifier はDart v2の引数にならないので、ここでバッチを作ることでバッチサイズを稼ぐ | |
dart_prompts = [] | |
for i in range(0, NUM_PROMPTS_PER_VARIATION * len(YEAR_MODIFIERS), BATCH_SIZE): | |
# ひとつのバッチの中で length が同じになってしまうのでどうにかしたいけど難しそう | |
length = random.choice(LENGTH_TAGS) | |
dart_prompts += get_prompt(model, BATCH_SIZE, rating_tag, aspect_ratio_tag, length, first_tag) | |
num_prompts_for_each_year_modifier = NUM_PROMPTS_PER_VARIATION | |
for j, year_modifier in enumerate(YEAR_MODIFIERS): | |
for prompt in dart_prompts[j * num_prompts_for_each_year_modifier : (j + 1) * num_prompts_for_each_year_modifier]: | |
# escape `(` and `)`, like "star (symbol)" -> "star \(symbol\)" | |
prompt = prompt.replace("(", "\\(").replace(")", "\\)") | |
# select quality modifiers and aesthetic | |
quality_modifiers = random.sample(QUALITY_MODIFIERS_AND_AESTHETIC, random.randint(0, 4)) | |
quality_modifiers = ", ".join(quality_modifiers) | |
# combine quality modifiers and aesthetic, year modifier and rating modifier | |
qm = f"{quality_modifiers}, " if quality_modifiers else "" | |
ym = f", {year_modifier}" if year_modifier else "" | |
# build final prompt | |
image_index = len(prompts) | |
width, height = dimension | |
rm_filename = rating_modifier.replace(", ", "_") # "nsfw, explicit" -> "nsfw_explicit" | |
ym_filename = year_modifier if year_modifier else "none" | |
ft_filename = first_tag.replace("+", "") # remove "+" from "6+girls" etc. | |
ft_filename = ft_filename.replace(" ", "") # remove space from "no humans" etc. | |
image_filename = ( | |
f"{PARTITION}{image_index:08d}_{rm_filename}_{width:04d}x{height:04d}_{ym_filename}_{ft_filename}.webp" | |
) | |
seed = random.randint(0, 2**32 - 1) | |
final_prompt = f"{qm}{prompt}, {rating_modifier}{ym} --n {negative_prompt} --w {width} --h {height} --ow {width} --oh {height} --d {seed} --f {image_filename}" | |
prompts.append(final_prompt) | |
# break # test | |
# break | |
# break | |
# output to a file | |
with open(f"prompts_{PARTITION}.txt", "w") as f: | |
f.write("\n".join(prompts)) | |
print(f"Done. {len(prompts)} prompts are written to prompts_{PARTITION}.txt.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment