Last active
October 28, 2022 02:38
-
-
Save ShadowPower/2abd890aa195cf5cb5825486254eb4bc to your computer and use it in GitHub Desktop.
convert sentence transformers multilingual clip model to onnx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from os import makedirs | |
from os.path import dirname, abspath | |
from pathlib import Path | |
from onnxruntime.quantization import quantize_dynamic, QuantType | |
from sentence_transformers import SentenceTransformer, util | |
import torch | |
from torch import nn | |
from transformers.onnx.convert import export_pytorch | |
img_model = SentenceTransformer('clip-ViT-B-32') | |
text_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1') | |
BASE_DIR = Path(dirname(abspath(__file__))) | |
output_dir = BASE_DIR / 'onnx' | |
if not output_dir.exists(): | |
makedirs(output_dir) | |
def export_textual_model_to_onnx(model, tokenizer, path): | |
with torch.no_grad(): | |
default_input = torch.ones(1, tokenizer.model_max_length, dtype=torch.int32).to('cpu') | |
symbolic_names = {0: "batch_size"} | |
dummy_input = { | |
"input": { | |
"input_ids": default_input, | |
"attention_mask": default_input, | |
} | |
} | |
torch.onnx.export( | |
model, | |
dummy_input, | |
path, | |
opset_version=11, | |
do_constant_folding=True, | |
input_names=["input_ids", "attention_mask"], | |
output_names=["output"], | |
dynamic_axes={ | |
"input_ids": symbolic_names, | |
"attention_mask": symbolic_names, | |
"output": symbolic_names | |
}, | |
) | |
def export_visual_model_to_onnx(model, path): | |
with torch.no_grad(): | |
symbolic_names = {0: "batch_size"} | |
dummy_input = { | |
"input": { | |
"pixel_values": torch.rand(1, 3, 224, 224), | |
"image_text_info": [0], | |
} | |
} | |
torch.onnx.export( | |
model, | |
dummy_input, | |
path, | |
opset_version=11, | |
do_constant_folding=True, | |
input_names=["pixel_values", "image_text_info"], | |
output_names=["output"], | |
dynamic_axes={ | |
"pixel_values": symbolic_names, | |
"output": symbolic_names | |
}, | |
) | |
def quantization(input_file: Path, output_file: Path): | |
quantize_dynamic(input_file, output_file, weight_type=QuantType.QInt8) | |
# model | |
output_value_name = 'sentence_embedding' | |
class TextualModel(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.model = model | |
def forward(self, input): | |
return self.model.forward(input)[output_value_name] | |
class VisualModel(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.model = model | |
def forward(self, input): | |
return self.model.forward(input)[output_value_name] | |
# export | |
if __name__ == '__main__': | |
textual_model = TextualModel(text_model) | |
export_textual_model_to_onnx(textual_model, text_model.tokenizer, output_dir / 'textual.onnx') | |
visual_model = VisualModel(img_model) | |
export_visual_model_to_onnx(visual_model, output_dir / 'visual.onnx') | |
# quantization(output_dir / 'textual.onnx', output_dir / 'textual-q.onnx') | |
# quantization(output_dir / 'visual.onnx', output_dir / 'visual-q.onnx') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment