Created
January 29, 2023 20:30
-
-
Save axsddlr/b36847b58abfb6ab35f00a8ecac321b8 to your computer and use it in GitHub Desktop.
clip-interrogator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CLIP Interrogator 2.2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) | |
# | |
# Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers! | |
# | |
# <br> | |
# | |
# For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model. | |
# | |
# This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. | |
# | |
# You can also run this on HuggingFace and Replicate<br> | |
# [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/pharmapsychotic/clip-interrogator/badge)](https://replicate.com/pharmapsychotic/clip-interrogator) | |
# | |
# <br> | |
# | |
# If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂 | |
# | |
# And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html). | |
# | |
import os | |
import subprocess | |
import urllib.request | |
def setup(): | |
install_cmds = [ | |
['pip', 'install', 'gradio'], | |
['pip', 'install', 'open_clip_torch'], | |
['pip', 'install', 'clip-interrogator'], | |
['pip', 'install', 'git+https://github.com/pharmapsychotic/BLIP.git'], | |
] | |
for cmd in install_cmds: | |
print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8')) | |
setup() | |
clip_model_name = 'ViT-L-14/openai' # @param ["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"] | |
print("Download preprocessed cache files...") | |
CACHE_URLS = [ | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl', | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl', | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl', | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl', | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl', | |
] if clip_model_name == 'ViT-L-14/openai' else [ | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl', | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl', | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl', | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl', | |
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl', | |
] | |
try: | |
if not os.path.exists("cache"): | |
os.makedirs("cache") | |
except OSError as e: | |
print("Error creating cache directory: ", e) | |
for url in CACHE_URLS: | |
file_name = url.split('/')[-1] | |
file_path = os.path.join("cache", file_name) | |
try: | |
urllib.request.urlretrieve(url, file_path) | |
except urllib.error.HTTPError as e: | |
print("Error retrieving data from url: ", e) | |
except urllib.error.URLError as e: | |
print("Error retrieving data from url: ", e) | |
import gradio as gr | |
from clip_interrogator import Config, Interrogator | |
config = Config() | |
config.blip_num_beams = 64 | |
config.blip_offload = False | |
config.clip_model_name = clip_model_name | |
ci = Interrogator(config) | |
def inference(image, mode, best_max_flavors=32): | |
ci.config.chunk_size = 2048 if ci.config.clip_model_name == "ViT-L-14/openai" else 1024 | |
ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == "ViT-L-14/openai" else 1024 | |
image = image.convert('RGB') | |
if mode == 'best': | |
return ci.interrogate(image, max_flavors=int(best_max_flavors)) | |
elif mode == 'classic': | |
return ci.interrogate_classic(image) | |
else: | |
return ci.interrogate_fast(image) | |
inputs = [ | |
gr.inputs.Image(type='pil'), | |
gr.Radio(['best', 'fast'], label='', value='best'), | |
gr.Number(value=16, label='best mode max flavors'), | |
] | |
outputs = [ | |
gr.outputs.Textbox(label="Output"), | |
] | |
io = gr.Interface( | |
inference, | |
inputs, | |
outputs, | |
allow_flagging=False, | |
) | |
io.launch(debug=False, server_port=7861) | |
import csv | |
import os | |
from IPython.display import clear_output, display | |
from PIL import Image | |
from tqdm import tqdm | |
folder_path = "/content/my_images" # @param {type:"string"} | |
prompt_mode = 'best' # @param ["best","fast"] | |
output_mode = 'rename' # @param ["desc.csv","rename"] | |
max_filename_len = 128 # @param {type:"integer"} | |
best_max_flavors = 16 # @param {type:"integer"} | |
def sanitize_for_filename(prompt: str, max_len: int) -> str: | |
name = "".join(c for c in prompt if (c.isalnum() or c in ",._-! ")) | |
name = name.strip()[:(max_len - 4)] # extra space for extension | |
return name | |
ci.config.quiet = True | |
files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists( | |
folder_path) else [] | |
prompts = [] | |
for idx, file in enumerate(tqdm(files, desc='Generating prompts')): | |
if idx > 0 and idx % 100 == 0: | |
clear_output(wait=True) | |
image = Image.open(os.path.join(folder_path, file)).convert('RGB') | |
prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors) | |
prompts.append(prompt) | |
print(prompt) | |
thumb = image.copy() | |
thumb.thumbnail([256, 256]) | |
display(thumb) | |
if output_mode == 'rename': | |
name = sanitize_for_filename(prompt, max_filename_len) | |
ext = os.path.splitext(file)[1] | |
filename = name + ext | |
idx = 1 | |
while os.path.exists(os.path.join(folder_path, filename)): | |
print(f'File {filename} already exists, trying {idx + 1}...') | |
filename = f"{name}_{idx}{ext}" | |
idx += 1 | |
os.rename(os.path.join(folder_path, file), os.path.join(folder_path, filename)) | |
if len(prompts): | |
if output_mode == 'desc.csv': | |
csv_path = os.path.join(folder_path, 'desc.csv') | |
with open(csv_path, 'w', encoding='utf-8', newline='') as f: | |
w = csv.writer(f, quoting=csv.QUOTE_MINIMAL) | |
w.writerow(['image', 'prompt']) | |
for file, prompt in zip(files, prompts): | |
w.writerow([file, prompt]) | |
print(f"\n\n\n\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!") | |
else: | |
print(f"\n\n\n\nGenerated {len(prompts)} prompts and renamed your files, enjoy!") | |
else: | |
print(f"Sorry, I couldn't find any images in {folder_path}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment