Last active
September 21, 2024 10:35
-
-
Save dranger003/daff444ebf04951d4279b5b2dee71ab4 to your computer and use it in GitHub Desktop.
Phi-3-Vision-128K-Instruct Quick Local API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# server.py | |
# uvicorn server:app --reload | |
import base64 | |
import queue | |
import threading | |
import torch | |
from PIL import Image | |
from io import BytesIO | |
from fastapi import FastAPI | |
from contextlib import asynccontextmanager | |
from typing import Optional | |
from pydantic import BaseModel | |
from fastapi.responses import StreamingResponse | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoProcessor, | |
# BitsAndBytesConfig, | |
TextStreamer, | |
) | |
@asynccontextmanager | |
async def lifespan(app: FastAPI): | |
# Application startup | |
load(LoadRequest(num_gpus=1, models_per_gpu=1)) | |
yield | |
# Application shutdown | |
app = FastAPI(lifespan=lifespan) | |
class TextStreamerEx(TextStreamer): | |
def __init__(self, tokenizer, output): | |
super().__init__( | |
tokenizer, | |
skip_prompt=False, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False, | |
) | |
self.output = output | |
def put(self, value): | |
if len(value.shape) > 1: | |
return | |
super().put(value) | |
def on_finalized_text(self, text, stream_end=False): | |
self.output.put(text) | |
if stream_end: | |
self.output.put(None) | |
class Model: | |
def __init__(self, model_id, index): | |
self.model_id = model_id | |
self.index = index | |
self.loaded = False | |
def load(self): | |
self.processor = AutoProcessor.from_pretrained( | |
self.model_id, trust_remote_code=True | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_id, | |
device_map=f"cuda:{self.index}", | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
# quantization_config=BitsAndBytesConfig(load_in_8bit=True), | |
) | |
self.loaded = True | |
def run(self, input, output): | |
image, text = input | |
images = None | |
if image is not None: | |
images = [image] | |
input_ids = self.processor(images=images, text=text, return_tensors="pt").to( | |
self.model.device | |
) | |
_ = self.model.generate( | |
**input_ids, | |
eos_token_id=self.processor.tokenizer.eos_token_id, | |
max_new_tokens=4096, | |
do_sample=False, | |
repetition_penalty=1.2, | |
streamer=TextStreamerEx(self.processor.tokenizer, output), | |
) | |
def prompt(self, image, text): | |
output = queue.Queue() | |
thread = threading.Thread(target=self.run, args=((image, text), output)) | |
thread.start() | |
while True: | |
text = output.get() | |
if text is None: | |
break | |
yield text | |
thread.join() | |
models = [] | |
class LoadRequest(BaseModel): | |
num_gpus: int = 1 | |
models_per_gpu: int = 1 | |
@app.post("/load") | |
def load(request: LoadRequest): | |
for gpu_index in range(request.num_gpus): | |
for _ in range(request.models_per_gpu): | |
model = Model("microsoft/Phi-3-vision-128k-instruct", gpu_index) | |
model.load() | |
models.append(model) | |
return {"status": "200 OK"} | |
class PromptRequest(BaseModel): | |
model: int = 0 | |
image: str = None | |
text: str | |
@app.post("/prompt") | |
def prompt(request: PromptRequest): | |
model = models[request.model] | |
if request.image is None: | |
image = None | |
prompt = f"{request.text}" | |
else: | |
image = Image.open(BytesIO(base64.b64decode(request.image))) | |
prompt = f"<|image_1|>\n{request.text}" | |
templatized_prompt = model.processor.tokenizer.apply_chat_template( | |
[{"role": "user", "content": prompt}], | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/sample_inference.py#L97 | |
if templatized_prompt.endswith("<|endoftext|>"): | |
templatized_prompt = templatized_prompt.rstrip("<|endoftext|>") | |
def stream(): | |
for text in model.prompt(image, templatized_prompt): | |
yield f"data: {base64.b64encode(text.encode('utf-8')).decode('utf-8')}\n\n" | |
return StreamingResponse(stream(), media_type="text/event-stream") |
Nice!
@dranger003 Do you also have a version that can run phi-3 vision on macOS?
+1
I don't have a Mac, but that code should run. Have you tried it?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice!
@dranger003 Do you also have a version that can run phi-3 vision on macOS?