Created
February 22, 2023 04:22
-
-
Save watzon/1e15172550a537532b0b52e12703aada to your computer and use it in GitHub Desktop.
Stable Diffusion WebUI to InvokeAI prompt conversion script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import argparse | |
re_attention = re.compile(r""" | |
\\\(| | |
\\\)| | |
\\\[| | |
\\]| | |
\\\\| | |
\\| | |
\(| | |
\[| | |
:([+-]?[.\d]+)\)| | |
\)| | |
]| | |
[^\\()\[\]:]+| | |
: | |
""", re.X) | |
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) | |
def parse_prompt_attention(text): | |
""" | |
Parses a string with attention tokens and returns a list of pairs: text and its associated weight. | |
Accepted tokens are: | |
(abc) - increases attention to abc by a multiplier of 1.1 | |
(abc:3.12) - increases attention to abc by a multiplier of 3.12 | |
[abc] - decreases attention to abc by a multiplier of 1.1 | |
\( - literal character '(' | |
\[ - literal character '[' | |
\) - literal character ')' | |
\] - literal character ']' | |
\\ - literal character '\' | |
anything else - just text | |
>>> parse_prompt_attention('normal text') | |
[['normal text', 1.0]] | |
>>> parse_prompt_attention('an (important) word') | |
[['an ', 1.0], ['important', 1.1], [' word', 1.0]] | |
>>> parse_prompt_attention('(unbalanced') | |
[['unbalanced', 1.1]] | |
>>> parse_prompt_attention('\(literal\]') | |
[['(literal]', 1.0]] | |
>>> parse_prompt_attention('(unnecessary)(parens)') | |
[['unnecessaryparens', 1.1]] | |
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') | |
[['a ', 1.0], | |
['house', 1.5730000000000004], | |
[' ', 1.1], | |
['on', 1.0], | |
[' a ', 1.1], | |
['hill', 0.55], | |
[', sun, ', 1.1], | |
['sky', 1.4641000000000006], | |
['.', 1.1]] | |
""" | |
res = [] | |
round_brackets = [] | |
square_brackets = [] | |
round_bracket_multiplier = 1.1 | |
square_bracket_multiplier = 1 / 1.1 | |
def multiply_range(start_position, multiplier): | |
for p in range(start_position, len(res)): | |
res[p][1] *= multiplier | |
for m in re_attention.finditer(text): | |
text = m.group(0) | |
weight = m.group(1) | |
if text.startswith('\\'): | |
res.append([text[1:], 1.0]) | |
elif text == '(': | |
round_brackets.append(len(res)) | |
elif text == '[': | |
square_brackets.append(len(res)) | |
elif weight is not None and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), float(weight)) | |
elif text == ')' and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), round_bracket_multiplier) | |
elif text == ']' and len(square_brackets) > 0: | |
multiply_range(square_brackets.pop(), square_bracket_multiplier) | |
else: | |
parts = re.split(re_break, text) | |
for i, part in enumerate(parts): | |
if i > 0: | |
res.append(["BREAK", -1]) | |
res.append([part, 1.0]) | |
for pos in round_brackets: | |
multiply_range(pos, round_bracket_multiplier) | |
for pos in square_brackets: | |
multiply_range(pos, square_bracket_multiplier) | |
if len(res) == 0: | |
res = [["", 1.0]] | |
# merge runs of identical weights | |
i = 0 | |
while i + 1 < len(res): | |
if res[i][1] == res[i + 1][1]: | |
res[i][0] += res[i + 1][0] | |
res.pop(i + 1) | |
else: | |
i += 1 | |
return res | |
def prompt_attention_to_invoke_prompt(attention): | |
""" | |
Converts a list of pairs: text and its associated weight to a list of InvokePrompt tokens. | |
Example: | |
>>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 1.4]]) | |
'normal text (something important)1.4' | |
For any weight under 1.3 which does not contain extra decimal places (eg. 1.22 or 1.13) the weight | |
is represented with `+` characters, where each `+` represents a weight increase of 0.1. | |
Example: | |
>>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 1.2]]) | |
'normal text (something important)++' | |
>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 1.22]]) | |
'normal text (something important)1.22' | |
Brackets decrease weight in the same way that parentheses increase weight. As with parentheses, any weight decrease over 0.8 | |
is represented with `-` characters, where each `-` represents a weight decrease of 0.1. | |
Example: | |
>>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 0.9]]) | |
'normal text (something important)--' | |
>>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 0.92]]) | |
'normal text (something important)0.92' | |
""" | |
tokens = [] | |
for text, weight in attention: | |
# Round weight to 2 decimal places | |
weight = round(weight, 2) | |
if weight == 1.0: | |
tokens.append(text) | |
elif weight < 1.0: | |
if weight < 0.8: | |
tokens.append(f"({text}){weight}") | |
else: | |
tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10)) | |
else: | |
if weight < 1.3: | |
tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10)) | |
else: | |
tokens.append(f"({text}){weight}") | |
return "".join(tokens) | |
parser = argparse.ArgumentParser(description="Parse prompt attention") | |
parser.add_argument("-d", "--debug", action="store_true", help="enable debug mode") | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
while True: | |
text = input("Enter prompt: ") | |
print() | |
attention = parse_prompt_attention(text) | |
if args.debug: | |
print(attention) | |
print() | |
print(prompt_attention_to_invoke_prompt(attention)) | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment