Skip to content

Instantly share code, notes, and snippets.

@halr9000
Last active July 1, 2024 14:23
Show Gist options
  • Save halr9000/f07a866d16ccad6d23198b9118ccec16 to your computer and use it in GitHub Desktop.
Save halr9000/f07a866d16ccad6d23198b9118ccec16 to your computer and use it in GitHub Desktop.
Python script to caption images using microsoft/florence-2 running locally using Pinokio and Gradio. Paper page: https://huggingface.co/papers/2311.06242. Model card: https://huggingface.co/microsoft/Florence-2-large. Gradio app: https://pinokio.computer/item?uri=https://github.com/pinokiofactory/florence2.
import argparse
import ast
from gradio_client import Client, handle_file
import json
import logging
logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
# Define the task_prompts dictionary first
task_prompts = {
"Caption": "<CAPTION>",
"Detailed Caption": "<DETAILED_CAPTION>",
"More Detailed Caption": "<MORE_DETAILED_CAPTION>",
"Object Detection": "<OD>",
"Dense Region Caption": "<DENSE_REGION_CAPTION>",
"Region Proposal": "<REGION_PROPOSAL>",
"Caption to Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
"Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
"Region to Segmentation": "<REGION_TO_SEGMENTATION>",
"Open Vocabulary Detection": "<OPEN_VOCABULARY_DETECTION>",
"Region to Category": "<REGION_TO_CATEGORY>",
"Region to Description": "<REGION_TO_DESCRIPTION>",
"OCR": "<OCR>",
"OCR with Region": "<OCR_WITH_REGION>"
}
def main(image_url: str, task: str, model_id: str, client_url: str):
"""
Main function to process image captioning.
Parameters:
- image_url: URL of the image to caption.
- task_prompt_key: Key corresponding to the desired task prompt.
- model_id: Model ID to use for prediction.
- client_url: URL of the client server.
"""
logger = logging.getLogger(__name__)
# Directly use task as the task_prompt
task_prompt = task
logger.debug(f"Using task prompt key: {task_prompt}")
client = Client(client_url, verbose=False)
logger.debug(f"Client instantiated with URL: {client_url}")
# Perform the prediction
try:
result = client.predict(
image=handle_file(image_url),
task_prompt=task_prompt,
text_input=None,
model_id=model_id,
)
logger.debug(f"Prediction result: {result}")
except Exception as e:
logger.error(f"Exception occurred during prediction: {e}", exc_info=True)
return
result_string, _ = result
# Initialize result_dict to None
result_dict = None
if result_string is not None and result_string.strip()!= "":
try:
result_dict = ast.literal_eval(result_string)
logger.debug(f"Evaluating result string: {result_string}")
except Exception as e:
logger.error(f"Exception occurred while evaluating result string: {e}", exc_info=True)
return
else:
json_output = json.dumps(result_dict, indent=None)
# logger.info(json_output)
# Load the JSON string into a Python dictionary
json_output_dict = json.loads(json_output)
# Process each value in the dictionary to remove newline characters
for key in json_output_dict.keys():
# Remove newline characters in the value entirely
json_output_dict[key] = json_output_dict[key].replace('\n', '').replace('\\n', '')
# Convert the modified dictionary back into a JSON string without formatting
modified_json_output_str = json.dumps(json_output_dict, indent=None)
# Print the resulting JSON string
print(modified_json_output_str)
else:
logger.warning("Result string is empty or None.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Image Captioning Script')
parser.add_argument('--image_url', type=str, required=True, help='URL of the image to caption')
parser.add_argument('--task', type=str, choices=list(task_prompts.keys()), required=True, help='Image processing task supported by the model')
parser.add_argument('--model_id', type=str, default="microsoft/Florence-2-large", help='Model ID to use for prediction')
parser.add_argument('--client_url', type=str, default="http://100.107.248.20:42421/", help='Client server URL')
args = parser.parse_args()
logger = logging.getLogger(__name__)
logger.info(f"Running with arguments: {args}")
main(**vars(args)) # Using vars() to convert args to dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment