|
import argparse |
|
import ast |
|
from gradio_client import Client, handle_file |
|
import json |
|
import logging |
|
|
|
logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
# Define the task_prompts dictionary first |
|
task_prompts = { |
|
"Caption": "<CAPTION>", |
|
"Detailed Caption": "<DETAILED_CAPTION>", |
|
"More Detailed Caption": "<MORE_DETAILED_CAPTION>", |
|
"Object Detection": "<OD>", |
|
"Dense Region Caption": "<DENSE_REGION_CAPTION>", |
|
"Region Proposal": "<REGION_PROPOSAL>", |
|
"Caption to Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>", |
|
"Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>", |
|
"Region to Segmentation": "<REGION_TO_SEGMENTATION>", |
|
"Open Vocabulary Detection": "<OPEN_VOCABULARY_DETECTION>", |
|
"Region to Category": "<REGION_TO_CATEGORY>", |
|
"Region to Description": "<REGION_TO_DESCRIPTION>", |
|
"OCR": "<OCR>", |
|
"OCR with Region": "<OCR_WITH_REGION>" |
|
} |
|
|
|
def main(image_url: str, task: str, model_id: str, client_url: str): |
|
""" |
|
Main function to process image captioning. |
|
|
|
Parameters: |
|
- image_url: URL of the image to caption. |
|
- task_prompt_key: Key corresponding to the desired task prompt. |
|
- model_id: Model ID to use for prediction. |
|
- client_url: URL of the client server. |
|
""" |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
# Directly use task as the task_prompt |
|
task_prompt = task |
|
logger.debug(f"Using task prompt key: {task_prompt}") |
|
|
|
client = Client(client_url, verbose=False) |
|
logger.debug(f"Client instantiated with URL: {client_url}") |
|
|
|
# Perform the prediction |
|
try: |
|
result = client.predict( |
|
image=handle_file(image_url), |
|
task_prompt=task_prompt, |
|
text_input=None, |
|
model_id=model_id, |
|
) |
|
logger.debug(f"Prediction result: {result}") |
|
except Exception as e: |
|
logger.error(f"Exception occurred during prediction: {e}", exc_info=True) |
|
return |
|
|
|
result_string, _ = result |
|
|
|
# Initialize result_dict to None |
|
result_dict = None |
|
|
|
if result_string is not None and result_string.strip()!= "": |
|
try: |
|
result_dict = ast.literal_eval(result_string) |
|
logger.debug(f"Evaluating result string: {result_string}") |
|
except Exception as e: |
|
logger.error(f"Exception occurred while evaluating result string: {e}", exc_info=True) |
|
return |
|
else: |
|
json_output = json.dumps(result_dict, indent=None) |
|
# logger.info(json_output) |
|
|
|
# Load the JSON string into a Python dictionary |
|
json_output_dict = json.loads(json_output) |
|
|
|
# Process each value in the dictionary to remove newline characters |
|
for key in json_output_dict.keys(): |
|
# Remove newline characters in the value entirely |
|
json_output_dict[key] = json_output_dict[key].replace('\n', '').replace('\\n', '') |
|
|
|
# Convert the modified dictionary back into a JSON string without formatting |
|
modified_json_output_str = json.dumps(json_output_dict, indent=None) |
|
|
|
# Print the resulting JSON string |
|
print(modified_json_output_str) |
|
else: |
|
logger.warning("Result string is empty or None.") |
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Image Captioning Script') |
|
parser.add_argument('--image_url', type=str, required=True, help='URL of the image to caption') |
|
parser.add_argument('--task', type=str, choices=list(task_prompts.keys()), required=True, help='Image processing task supported by the model') |
|
parser.add_argument('--model_id', type=str, default="microsoft/Florence-2-large", help='Model ID to use for prediction') |
|
parser.add_argument('--client_url', type=str, default="http://100.107.248.20:42421/", help='Client server URL') |
|
|
|
args = parser.parse_args() |
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Running with arguments: {args}") |
|
|
|
main(**vars(args)) # Using vars() to convert args to dict |
|
|