Skip to content

Instantly share code, notes, and snippets.

@waundme
Created August 11, 2023 13:16
Show Gist options
  • Save waundme/e1ee2939b97f2f1d22fd1680b77071e7 to your computer and use it in GitHub Desktop.
Save waundme/e1ee2939b97f2f1d22fd1680b77071e7 to your computer and use it in GitHub Desktop.
from os import environ
from psutil import cpu_count
from pathlib import Path
from typing import Optional
from typing import Union
import logging
from transformers import Pipeline
import torch.nn.functional as F
import torch
from fastapi import FastAPI
# copied from the model card
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class SentenceEmbeddingPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
# we don't have any hyperameters to sanitize
preprocess_kwargs = {}
return preprocess_kwargs, {}, {}
def preprocess(self, inputs):
encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
return encoded_inputs
def _forward(self, model_inputs):
outputs = self.model(**model_inputs)
return {"outputs": outputs, "attention_mask": model_inputs["attention_mask"]}
def postprocess(self, model_outputs):
# Perform pooling
sentence_embeddings = mean_pooling(model_outputs["outputs"], model_outputs['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
app = FastAPI()
@app.on_event("startup")
async def startup_event():
logging.info("loading model")
environ["OMP_NUM_THREADS"] = str(cpu_count(logical=True))
environ["OMP_WAIT_POLICY"] = 'ACTIVE'
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
from pathlib import Path
model_id="aari1995/German_Semantic_STS_V2"
onnx_path = Path("onnx")
# load vanilla transformers and convert to onnx
global model
model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True)
global tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# save onnx checkpoint and tokenizer
model.save_pretrained(onnx_path)
tokenizer.save_pretrained(onnx_path)
from optimum.onnxruntime import ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig
# create ORTOptimizer and define optimization configuration
optimizer = ORTOptimizer.from_pretrained(model_id)
optimization_config = OptimizationConfig(optimization_level=99) # enable all optimizations
# apply the optimization configuration to the model
optimizer.export(
onnx_model_path=onnx_path / "model.onnx",
onnx_optimized_model_output_path=onnx_path / "model-optimized.onnx",
optimization_config=optimization_config,
)
from optimum.onnxruntime import ORTQuantizer, ORTModelForFeatureExtraction
from optimum.onnxruntime.configuration import AutoQuantizationConfig
onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True)
quantizer = ORTQuantizer.from_pretrained(onnx_model)
dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
quantizer.quantize(
save_dir="vnni",
quantization_config=dqconfig,
)
import onnx
from onnx.external_data_helper import load_external_data_for_model
onnx_model = onnx.load('./vnni/model_quantized.onnx', load_external_data=False)
load_external_data_for_model(onnx_model, './vnni')
model = ORTModelForFeatureExtraction.from_pretrained("./vnni")
global vanilla_emb
vanilla_emb = SentenceEmbeddingPipeline(model=model, tokenizer=tokenizer)
@app.get("/embed")
async def get(text: Union[str, None] = None):
v = vanilla_emb(text)
return v[0].tolist()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment