Created
May 9, 2025 00:15
-
-
Save hotchpotch/c6d83949a7587595606c1f54b0c41072 to your computer and use it in GitHub Desktop.
cross_encoder_to_onnx_pr.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sentence_transformers import CrossEncoder, export_dynamic_quantized_onnx_model, export_optimized_onnx_model | |
# モデル名の定義 | |
MODEL_NAME = "hotchpotch/japanese-reranker-xsmall-v2" | |
# 基本モデルの読み込み(CPUを使用、ONNXバックエンド) | |
model = CrossEncoder(MODEL_NAME, device="cpu", backend="onnx") | |
# 1. 基本モデル (model.onnx) | |
# Hubにプッシュして、必要に応じてPRを作成 | |
# model.push_to_hub(MODEL_NAME, create_pr=True) | |
# 2. 最適化レベルごとのモデル作成 (model_O1.onnx ~ model_O4.onnx) | |
# O1~O4までの最適化レベルを適用 | |
optimization_levels = ["O1", "O2", "O3", "O4"] | |
for i, level in enumerate(optimization_levels, 1): | |
export_optimized_onnx_model( | |
model, | |
level, | |
MODEL_NAME, | |
push_to_hub=True, | |
create_pr=True, | |
file_suffix=f"O{i}" # model_O1.onnx, model_O2.onnx, model_O3.onnx, model_O4.onnx | |
) | |
if i == 3: | |
# default | |
export_optimized_onnx_model( | |
model, | |
level, | |
MODEL_NAME, | |
push_to_hub=True, | |
create_pr=True, | |
) | |
# 3. 異なるアーキテクチャ向けの量子化モデル作成 | |
# arm64向け | |
export_dynamic_quantized_onnx_model( | |
model, | |
"arm64", | |
MODEL_NAME, | |
push_to_hub=True, | |
create_pr=True, | |
file_suffix="qint8_arm64" # model_qint8_arm64.onnx | |
) | |
# avx512向け | |
export_dynamic_quantized_onnx_model( | |
model, | |
"avx512", | |
MODEL_NAME, | |
push_to_hub=True, | |
create_pr=True, | |
file_suffix="qint8_avx512" # model_qint8_avx512.onnx | |
) | |
# avx512_vnni向け | |
export_dynamic_quantized_onnx_model( | |
model, | |
"avx512_vnni", | |
MODEL_NAME, | |
push_to_hub=True, | |
create_pr=True, | |
file_suffix="qint8_avx512_vnni" # model_qint8_avx512_vnni.onnx | |
) | |
# avx2向け | |
export_dynamic_quantized_onnx_model( | |
model, | |
"avx2", | |
MODEL_NAME, | |
push_to_hub=True, | |
create_pr=True, | |
file_suffix="qint8_avx2" # model_qint8_avx2.onnx | |
) | |
print("すべてのモデル変換が完了しました!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment