Last active
December 30, 2024 02:42
-
-
Save weimeng23/9c23fa9ee836e15d7108885309a30a22 to your computer and use it in GitHub Desktop.
export bert onnx model with dynamic_axes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MyBertForSequenceClassification(BertForSequenceClassification): | |
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None): | |
outputs = super().forward( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
) | |
return outputs.logits | |
def get_dummy_input(seq_length=512): | |
input_ids = torch.tensor([[i for i in range(seq_length)]], dtype=torch.long) | |
attention_mask = torch.tensor([[1 for i in range(seq_length)]], dtype=torch.long) | |
token_type_ids = torch.tensor( | |
[[0 for i in range(int(seq_length / 2))] + [1 for i in range(seq_length - int(seq_length / 2))]], | |
dtype=torch.long, | |
) | |
return input_ids, attention_mask, token_type_ids | |
def export_onnx(model, tokenizer, onnx_path, seq_length=512): | |
dummy_inputs = get_dummy_input(seq_length) | |
model.eval() | |
with torch.no_grad(): | |
torch.onnx.export( | |
model, | |
dummy_inputs, | |
onnx_path, | |
# verbose=True, | |
opset_version=16, | |
input_names=['input_ids', 'attention_mask', 'token_type_ids'], | |
output_names=['output'], | |
dynamic_axes={ | |
'input_ids': {0: 'batch', 1: 'seq_len'}, | |
'attention_mask': {0: 'batch', 1: 'seq_len'}, | |
'token_type_ids': {0: 'batch', 1: 'seq_len'}, | |
'output': {0: 'batch'}, | |
}, | |
) | |
def test_onnx(onnx_path): | |
print() | |
print('---------------------- test onnx: ', onnx_path) | |
import numpy as np | |
onnx_model = BertOnnxWrapper(onnx_path, 1) | |
inputs = tokenizer( | |
text, | |
return_tensors='np', | |
truncation=True, | |
max_length=512, | |
padding='longest', | |
) | |
start_time = time.time() | |
onnx_outputs = onnx_model(inputs) | |
end_time = time.time() | |
print( | |
f'length text is {len(text)}, total time is {end_time - start_time}, per text is {(end_time - start_time) / len(text)}' | |
) | |
# print('############## onnx model output: ', onnx_outputs) | |
np.testing.assert_allclose(onnx_outputs, to_numpy(torch_outputs), rtol=1e-05, atol=1e-08) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def export_model(model, onnx_path, quantize=False): | |
input_tensor = torch.randn(5, 320000) # 16000 * 20 | |
model.eval() | |
with torch.no_grad(): | |
torch.onnx.export( | |
model, # model being run | |
input_tensor, # model input (or a tuple for multiple inputs) | |
onnx_path, # where to save the model (can be a file or file-like object) | |
opset_version=16, | |
input_names=["input"], # the model's input names | |
output_names=["output"], # the model's output names | |
dynamic_axes={ | |
'input': { | |
0: 'batch_size', | |
1: 'sequence_length', | |
}, # variable length axes | |
'output': {0: 'batch_size'}, | |
}, | |
) | |
if quantize: | |
from onnxruntime.quantization import QuantType, quantize_dynamic | |
quantize_dynamic( | |
onnx_path, | |
quant_onnx_path, | |
op_types_to_quantize=['MatMul'], | |
weight_type=QuantType.QUInt8, | |
) | |
def test_export_model(): | |
model = xxxModel.from_pretrained( | |
model_name_or_path, | |
config=config, | |
) | |
model.eval() | |
onnx_fp32_model = OnnxWrapper(onnx_fp32_path) | |
input_tensor = torch.randn(5, 320000) | |
with torch.no_grad(): | |
torch_outs = model(input_tensor) | |
onnx_fp32_outs = onnx_fp32_model(input_tensor.numpy()) | |
np.testing.assert_allclose( | |
torch_outs.numpy(), onnx_fp32_outs, rtol=1e-03, atol=1e-05 | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def convert_onnx_float32_to_float16(fp32_model_path, fp16_model_path): | |
from onnxmltools.utils.float16_converter import convert_float_to_float16 | |
from onnxmltools.utils import save_model | |
model = onnx.load(fp32_model_path) | |
onnx.checker.check_model(model) | |
new_onnx_model = convert_float_to_float16(model, keep_io_types=False) | |
save_model(new_onnx_model, fp16_model_path) | |
def test_convert_onnx_float32_to_float16(): | |
input_tensor = torch.randn(5, 320000) | |
fp16_model = OnnxWrapper(onnx_fp16_path) | |
fp16_outs = fp16_model(input_tensor.numpy().astype(np.float16)) | |
print( | |
'fp16 ouputs: ', fp16_outs, '\t', fp16_outs[0].shape, '\t', fp16_outs[0].dtype | |
) | |
model = xxxModel.from_pretrained( | |
model_name_or_path, | |
config=config, | |
) | |
model.eval() | |
torch_outs = model(input_tensor) | |
np.testing.assert_allclose(torch_outs.numpy(), fp16_outs, rtol=1e-03, atol=1e-05) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment