Created
July 20, 2023 18:29
-
-
Save nousr/5ed6c359b913f8b05435d969a4ba6cc9 to your computer and use it in GitHub Desktop.
A simple LLAVA(R) api and query example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, request, jsonify | |
import torch | |
import base64 | |
import io | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
# Import the necessary modules from the llava directory | |
from transformers import AutoTokenizer | |
from llava.conversation import conv_templates, SeparatorStyle | |
from llava.utils import disable_torch_init | |
from llava.model.utils import KeywordsStoppingCriteria | |
from llava.model import LlavaLlamaForCausalLM, LlavaMPTForCausalLM | |
from transformers import CLIPVisionModel, CLIPImageProcessor | |
# Constants | |
DEFAULT_IMAGE_TOKEN = "<image>" | |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" | |
DEFAULT_IM_START_TOKEN = "<im_start>" | |
DEFAULT_IM_END_TOKEN = "<im_end>" | |
app = Flask(__name__) | |
# Model parameters | |
model_path = "<HF_MODEL_PATH>" | |
# Load the model once when the script starts | |
disable_torch_init() | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
if "mpt" in model_path.lower(): | |
model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() | |
else: | |
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() | |
image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) | |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) | |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) | |
if mm_use_im_start_end: | |
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) | |
# Ensure the vision tower is on the correct device | |
vision_tower = model.get_model().vision_tower[0] | |
if vision_tower.device.type == 'meta': | |
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda() | |
model.get_model().vision_tower[0] = vision_tower | |
else: | |
vision_tower.to(device='cuda', dtype=torch.float16) | |
vision_config = vision_tower.config | |
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] | |
vision_config.use_im_start_end = mm_use_im_start_end | |
if mm_use_im_start_end: | |
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) | |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 | |
def load_image_from_base64(base64_str): | |
base64_img_bytes = base64.b64decode(base64_str) | |
image = Image.open(io.BytesIO(base64_img_bytes)).convert('RGB') | |
return image | |
@app.route('/predict', methods=['POST']) | |
def predict(): | |
data = request.get_json() | |
query = data['query'] | |
base64_image = data['image'] | |
conv_mode = data.get('conv_mode', None) # get conv_mode from the POST request data | |
# Load and process the image | |
image = load_image_from_base64(base64_image) | |
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
# Process the query | |
if model.config.mm_use_im_start_end: | |
query = query + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN | |
else: | |
query = query + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len | |
inferred_conv_mode = "multimodal" # default conv mode | |
if "v1" in model_path.lower(): | |
inferred_conv_mode = "llava_v1" | |
elif "mpt" in model_path.lower(): | |
inferred_conv_mode = "mpt_multimodal" | |
if conv_mode is not None and inferred_conv_mode != conv_mode: | |
print('[WARNING] the auto inferred conversation mode is {}, while `conv_mode` is {}, using {}'.format(inferred_conv_mode, conv_mode, conv_mode)) | |
else: | |
conv_mode = inferred_conv_mode | |
conv = conv_templates[conv_mode].copy() | |
conv.append_message(conv.roles[0], query) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
inputs = tokenizer([prompt]) | |
input_ids = torch.as_tensor(inputs.input_ids).to("cuda") | |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
keywords = [stop_str] | |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | |
# Generate a response | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids, | |
images=image_tensor.unsqueeze(0).half().cuda(), | |
do_sample=True, | |
temperature=0.2, | |
max_new_tokens=1024, | |
stopping_criteria=[stopping_criteria] | |
) | |
# Decode and format the output | |
input_token_len = input_ids.shape[1] | |
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() | |
if n_diff_input_output > 0: | |
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') | |
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | |
outputs = outputs.strip() | |
if outputs.endswith(stop_str): | |
outputs = outputs[:-len(stop_str)] | |
outputs = outputs.strip() | |
response = {"response": outputs} | |
return jsonify(response) | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import requests | |
import click | |
from PIL import Image | |
import io | |
@click.command() | |
@click.option('--image_path', prompt='Image path', help='Path to the image file.') | |
@click.option('--query', prompt='Query', help='Text query associated with the image.') | |
def call_api(image_path, query): | |
# Open the image file in binary mode, convert it to RGB and encode it as base64 | |
image = Image.open(image_path).convert('RGB') | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
# Define the API endpoint | |
url = 'http://localhost:8898/predict' | |
# Define the data payload | |
data = { | |
'image': img_str, | |
'query': query | |
} | |
# Make the POST request and print the response | |
response = requests.post(url, json=data) | |
print(response.json()) | |
if __name__ == '__main__': | |
call_api() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment