Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created May 15, 2024 11:17
Show Gist options
  • Save kohya-ss/1711f17fe77def811fcaf82877b0bec2 to your computer and use it in GitHub Desktop.
Save kohya-ss/1711f17fe77def811fcaf82877b0bec2 to your computer and use it in GitHub Desktop.
Dart V2を使ってプロンプトを作成
# Dart v2を用いて sd-scripts の gen_img.py 用のプロンプトファイルを作成するスクリプト
import random
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Rating tag: <|rating:sfw|>, <|rating:general|>, <|rating:sensitive|>, nsfw, <|rating:questionable|>, <|rating:explicit|>
# Aspect ratio tag: <|aspect_ratio:ultra_wide|>, <|aspect_ratio:wide|>, <|aspect_ratio:square|>, <|aspect_ratio:tall|>, <|aspect_ratio:ultra_tall|>
# Length tag: <|length:very_short|>, <|length:short|>, <|length:medium|>, <|length:long|>, <|length:very_long|>
"""
prompt = (
f"<|bos|>"
f"<copyright>{copyright_tags_here}</copyright>"
f"<character>{character_tags_here}</character>"
f"<|rating:general|><|aspect_ratio:tall|><|length:long|>"
f"<general>{general_tags_here}"
)
"""
def get_prompt(model, num_prompts, rating, aspect_ratio, length, first_tag):
prompt = (
f"<|bos|>" f"<copyright></copyright>" f"<character></character>" f"{rating}{aspect_ratio}{length}" f"<general>{first_tag}"
)
prompts = [prompt] * num_prompts
inputs = tokenizer(prompts, return_tensors="pt").input_ids
inputs = inputs.to("cuda")
with torch.no_grad():
outputs = model.generate(
inputs,
do_sample=True,
temperature=1.0,
top_p=1.0,
top_k=100,
max_new_tokens=128,
num_beams=1,
)
# return ", ".join([tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""])
decoded = []
for i in range(num_prompts):
output = outputs[i].cpu()
tags = tokenizer.batch_decode(output, skip_special_tokens=True)
prompt = ", ".join([tag for tag in tags if tag.strip() != ""])
decoded.append(prompt)
return decoded
# 網羅的に作るタグ類
"""
1024 x 1024 1:1 Square
1152 x 896 9:7
896 x 1152 7:9
1216 x 832 19:13
832 x 1216 13:19
1344 x 768 7:4 Horizontal
768 x 1344 4:7 Vertical
1536 x 640 12:5 Horizontal
640 x 1536 5:12 Vertical
"""
DIMENSIONS = [(1024, 1024), (1152, 896), (896, 1152), (1216, 832), (832, 1216), (1344, 768), (768, 1344), (1536, 640), (640, 1536)]
ASPECT_RATIO_TAGS = [
"<|aspect_ratio:square|>",
"<|aspect_ratio:wide|>",
"<|aspect_ratio:tall|>",
"<|aspect_ratio:wide|>",
"<|aspect_ratio:tall|>",
"<|aspect_ratio:wide|>",
"<|aspect_ratio:tall|>",
"<|aspect_ratio:ultra_wide|>",
"<|aspect_ratio:ultra_tall|>",
]
RATING_MODIFIERS = ["safe", "sensitive"] # , "nsfw", "explicit, nsfw"]
RATING_TAGS = ["<|rating:general|>", "<|rating:sensitive|>"] # , "<|rating:questionable|>", "<|rating:explicit|>"]
FIRST_TAGS = [
"no humans",
"1girl",
"2girls",
"3girls",
"4girls",
"5girls",
"6+girls",
"1boy",
"2boys",
"3boys",
"4boys",
"5boys",
"6+boys",
"1other",
"2others",
"3others",
"4others",
"5others",
"6+others",
]
# ランダムに選ぶタグ類
LENGTH_TAGS = ["<|length:very_short|>", "<|length:short|>", "<|length:medium|>", "<|length:long|>", "<|length:very_long|>"]
"""
newest 2021 to 2024
recent 2018 to 2020
mid 2015 to 2017
early 2011 to 2014
oldest 2005 to 2010
"""
YEAR_MODIFIERS = [None, "newest", "recent", "mid"] # , "early", "oldest"]
# ranomly select 0 to 4 of these
QUALITY_MODIFIERS_AND_AESTHETIC = ["masterpiece", "best quality", "very aesthetic", "absurdres"]
# negative prompt
NEGATIVE_PROMPT = (
"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, "
+ "oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]"
)
# NEGATIVE_PROMPT = (
# "nsfw, lowres, bad, text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, "
# + "oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, abstract"
# )
NUM_PROMPTS_PER_VARIATION = 8
BATCH_SIZE = 8 # 大きくしたいが、バッチ内で length が同じになってしまう
assert NUM_PROMPTS_PER_VARIATION * len(YEAR_MODIFIERS) % BATCH_SIZE == 0
MODEL_NAME = "p1atdev/dart-v2-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
model.to("cuda")
# make prompts
PARTITION = "a" # prefix for the output file
random.seed(42)
prompts = []
for rating_modifier, rating_tag in zip(RATING_MODIFIERS, RATING_TAGS):
negative_prompt = NEGATIVE_PROMPT
if "nsfw" in rating_modifier:
negative_prompt = negative_prompt.replace("nsfw, ", "")
for dimension, aspect_ratio_tag in zip(DIMENSIONS, ASPECT_RATIO_TAGS):
for first_tag in FIRST_TAGS:
print(f"rating: {rating_modifier}, aspect ratio: {dimension}, first tag: {first_tag}")
# year_modifier はDart v2の引数にならないので、ここでバッチを作ることでバッチサイズを稼ぐ
dart_prompts = []
for i in range(0, NUM_PROMPTS_PER_VARIATION * len(YEAR_MODIFIERS), BATCH_SIZE):
# ひとつのバッチの中で length が同じになってしまうのでどうにかしたいけど難しそう
length = random.choice(LENGTH_TAGS)
dart_prompts += get_prompt(model, BATCH_SIZE, rating_tag, aspect_ratio_tag, length, first_tag)
num_prompts_for_each_year_modifier = NUM_PROMPTS_PER_VARIATION
for j, year_modifier in enumerate(YEAR_MODIFIERS):
for prompt in dart_prompts[j * num_prompts_for_each_year_modifier : (j + 1) * num_prompts_for_each_year_modifier]:
# escape `(` and `)`, like "star (symbol)" -> "star \(symbol\)"
prompt = prompt.replace("(", "\\(").replace(")", "\\)")
# select quality modifiers and aesthetic
quality_modifiers = random.sample(QUALITY_MODIFIERS_AND_AESTHETIC, random.randint(0, 4))
quality_modifiers = ", ".join(quality_modifiers)
# combine quality modifiers and aesthetic, year modifier and rating modifier
qm = f"{quality_modifiers}, " if quality_modifiers else ""
ym = f", {year_modifier}" if year_modifier else ""
# build final prompt
image_index = len(prompts)
width, height = dimension
rm_filename = rating_modifier.replace(", ", "_") # "nsfw, explicit" -> "nsfw_explicit"
ym_filename = year_modifier if year_modifier else "none"
ft_filename = first_tag.replace("+", "") # remove "+" from "6+girls" etc.
ft_filename = ft_filename.replace(" ", "") # remove space from "no humans" etc.
image_filename = (
f"{PARTITION}{image_index:08d}_{rm_filename}_{width:04d}x{height:04d}_{ym_filename}_{ft_filename}.webp"
)
seed = random.randint(0, 2**32 - 1)
final_prompt = f"{qm}{prompt}, {rating_modifier}{ym} --n {negative_prompt} --w {width} --h {height} --ow {width} --oh {height} --d {seed} --f {image_filename}"
prompts.append(final_prompt)
# break # test
# break
# break
# output to a file
with open(f"prompts_{PARTITION}.txt", "w") as f:
f.write("\n".join(prompts))
print(f"Done. {len(prompts)} prompts are written to prompts_{PARTITION}.txt.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment