Last active
October 1, 2022 14:21
-
-
Save torridgristle/55fef868d68cfa9a7e3449f9d4d1808b to your computer and use it in GitHub Desktop.
Generate every combination of prompt parts, encode all of the prompts in batches to avoid running out of memory. Alternatively only keep the min/max channel values and min/max token norms and randomly generate prompts with randn noise. Intended for Stable Diffusion but can be used for anything with CLIP by just swapping out the model.get_learned…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
def prompt_combinations(prompt_parts): | |
''' | |
Provide a list of lists of prompt parts, like: | |
[ ["A ","An "], ["anteater","feather duster"] ] | |
''' | |
opt_prompt = list(itertools.product(*prompt_parts, repeat=1)) | |
opt_prompt = [''.join(opt_prompt[b]) for b in range(len(opt_prompt))] | |
return opt_prompt | |
def encode_all_prompts(opt_prompt): | |
with torch.no_grad(): | |
with torch.autocast("cuda", cache_enabled=True): | |
with model.ema_scope(): | |
c_all = [] | |
for b in range(math.ceil(len(opt_prompt)/64)): | |
c_all.append(model.get_learned_conditioning(opt_prompt[b*64:(b+1)*64])) | |
c_all = torch.cat(c_all)#.cpu() | |
return c_all | |
def encode_all_prompts_stats(opt_prompt): | |
with torch.no_grad(): | |
with torch.autocast("cuda", cache_enabled=True): | |
with model.ema_scope(): | |
max_ch = None | |
min_ch = None | |
max_norm = None | |
min_norm = None | |
for b in range(math.ceil(len(opt_prompt)/64)): | |
x = model.get_learned_conditioning(opt_prompt[b*64:(b+1)*64]) | |
if max_ch != None: | |
x = torch.cat([x,max_ch,min_ch],0) | |
max_ch = x.max(0,keepdim=True).values | |
min_ch = x.min(0,keepdim=True).values | |
norm_token = x.norm(2,-1,keepdim=True) | |
if max_norm != None: | |
norm_token = torch.cat([norm_token,max_norm,min_norm],0) | |
max_norm = norm_token.max(0,keepdim=True).values | |
min_norm = norm_token.min(0,keepdim=True).values | |
return max_ch, min_ch, max_norm, min_norm | |
def token_stats(x): | |
max_ch = x.max(0,keepdim=True).values | |
min_ch = x.min(0,keepdim=True).values | |
norm_token = x.norm(2,-1,keepdim=True) | |
max_norm = norm_token.max(0,keepdim=True).values | |
min_norm = norm_token.min(0,keepdim=True).values | |
return max_ch, min_ch, max_norm, min_norm | |
def match_token_stats(x, max_ch, min_ch, max_norm, min_norm, eps=1e-6): | |
ch_out = torch.lerp(min_ch,max_ch, x) | |
ch_out_norm = ch_out.norm(2,-1,keepdim=True) | |
ch_out = ch_out / ch_out_norm.add(eps) | |
ch_out_norm = torch.rand([x.shape[0],x.shape[1],1],device=x.device).to(x.dtype) | |
ch_out_norm = torch.lerp(min_norm,max_norm,ch_out_norm) | |
ch_out = ch_out * ch_out_norm | |
return ch_out | |
def match_token_stats_simple(x): | |
return match_token_stats(torch.rand([1,77,768],device=x[0].device,dtype=x[0].dtype), *x) | |
#Example: match_token_stats(torch.rand([4,3,6]), *token_stats(torch.randn([4,3,6]))) | |
#Example: match_token_stats_simple(token_stats) | |
import clip | |
tokenizer = clip.simple_tokenizer.SimpleTokenizer() | |
def token_check(x): | |
word_tokens = tokenizer.encode(x) | |
print(word_tokens) | |
print(tokenizer.decode(word_tokens)) | |
if len(word_tokens) > 1: | |
print(list(str(i) + " " + tokenizer.decode([word_tokens[i]]) for i in range(len(word_tokens)))) | |
print(len(word_tokens)) | |
#Example: token_check("john carpenter, David Cronenberg, David Lynch, Clive Barker") | |
#Example: | |
opt_prompt = prompt_combinations([ | |
["eldritch creature, practical effects, imdb"], | |
[", iso "], | |
["100","200","400","800"], | |
[", danny devito, body horror, metamorphosis carapace cocoon"], | |
[", john carpenter",", Cronenberg",", David Lynch",", Clive Barker"], | |
]) | |
c_all = encode_all_prompts(opt_prompt) | |
c_stats = token_stats(c_all) | |
c = match_token_stats_simple(c_stats) | |
# Or to avoid encoding all the prompts and keeping them all in memory, just keep the stats from the prompts | |
c_stats = encode_all_prompts_stats(opt_prompt) | |
c = match_token_stats_simple(c_stats) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment