Skip to content

Instantly share code, notes, and snippets.

@ShadowPower
Last active October 28, 2022 02:38
Show Gist options
  • Save ShadowPower/2abd890aa195cf5cb5825486254eb4bc to your computer and use it in GitHub Desktop.
Save ShadowPower/2abd890aa195cf5cb5825486254eb4bc to your computer and use it in GitHub Desktop.
convert sentence transformers multilingual clip model to onnx
from os import makedirs
from os.path import dirname, abspath
from pathlib import Path
from onnxruntime.quantization import quantize_dynamic, QuantType
from sentence_transformers import SentenceTransformer, util
import torch
from torch import nn
from transformers.onnx.convert import export_pytorch
img_model = SentenceTransformer('clip-ViT-B-32')
text_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1')
BASE_DIR = Path(dirname(abspath(__file__)))
output_dir = BASE_DIR / 'onnx'
if not output_dir.exists():
makedirs(output_dir)
def export_textual_model_to_onnx(model, tokenizer, path):
with torch.no_grad():
default_input = torch.ones(1, tokenizer.model_max_length, dtype=torch.int32).to('cpu')
symbolic_names = {0: "batch_size"}
dummy_input = {
"input": {
"input_ids": default_input,
"attention_mask": default_input,
}
}
torch.onnx.export(
model,
dummy_input,
path,
opset_version=11,
do_constant_folding=True,
input_names=["input_ids", "attention_mask"],
output_names=["output"],
dynamic_axes={
"input_ids": symbolic_names,
"attention_mask": symbolic_names,
"output": symbolic_names
},
)
def export_visual_model_to_onnx(model, path):
with torch.no_grad():
symbolic_names = {0: "batch_size"}
dummy_input = {
"input": {
"pixel_values": torch.rand(1, 3, 224, 224),
"image_text_info": [0],
}
}
torch.onnx.export(
model,
dummy_input,
path,
opset_version=11,
do_constant_folding=True,
input_names=["pixel_values", "image_text_info"],
output_names=["output"],
dynamic_axes={
"pixel_values": symbolic_names,
"output": symbolic_names
},
)
def quantization(input_file: Path, output_file: Path):
quantize_dynamic(input_file, output_file, weight_type=QuantType.QInt8)
# model
output_value_name = 'sentence_embedding'
class TextualModel(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input):
return self.model.forward(input)[output_value_name]
class VisualModel(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input):
return self.model.forward(input)[output_value_name]
# export
if __name__ == '__main__':
textual_model = TextualModel(text_model)
export_textual_model_to_onnx(textual_model, text_model.tokenizer, output_dir / 'textual.onnx')
visual_model = VisualModel(img_model)
export_visual_model_to_onnx(visual_model, output_dir / 'visual.onnx')
# quantization(output_dir / 'textual.onnx', output_dir / 'textual-q.onnx')
# quantization(output_dir / 'visual.onnx', output_dir / 'visual-q.onnx')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment