Created
July 26, 2023 01:26
-
-
Save jaretburkett/cc269915dc7f29c0f64e0d161b81bd78 to your computer and use it in GitHub Desktop.
Interrogate image folder with blip 2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Interrogator using BLIP2\n | |
""" | |
# set gpu to 1 | |
# set pci bus order | |
import os | |
from clip_interrogator.clip_interrogator import _truncate_to_fit | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
import argparse | |
from tqdm import tqdm | |
from PIL import Image | |
from clip_interrogator import Config, Interrogator, LabelTable, load_list | |
import json | |
class CustomHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): | |
pass | |
def adjust_path(path): | |
if not os.path.isabs(path): | |
path = os.path.join(os.getcwd(), path) | |
return path | |
parser = argparse.ArgumentParser( | |
# use header as description but keep line breaks | |
description=__doc__, | |
formatter_class=CustomHelpFormatter | |
) | |
add_arg = parser.add_argument | |
add_arg('input_folder', type=str, help='Path to folder containing images') | |
add_arg('--ext', type=str, default=".txt", help='Caption extension. Either .txt or .caption') | |
add_arg('--meta', type=str, default=None, | |
help='Path to meta file to create for sd-scripts. If none, ext extension will be used') | |
add_arg('-r', '--recursive', type=bool, default=False, action=argparse.BooleanOptionalAction, | |
help='Recursively search for images in subfolders') | |
add_arg('--model', type=str, default='blip2-2.7b', | |
choices=['blip-base', 'blip-large', 'blip2-2.7b', 'blip2-flan-t5-xl', 'git-large-coco'], | |
help='Which model to use for captioning') | |
add_arg('-t', '--terms', type=str, default=None, | |
help='Path to terms config file in json or txt format. If none, only captions will be generated') | |
add_arg('--v2', type=bool, default=False, action=argparse.BooleanOptionalAction, | |
help='Use CLIP for SD 2.1 instead of SD 1.5') | |
add_arg('--always_save', type=bool, default=False, action=argparse.BooleanOptionalAction, | |
help='Always save captions to file even doing meta') | |
add_arg('--regen', type=bool, default=False, action=argparse.BooleanOptionalAction, | |
help='Always generate new captions even if they exist') | |
# list of text files to use for captioning | |
args = parser.parse_args() | |
img_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.png', '.webp', '.tiff', '.tif'] | |
input_folder = args.input_folder = adjust_path(args.input_folder) | |
caption_ext = args.ext | |
meta_path = args.meta = adjust_path(args.meta) if args.meta is not None else None | |
terms_path = args.terms = adjust_path(args.terms) if args.terms is not None else None | |
recursive = args.recursive | |
print("Interrogating images...") | |
print("") | |
print("Args:") | |
# print each arg | |
for arg in vars(args): | |
print(f" --{arg}") | |
print(f" - {getattr(args, arg)}") | |
print("") | |
# read json file of terms config in an object | |
terms_config = None | |
terms_txt_path = None | |
if args.terms is not None: | |
if os.path.splitext(args.terms)[1].lower() == '.json': | |
print(f"Reading terms config from {terms_path}") | |
with open(terms_path, 'r') as f: | |
terms_config = json.load(f) | |
elif os.path.splitext(args.terms)[1].lower() == '.txt': | |
print(f"Reading terms from {terms_path}") | |
terms_txt_path = terms_path | |
# add a . in front of extension if not already there | |
if caption_ext[0] != '.': | |
caption_ext = '.' + caption_ext | |
# get all images in input folder | |
images = [] | |
if recursive: | |
for root, dirs, files in os.walk(input_folder): | |
for file in files: | |
if os.path.splitext(file)[1].lower() in img_ext: | |
images.append(os.path.join(root, file)) | |
else: | |
for file in os.listdir(input_folder): | |
if os.path.splitext(file)[1].lower() in img_ext: | |
images.append(os.path.join(input_folder, file)) | |
print(f"Found {len(images)} images in {input_folder}") | |
if len(images) == 0: | |
raise Exception(f"No images found in {input_folder}") | |
# CAPTION_MODELS = { | |
# 'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB | |
# 'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB | |
# 'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB | |
# 'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB | |
# 'git-large-coco': 'microsoft/git-large-coco', # 1.58GB | |
# } | |
print('Loading Models. Takes a bit...') | |
clip_model_name = "ViT-L-14/openai" | |
if args.v2: | |
clip_model_name = "ViT-H-14/laion2b_s32b_b79k" | |
data_dir = os.path.join(os.path.dirname(__file__), "data") | |
cache_dir = os.path.join(data_dir, "interrogate_cache") | |
if not os.path.exists(data_dir): | |
os.makedirs(data_dir, exist_ok=True) | |
if not os.path.exists(cache_dir): | |
os.makedirs(cache_dir, exist_ok=True) | |
interrogator_config = Config( | |
clip_model_name=clip_model_name, | |
caption_model_name=args.model, | |
cache_path=cache_dir, | |
quiet=True # no status bar, we have our own | |
) | |
interrogator = Interrogator(interrogator_config) | |
def clean_caption(cap): | |
cap_list = cap.split(",") | |
# trim whitespace | |
cap_list = [c.strip() for c in cap_list] | |
# remove empty strings | |
cap_list = [c for c in cap_list if c != ""] | |
# remove duplicates | |
cap_list = list(dict.fromkeys(cap_list)) | |
# join back together | |
cap = ", ".join(cap_list) | |
return cap | |
captionDict = {} | |
for image_path in tqdm(images): | |
try: | |
caption = None | |
if not args.regen: | |
caption_path = os.path.splitext(image_path)[0] + caption_ext | |
if os.path.exists(caption_path): | |
with open(caption_path, 'r') as f: | |
caption = f.read() | |
# clean up caption | |
caption = clean_caption(caption) | |
if caption is None: | |
image = Image.open(image_path).convert('RGB') | |
# caption = interrogator.interrogate(image) | |
if terms_config is not None: | |
result_list = [] | |
if terms_config["include_caption"]: | |
caption = interrogator.generate_caption(image) | |
result_list.append(caption.strip()) | |
image_features = interrogator.image_to_features(image) | |
for term_obj in terms_config["terms"]: | |
# { | |
# "name": "gender simple", | |
# "template": "[value]", | |
# "max_tokens": 1, | |
# "ignore_token": null, | |
# "values": [ | |
# "male", | |
# "female", | |
# "other" | |
# ] | |
# }, | |
# use template and insert value for each value in values | |
term_list = [term_obj['template'].replace("[value]", value) for value in term_obj['values']] | |
table = LabelTable(term_list, term_obj['name'], interrogator) | |
best_match = table.rank(image_features, top_count=term_obj['max_tokens']) | |
# remove ignore token if present | |
if term_obj['ignore_token'] is not None: | |
best_match = [match for match in best_match if | |
match != term_obj['ignore_token'] and match != ' ' and match != '' and match != term_obj[ | |
'template'].replace("[value]", term_obj['ignore_token'])] | |
if "caption_replace" in term_obj: | |
for rep_string in term_obj["caption_replace"]: | |
to_find, to_replace = rep_string.split("|") | |
best_match = [match.replace(to_find, to_replace) for match in best_match] | |
# add to result list | |
result_list.extend([match for match in best_match]) | |
caption = ", ".join(result_list) | |
elif terms_txt_path is not None: | |
caption = interrogator.generate_caption(image) | |
table = LabelTable(load_list(terms_txt_path), 'terms', interrogator) | |
features = interrogator.image_to_features(image) | |
best_match = table.rank(features, top_count=32) | |
caption = _truncate_to_fit(caption + ", " + ", ".join(best_match), interrogator.tokenize) | |
else: | |
caption = interrogator.generate_caption(image) | |
caption = clean_caption(caption) | |
captionDict[image_path] = { | |
"caption": caption | |
} | |
# create txt file with same name as image with caption and caption extension | |
# only if we don't have a meta path | |
if meta_path is None or args.always_save: | |
caption_path = os.path.splitext(image_path)[0] + caption_ext | |
with open(caption_path, 'w') as f: | |
# remove unicode characters | |
caption = caption.encode('ascii', 'ignore').decode('ascii') | |
f.write(caption) | |
except Exception as e: | |
print(f"Error processing {image_path}: {e}") | |
# write meta as json file if we have a meta path | |
if meta_path is not None: | |
with open(meta_path, 'w') as f: | |
json.dump(captionDict, f, indent=4) | |
print(f"Meta file written to {meta_path}") | |
print("Done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment