Last active
September 21, 2024 10:23
-
-
Save dcbark01/b329e170d0473bbdfd012e04c17bcfd3 to your computer and use it in GitHub Desktop.
Embed Mistral FastAPI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# See https://huggingface.co/intfloat/e5-mistral-7b-instruct for model inference code | |
## Quickstart | |
Install requirements | |
```bash | |
pip install fastapi uvicorn torch transformers | |
``` | |
```bash | |
# Use base conda environment (python 3.11) | |
source venv/bin/activate | |
# Start the server | |
uvicorn app:app --reload --host 0.0.0.0 --port 8000 | |
uvicorn app_gpu:app --reload --host 0.0.0.0 --port 8000 | |
``` | |
Then to perform inference request: | |
```python | |
# Then inference with (e.g. using requests) | |
import requests | |
# Server URL | |
url = 'http://localhost:8000/embed/' | |
headers = {'accept': 'application/json', 'Content-Type': 'application/json'} | |
data = { | |
"queries": ["how much protein should a female eat", "summit define"], | |
"documents": [ | |
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", | |
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments." | |
] | |
} | |
response = requests.post(url, json=data, headers=headers) | |
``` | |
""" | |
from typing import List | |
import torch | |
from pydantic import BaseModel | |
import torch.nn.functional as F | |
from fastapi import FastAPI, HTTPException | |
from transformers import AutoTokenizer, AutoModel | |
DEVICE = "cuda" # Should be cpu or cuda | |
app = FastAPI() | |
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct') | |
model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct') | |
if DEVICE == "cuda": | |
if not torch.cuda.is_available(): | |
print("CUDA not available in current environment. Defaulting to CPU.") | |
DEVICE = "cpu" | |
model = model.to(DEVICE) | |
print(f"Using Device: {DEVICE}") | |
class QueryDocument(BaseModel): | |
queries: List[str] | |
documents: List[str] | |
def last_token_pool(last_hidden_states: torch.Tensor, | |
attention_mask: torch.Tensor) -> torch.Tensor: | |
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) | |
if left_padding: | |
return last_hidden_states[:, -1] | |
else: | |
sequence_lengths = attention_mask.sum(dim=1) - 1 | |
batch_size = last_hidden_states.shape[0] | |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] | |
@app.post("/embed/") | |
def create_embeddings(query_document: QueryDocument): | |
input_texts = query_document.queries + query_document.documents | |
max_length = 4096 | |
try: | |
# Tokenize the input texts | |
batch_dict = tokenizer(input_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True) | |
# append eos_token_id to every input_ids | |
batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']] | |
batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt') | |
for key in batch_dict.keys(): | |
vals = batch_dict[key] | |
if isinstance(vals, torch.Tensor): | |
vals = vals.to(DEVICE) | |
batch_dict[key] = vals | |
outputs = model(**batch_dict) | |
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) | |
# normalize embeddings | |
embeddings = F.normalize(embeddings, p=2, dim=1) | |
embeddings_np = embeddings.detach().cpu().numpy() | |
# scores = (embeddings[:len(query_document.queries)] @ embeddings[len(query_document.queries):].T) * 100 | |
output = { | |
"embeddings": embeddings_np.tolist(), | |
# "scores": scores.tolist() | |
} | |
return output | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment