Created
February 17, 2025 09:09
-
-
Save Norod/68d0121ddec47c7aaef86e1e01ee170e to your computer and use it in GitHub Desktop.
Iterate over an input folder with images, for each processed image, script will generate a corresponding .txt file (with the same base name as the image) containing the detailed caption generated by Florence-2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import argparse | |
import torch | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
#pip install torch transformers pillow | |
#python Florence-2-batch_caption.py.py ./input_images ./captions_output_folder | |
# Enable MPS fallback (if using MPS) | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
def get_device_type(): | |
if torch.cuda.is_available(): | |
return "cuda" | |
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
return "mps" | |
else: | |
return "cpu" | |
# Load the Florence-2 model and processor | |
model_id = 'microsoft/Florence-2-base-ft' | |
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) | |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
device = get_device_type() | |
model.to(device) | |
def run_example(task_prompt, image, text_input=None): | |
""" | |
Runs the Florence-2 model on the provided image using the given task prompt. | |
""" | |
# Build the prompt (if text_input is provided, concatenate it) | |
prompt = task_prompt if text_input is None else task_prompt + text_input | |
# Prepare the inputs and move them to the selected device | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
# Generate output | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3, | |
) | |
# Decode and post-process the generation | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
return parsed_answer | |
def process_folder(input_folder, output_folder): | |
""" | |
Iterates over images in the input folder, generates a "more detailed caption" for each, | |
and saves the caption as a .txt file in the output folder. | |
""" | |
if not os.path.exists(output_folder): | |
os.makedirs(output_folder) | |
# Consider common image extensions | |
valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif') | |
for filename in os.listdir(input_folder): | |
if filename.lower().endswith(valid_extensions): | |
image_path = os.path.join(input_folder, filename) | |
print(f"Processing {image_path} ...") | |
# Open the image and ensure it is in RGB format | |
try: | |
image = Image.open(image_path).convert("RGB") | |
except Exception as e: | |
print(f"Could not open {image_path}: {e}") | |
continue | |
# Run the task '<MORE_DETAILED_CAPTION>' | |
result = run_example('<MORE_DETAILED_CAPTION>', image) | |
# Extract the caption from the result dictionary | |
caption = result.get('<MORE_DETAILED_CAPTION>', '') | |
if not caption: | |
print(f"No caption returned for {filename}.") | |
continue | |
# Save the caption to a text file with the same base name | |
base_name, _ = os.path.splitext(filename) | |
txt_filename = base_name + ".txt" | |
output_path = os.path.join(output_folder, txt_filename) | |
with open(output_path, "w", encoding="utf-8") as f: | |
f.write(caption) | |
print(f"Saved caption to {output_path}") | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Generate '<MORE_DETAILED_CAPTION>' captions for all images in a folder using Florence‑2." | |
) | |
parser.add_argument("input_folder", type=str, help="Path to the folder containing input images.") | |
parser.add_argument("output_folder", type=str, help="Path to the folder where captions will be saved.") | |
args = parser.parse_args() | |
process_folder(args.input_folder, args.output_folder) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment