Skip to content

Instantly share code, notes, and snippets.

@jaretburkett
Created July 26, 2023 01:26
Show Gist options
  • Save jaretburkett/cc269915dc7f29c0f64e0d161b81bd78 to your computer and use it in GitHub Desktop.
Save jaretburkett/cc269915dc7f29c0f64e0d161b81bd78 to your computer and use it in GitHub Desktop.
Interrogate image folder with blip 2
#!/usr/bin/env python3
"""
Interrogator using BLIP2\n
"""
# set gpu to 1
# set pci bus order
import os
from clip_interrogator.clip_interrogator import _truncate_to_fit
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import argparse
from tqdm import tqdm
from PIL import Image
from clip_interrogator import Config, Interrogator, LabelTable, load_list
import json
class CustomHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
pass
def adjust_path(path):
if not os.path.isabs(path):
path = os.path.join(os.getcwd(), path)
return path
parser = argparse.ArgumentParser(
# use header as description but keep line breaks
description=__doc__,
formatter_class=CustomHelpFormatter
)
add_arg = parser.add_argument
add_arg('input_folder', type=str, help='Path to folder containing images')
add_arg('--ext', type=str, default=".txt", help='Caption extension. Either .txt or .caption')
add_arg('--meta', type=str, default=None,
help='Path to meta file to create for sd-scripts. If none, ext extension will be used')
add_arg('-r', '--recursive', type=bool, default=False, action=argparse.BooleanOptionalAction,
help='Recursively search for images in subfolders')
add_arg('--model', type=str, default='blip2-2.7b',
choices=['blip-base', 'blip-large', 'blip2-2.7b', 'blip2-flan-t5-xl', 'git-large-coco'],
help='Which model to use for captioning')
add_arg('-t', '--terms', type=str, default=None,
help='Path to terms config file in json or txt format. If none, only captions will be generated')
add_arg('--v2', type=bool, default=False, action=argparse.BooleanOptionalAction,
help='Use CLIP for SD 2.1 instead of SD 1.5')
add_arg('--always_save', type=bool, default=False, action=argparse.BooleanOptionalAction,
help='Always save captions to file even doing meta')
add_arg('--regen', type=bool, default=False, action=argparse.BooleanOptionalAction,
help='Always generate new captions even if they exist')
# list of text files to use for captioning
args = parser.parse_args()
img_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.png', '.webp', '.tiff', '.tif']
input_folder = args.input_folder = adjust_path(args.input_folder)
caption_ext = args.ext
meta_path = args.meta = adjust_path(args.meta) if args.meta is not None else None
terms_path = args.terms = adjust_path(args.terms) if args.terms is not None else None
recursive = args.recursive
print("Interrogating images...")
print("")
print("Args:")
# print each arg
for arg in vars(args):
print(f" --{arg}")
print(f" - {getattr(args, arg)}")
print("")
# read json file of terms config in an object
terms_config = None
terms_txt_path = None
if args.terms is not None:
if os.path.splitext(args.terms)[1].lower() == '.json':
print(f"Reading terms config from {terms_path}")
with open(terms_path, 'r') as f:
terms_config = json.load(f)
elif os.path.splitext(args.terms)[1].lower() == '.txt':
print(f"Reading terms from {terms_path}")
terms_txt_path = terms_path
# add a . in front of extension if not already there
if caption_ext[0] != '.':
caption_ext = '.' + caption_ext
# get all images in input folder
images = []
if recursive:
for root, dirs, files in os.walk(input_folder):
for file in files:
if os.path.splitext(file)[1].lower() in img_ext:
images.append(os.path.join(root, file))
else:
for file in os.listdir(input_folder):
if os.path.splitext(file)[1].lower() in img_ext:
images.append(os.path.join(input_folder, file))
print(f"Found {len(images)} images in {input_folder}")
if len(images) == 0:
raise Exception(f"No images found in {input_folder}")
# CAPTION_MODELS = {
# 'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB
# 'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB
# 'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB
# 'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB
# 'git-large-coco': 'microsoft/git-large-coco', # 1.58GB
# }
print('Loading Models. Takes a bit...')
clip_model_name = "ViT-L-14/openai"
if args.v2:
clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
data_dir = os.path.join(os.path.dirname(__file__), "data")
cache_dir = os.path.join(data_dir, "interrogate_cache")
if not os.path.exists(data_dir):
os.makedirs(data_dir, exist_ok=True)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
interrogator_config = Config(
clip_model_name=clip_model_name,
caption_model_name=args.model,
cache_path=cache_dir,
quiet=True # no status bar, we have our own
)
interrogator = Interrogator(interrogator_config)
def clean_caption(cap):
cap_list = cap.split(",")
# trim whitespace
cap_list = [c.strip() for c in cap_list]
# remove empty strings
cap_list = [c for c in cap_list if c != ""]
# remove duplicates
cap_list = list(dict.fromkeys(cap_list))
# join back together
cap = ", ".join(cap_list)
return cap
captionDict = {}
for image_path in tqdm(images):
try:
caption = None
if not args.regen:
caption_path = os.path.splitext(image_path)[0] + caption_ext
if os.path.exists(caption_path):
with open(caption_path, 'r') as f:
caption = f.read()
# clean up caption
caption = clean_caption(caption)
if caption is None:
image = Image.open(image_path).convert('RGB')
# caption = interrogator.interrogate(image)
if terms_config is not None:
result_list = []
if terms_config["include_caption"]:
caption = interrogator.generate_caption(image)
result_list.append(caption.strip())
image_features = interrogator.image_to_features(image)
for term_obj in terms_config["terms"]:
# {
# "name": "gender simple",
# "template": "[value]",
# "max_tokens": 1,
# "ignore_token": null,
# "values": [
# "male",
# "female",
# "other"
# ]
# },
# use template and insert value for each value in values
term_list = [term_obj['template'].replace("[value]", value) for value in term_obj['values']]
table = LabelTable(term_list, term_obj['name'], interrogator)
best_match = table.rank(image_features, top_count=term_obj['max_tokens'])
# remove ignore token if present
if term_obj['ignore_token'] is not None:
best_match = [match for match in best_match if
match != term_obj['ignore_token'] and match != ' ' and match != '' and match != term_obj[
'template'].replace("[value]", term_obj['ignore_token'])]
if "caption_replace" in term_obj:
for rep_string in term_obj["caption_replace"]:
to_find, to_replace = rep_string.split("|")
best_match = [match.replace(to_find, to_replace) for match in best_match]
# add to result list
result_list.extend([match for match in best_match])
caption = ", ".join(result_list)
elif terms_txt_path is not None:
caption = interrogator.generate_caption(image)
table = LabelTable(load_list(terms_txt_path), 'terms', interrogator)
features = interrogator.image_to_features(image)
best_match = table.rank(features, top_count=32)
caption = _truncate_to_fit(caption + ", " + ", ".join(best_match), interrogator.tokenize)
else:
caption = interrogator.generate_caption(image)
caption = clean_caption(caption)
captionDict[image_path] = {
"caption": caption
}
# create txt file with same name as image with caption and caption extension
# only if we don't have a meta path
if meta_path is None or args.always_save:
caption_path = os.path.splitext(image_path)[0] + caption_ext
with open(caption_path, 'w') as f:
# remove unicode characters
caption = caption.encode('ascii', 'ignore').decode('ascii')
f.write(caption)
except Exception as e:
print(f"Error processing {image_path}: {e}")
# write meta as json file if we have a meta path
if meta_path is not None:
with open(meta_path, 'w') as f:
json.dump(captionDict, f, indent=4)
print(f"Meta file written to {meta_path}")
print("Done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment