Created
August 11, 2023 13:16
-
-
Save waundme/e1ee2939b97f2f1d22fd1680b77071e7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from os import environ | |
from psutil import cpu_count | |
from pathlib import Path | |
from typing import Optional | |
from typing import Union | |
import logging | |
from transformers import Pipeline | |
import torch.nn.functional as F | |
import torch | |
from fastapi import FastAPI | |
# copied from the model card | |
def mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output[0] #First element of model_output contains all token embeddings | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
class SentenceEmbeddingPipeline(Pipeline): | |
def _sanitize_parameters(self, **kwargs): | |
# we don't have any hyperameters to sanitize | |
preprocess_kwargs = {} | |
return preprocess_kwargs, {}, {} | |
def preprocess(self, inputs): | |
encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt') | |
return encoded_inputs | |
def _forward(self, model_inputs): | |
outputs = self.model(**model_inputs) | |
return {"outputs": outputs, "attention_mask": model_inputs["attention_mask"]} | |
def postprocess(self, model_outputs): | |
# Perform pooling | |
sentence_embeddings = mean_pooling(model_outputs["outputs"], model_outputs['attention_mask']) | |
# Normalize embeddings | |
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) | |
return sentence_embeddings | |
app = FastAPI() | |
@app.on_event("startup") | |
async def startup_event(): | |
logging.info("loading model") | |
environ["OMP_NUM_THREADS"] = str(cpu_count(logical=True)) | |
environ["OMP_WAIT_POLICY"] = 'ACTIVE' | |
from optimum.onnxruntime import ORTModelForFeatureExtraction | |
from transformers import AutoTokenizer | |
from pathlib import Path | |
model_id="aari1995/German_Semantic_STS_V2" | |
onnx_path = Path("onnx") | |
# load vanilla transformers and convert to onnx | |
global model | |
model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) | |
global tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# save onnx checkpoint and tokenizer | |
model.save_pretrained(onnx_path) | |
tokenizer.save_pretrained(onnx_path) | |
from optimum.onnxruntime import ORTOptimizer | |
from optimum.onnxruntime.configuration import OptimizationConfig | |
# create ORTOptimizer and define optimization configuration | |
optimizer = ORTOptimizer.from_pretrained(model_id) | |
optimization_config = OptimizationConfig(optimization_level=99) # enable all optimizations | |
# apply the optimization configuration to the model | |
optimizer.export( | |
onnx_model_path=onnx_path / "model.onnx", | |
onnx_optimized_model_output_path=onnx_path / "model-optimized.onnx", | |
optimization_config=optimization_config, | |
) | |
from optimum.onnxruntime import ORTQuantizer, ORTModelForFeatureExtraction | |
from optimum.onnxruntime.configuration import AutoQuantizationConfig | |
onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True) | |
quantizer = ORTQuantizer.from_pretrained(onnx_model) | |
dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False) | |
quantizer.quantize( | |
save_dir="vnni", | |
quantization_config=dqconfig, | |
) | |
import onnx | |
from onnx.external_data_helper import load_external_data_for_model | |
onnx_model = onnx.load('./vnni/model_quantized.onnx', load_external_data=False) | |
load_external_data_for_model(onnx_model, './vnni') | |
model = ORTModelForFeatureExtraction.from_pretrained("./vnni") | |
global vanilla_emb | |
vanilla_emb = SentenceEmbeddingPipeline(model=model, tokenizer=tokenizer) | |
@app.get("/embed") | |
async def get(text: Union[str, None] = None): | |
v = vanilla_emb(text) | |
return v[0].tolist() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment