Last active
June 16, 2023 06:42
-
-
Save jcrubino/307e57f1d4a1a539120c51e069643f68 to your computer and use it in GitHub Desktop.
Simple Flask Web Server for Image Captioning and Document Embedding Generation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import logging | |
from io import BytesIO | |
import torch | |
import torch.nn.functional as F | |
from flask import Flask, request | |
from PIL import Image | |
from torch import Tensor | |
from transformers import AutoModel, AutoTokenizer | |
app = Flask(__name__) | |
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
TEXT_EMBED_MODEL = "intfloat/e5-large-v2" | |
logging.info(f"Text Embedding Model: {TEXT_EMBED_MODEL}") | |
embedding_tokenizer = AutoTokenizer.from_pretrained(TEXT_EMBED_MODEL) | |
embedding_model = AutoModel.from_pretrained(TEXT_EMBED_MODEL) | |
PROD_MODEL = "microsoft/git-base-coco" | |
logging.info(f"Image 2 Text model: {PROD_MODEL}") | |
def process_image(pil_image, prompt=""): | |
raise NotImplementedError() | |
def load_model(PROD_MODEL): | |
if "Salesforce" in PROD_MODEL: | |
from transformers import BlipForConditionalGeneration, BlipProcessor | |
processor = BlipProcessor.from_pretrained(PROD_MODEL) | |
model = BlipForConditionalGeneration.from_pretrained(PROD_MODEL, load_in_8bit=True) | |
elif "microsoft/" in PROD_MODEL: | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
processor = AutoProcessor.from_pretrained(PROD_MODEL) | |
model = AutoModelForCausalLM.from_pretrained(PROD_MODEL, load_in_8bit=True) | |
else: | |
raise Exception(f"Unknown Image 2 Text Model: {PROD_MODEL}") | |
return model, processor | |
model, processor = load_model(PROD_MODEL) | |
def process_image(pil_image, prompt=""): | |
device = "cuda" | |
caption = "Nothing Processed" | |
if isinstance(processor, BlipProcessor): | |
if prompt != "": | |
inputs = processor(pil_image, prompt, return_tensors="pt").to( | |
device, torch.float16 | |
) | |
else: | |
inputs = processor(pil_image, return_tensors="pt").to(device, torch.float16) | |
elif isinstance(processor, AutoProcessor): | |
pixel_values = processor( | |
images=pil_image, return_tensors="pt" | |
).pixel_values.half() | |
if prompt != "": | |
input_ids = processor(text=question, add_special_tokens=False).input_ids | |
input_ids = [processor.tokenizer.cls_token_id] + input_ids | |
input_ids = torch.tensor(input_ids).unsqueeze(0) | |
inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "max_length": 50} | |
else: | |
inputs = {"pixel_values": pixel_values, "max_length": 50} | |
else: | |
raise Exception("Unknown processor type") | |
generated_ids = model.generate(**inputs) | |
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return caption.strip() | |
@app.route("/") | |
def hello_world(): | |
return """ | |
<h1>Base64 Caption Server</h1> | |
<p>Post an HTML base64 encoded image to /base64caption and receive a description.</p> | |
""" | |
@app.route("/base64caption", methods=["POST"]) | |
def base64caption(): | |
request_data = request.get_json() | |
base64_data = request_data["base64"] | |
base64_data = base64_data.split(",")[-1] | |
image_data = base64.b64decode(base64_data) | |
pil_image = Image.open(BytesIO(image_data)).convert("RGB") | |
caption = process_image(pil_image) | |
return {"caption": caption} | |
@app.route("/embedding", methods=["GET"]) | |
def embedding_base(): | |
return """\ | |
\r<h1>Embeddings</h1> | |
\r<p></p> | |
\r<p>Available Endpoints: /embedding/text</p> | |
\r<p>post data: {"batch":[list, of, strings], "prefix":(query|passage), "normalize":(True|False)}</p> | |
""" | |
@app.route("/embedding/text", methods=["POST"]) | |
def embedding_text(): | |
request_data = request.get_json() | |
logging.info(request_data) | |
batch_data = request_data["batch"] | |
prefix = request_data["prefix"] | |
normalize = request_data["normalize"] | |
assert normalize in [True, False] | |
input_texts = [f"{prefix}: {item}" for item in batch_data] | |
batch_dict = embedding_tokenizer( | |
input_texts, max_length=512, padding=True, truncation=True, return_tensors="pt" | |
) | |
outputs = embedding_model(**batch_dict) | |
embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) | |
if normalize: | |
embeddings = F.normalize(embeddings, p=2, dim=1) | |
return { | |
"embeddings": embeddings.tolist(), | |
"batch": batch_data, | |
"prefix": prefix, | |
"normalize": normalize, | |
} | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=5000, debug=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment