Skip to content

Instantly share code, notes, and snippets.

@watzon
Created February 22, 2023 04:22
Show Gist options
  • Save watzon/1e15172550a537532b0b52e12703aada to your computer and use it in GitHub Desktop.
Save watzon/1e15172550a537532b0b52e12703aada to your computer and use it in GitHub Desktop.
Stable Diffusion WebUI to InvokeAI prompt conversion script
import re
import argparse
re_attention = re.compile(r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""", re.X)
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re_break, text)
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def prompt_attention_to_invoke_prompt(attention):
"""
Converts a list of pairs: text and its associated weight to a list of InvokePrompt tokens.
Example:
>>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 1.4]])
'normal text (something important)1.4'
For any weight under 1.3 which does not contain extra decimal places (eg. 1.22 or 1.13) the weight
is represented with `+` characters, where each `+` represents a weight increase of 0.1.
Example:
>>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 1.2]])
'normal text (something important)++'
>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 1.22]])
'normal text (something important)1.22'
Brackets decrease weight in the same way that parentheses increase weight. As with parentheses, any weight decrease over 0.8
is represented with `-` characters, where each `-` represents a weight decrease of 0.1.
Example:
>>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 0.9]])
'normal text (something important)--'
>>> prompt_attention_to_invoke_prompt([['normal text', 1.0], ['something important', 0.92]])
'normal text (something important)0.92'
"""
tokens = []
for text, weight in attention:
# Round weight to 2 decimal places
weight = round(weight, 2)
if weight == 1.0:
tokens.append(text)
elif weight < 1.0:
if weight < 0.8:
tokens.append(f"({text}){weight}")
else:
tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10))
else:
if weight < 1.3:
tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10))
else:
tokens.append(f"({text}){weight}")
return "".join(tokens)
parser = argparse.ArgumentParser(description="Parse prompt attention")
parser.add_argument("-d", "--debug", action="store_true", help="enable debug mode")
if __name__ == "__main__":
args = parser.parse_args()
while True:
text = input("Enter prompt: ")
print()
attention = parse_prompt_attention(text)
if args.debug:
print(attention)
print()
print(prompt_attention_to_invoke_prompt(attention))
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment