Skip to content

Instantly share code, notes, and snippets.

@axsddlr
Created January 29, 2023 20:30
Show Gist options
  • Save axsddlr/b36847b58abfb6ab35f00a8ecac321b8 to your computer and use it in GitHub Desktop.
Save axsddlr/b36847b58abfb6ab35f00a8ecac321b8 to your computer and use it in GitHub Desktop.
clip-interrogator
# CLIP Interrogator 2.2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic)
#
# Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!
#
# <br>
#
# For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.
#
# This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms.
#
# You can also run this on HuggingFace and Replicate<br>
# [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/pharmapsychotic/clip-interrogator/badge)](https://replicate.com/pharmapsychotic/clip-interrogator)
#
# <br>
#
# If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂
#
# And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
#
import os
import subprocess
import urllib.request
def setup():
install_cmds = [
['pip', 'install', 'gradio'],
['pip', 'install', 'open_clip_torch'],
['pip', 'install', 'clip-interrogator'],
['pip', 'install', 'git+https://github.com/pharmapsychotic/BLIP.git'],
]
for cmd in install_cmds:
print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))
setup()
clip_model_name = 'ViT-L-14/openai' # @param ["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"]
print("Download preprocessed cache files...")
CACHE_URLS = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',
] if clip_model_name == 'ViT-L-14/openai' else [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',
]
try:
if not os.path.exists("cache"):
os.makedirs("cache")
except OSError as e:
print("Error creating cache directory: ", e)
for url in CACHE_URLS:
file_name = url.split('/')[-1]
file_path = os.path.join("cache", file_name)
try:
urllib.request.urlretrieve(url, file_path)
except urllib.error.HTTPError as e:
print("Error retrieving data from url: ", e)
except urllib.error.URLError as e:
print("Error retrieving data from url: ", e)
import gradio as gr
from clip_interrogator import Config, Interrogator
config = Config()
config.blip_num_beams = 64
config.blip_offload = False
config.clip_model_name = clip_model_name
ci = Interrogator(config)
def inference(image, mode, best_max_flavors=32):
ci.config.chunk_size = 2048 if ci.config.clip_model_name == "ViT-L-14/openai" else 1024
ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == "ViT-L-14/openai" else 1024
image = image.convert('RGB')
if mode == 'best':
return ci.interrogate(image, max_flavors=int(best_max_flavors))
elif mode == 'classic':
return ci.interrogate_classic(image)
else:
return ci.interrogate_fast(image)
inputs = [
gr.inputs.Image(type='pil'),
gr.Radio(['best', 'fast'], label='', value='best'),
gr.Number(value=16, label='best mode max flavors'),
]
outputs = [
gr.outputs.Textbox(label="Output"),
]
io = gr.Interface(
inference,
inputs,
outputs,
allow_flagging=False,
)
io.launch(debug=False, server_port=7861)
import csv
import os
from IPython.display import clear_output, display
from PIL import Image
from tqdm import tqdm
folder_path = "/content/my_images" # @param {type:"string"}
prompt_mode = 'best' # @param ["best","fast"]
output_mode = 'rename' # @param ["desc.csv","rename"]
max_filename_len = 128 # @param {type:"integer"}
best_max_flavors = 16 # @param {type:"integer"}
def sanitize_for_filename(prompt: str, max_len: int) -> str:
name = "".join(c for c in prompt if (c.isalnum() or c in ",._-! "))
name = name.strip()[:(max_len - 4)] # extra space for extension
return name
ci.config.quiet = True
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(
folder_path) else []
prompts = []
for idx, file in enumerate(tqdm(files, desc='Generating prompts')):
if idx > 0 and idx % 100 == 0:
clear_output(wait=True)
image = Image.open(os.path.join(folder_path, file)).convert('RGB')
prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)
prompts.append(prompt)
print(prompt)
thumb = image.copy()
thumb.thumbnail([256, 256])
display(thumb)
if output_mode == 'rename':
name = sanitize_for_filename(prompt, max_filename_len)
ext = os.path.splitext(file)[1]
filename = name + ext
idx = 1
while os.path.exists(os.path.join(folder_path, filename)):
print(f'File {filename} already exists, trying {idx + 1}...')
filename = f"{name}_{idx}{ext}"
idx += 1
os.rename(os.path.join(folder_path, file), os.path.join(folder_path, filename))
if len(prompts):
if output_mode == 'desc.csv':
csv_path = os.path.join(folder_path, 'desc.csv')
with open(csv_path, 'w', encoding='utf-8', newline='') as f:
w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
w.writerow(['image', 'prompt'])
for file, prompt in zip(files, prompts):
w.writerow([file, prompt])
print(f"\n\n\n\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!")
else:
print(f"\n\n\n\nGenerated {len(prompts)} prompts and renamed your files, enjoy!")
else:
print(f"Sorry, I couldn't find any images in {folder_path}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment