Created
March 28, 2024 12:15
-
-
Save iamironz/9cd1cbef3dbc344562d7a93a69baef8a to your computer and use it in GitHub Desktop.
imagetotext-kosmos.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys | |
import os | |
import re | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
def log_message(message): | |
print(message) | |
with open("log.txt", "a") as log_file: | |
log_file.write(f"{message}\n") | |
def clean_tags(text): | |
text = re.sub(r'<phrase>(.*?)</phrase>', r'\1', text) | |
text = re.sub(r'<object>|</object>|<patch_index_\d+>|</delimiter_of_multi_objects/>', '', text) | |
text = text.replace(prompt, "") | |
return text.strip() | |
def process_images(image_paths, processor, model): | |
log_message(f"Processing {len(image_paths)} images") | |
captions = [] | |
for image_path in image_paths: | |
image = Image.open(image_path) | |
inputs = processor(text=prompt, images=image, return_tensors="pt") | |
generated_ids = model.generate( | |
pixel_values=inputs["pixel_values"], | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
image_embeds=None, | |
image_embeds_position_mask=inputs["image_embeds_position_mask"], | |
use_cache=True, | |
max_new_tokens=128, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False) | |
cleaned_caption = clean_tags(processed_text) | |
captions.append(cleaned_caption) | |
return captions | |
def main(input_path, processor, model): | |
log_message(f"Input path: {input_path}") | |
image_paths = [] | |
if os.path.isdir(input_path): | |
log_message("Input path is a directory") | |
for root, dirs, files in os.walk(input_path): | |
for file in files: | |
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')): | |
image_paths.append(os.path.join(root, file)) | |
captions = process_images(image_paths, processor, model) | |
for path, caption in zip(image_paths, captions): | |
log_message(f"{path}: {caption.strip()}") | |
elif os.path.isfile(input_path) and input_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')): | |
log_message("Input path is a file") | |
caption = process_images([input_path], processor, model)[0] | |
log_message(f"{input_path}: {caption.strip()}\n") | |
else: | |
log_message("Invalid input path") | |
if __name__ == "__main__": | |
log_message("Loading model and processor...") | |
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224") | |
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") | |
prompt = "<grounding>Describe this image (and all object especially via comma) in detail:" | |
if len(sys.argv) != 2: | |
log_message("Usage: script.py <path_to_file_or_folder>") | |
else: | |
main(sys.argv[1], processor, model) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment