Skip to content

Instantly share code, notes, and snippets.

@mht-sharma
Last active November 24, 2022 11:04
Show Gist options
  • Save mht-sharma/537bf59f5b30ce340644481c8ed1b3eb to your computer and use it in GitHub Desktop.
Save mht-sharma/537bf59f5b30ce340644481c8ed1b3eb to your computer and use it in GitHub Desktop.
Profiling Whisper Model - Hugging Face
import time
import numpy as np
import onnxruntime
import torch
from datasets import load_dataset
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
sess_options = onnxruntime.SessionOptions()
# Set graph optimization level
# sess_options.log_severity_level = 2
# sess_options.intra_op_num_threads = 1
# sess_options.graph_optimization_level = (
# onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
# )
# device = torch.device("cuda:0")
device = torch.device("cpu")
model_name = "openai/whisper-tiny.en"
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
input = ds[0]["audio"]["array"]
###########################################################################################
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name)
speech_recognition_pipeline = pipeline(
"automatic-speech-recognition",
model=model,
feature_extractor=processor.feature_extractor,
tokenizer=processor.tokenizer,
device=device,
)
result = speech_recognition_pipeline(input)
print(result)
###########################################################################################
from datasets import load_dataset
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
processor = AutoProcessor.from_pretrained(model_name)
model_optimum = ORTModelForSpeechSeq2Seq.from_pretrained(
model_name, from_transformers=True, use_cache=True, session_options=sess_options
)
speech_recognition_pipeline_optimum = pipeline(
"automatic-speech-recognition",
model=model_optimum,
feature_extractor=processor.feature_extractor,
tokenizer=processor.tokenizer,
device=device,
)
result = speech_recognition_pipeline_optimum(input)
print(result)
###########################################################################################
iterations = 20
def measure_latency(pipe):
latencies = []
# warm up
for _ in range(10):
_ = pipe(input)
# Timed run
for _ in range(iterations):
start_time = time.perf_counter()
_ = pipe(input)
latency = time.perf_counter() - start_time
latencies.append(latency)
# Compute run statistics
time_avg_ms = 1000 * np.mean(latencies)
time_std_ms = 1000 * np.std(latencies)
return f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}"
print(f"Torch model {measure_latency(speech_recognition_pipeline)}")
print(f"Optimum model {measure_latency(speech_recognition_pipeline_optimum)}")
import torch
import json
import pandas as pd
import matplotlib.pyplot as plt
import os
from pathlib import Path
from datasets import load_dataset
from transformers import AutoProcessor
import onnxruntime
from tqdm import tqdm
import time
model_name = "openai/whisper-tiny.en"
def get_options(model_path=None, profile=False):
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 0
if profile:
options.optimized_model_filepath = os.path.join(
os.path.dirname(model_path), Path(model_path).stem + "_optimized.onnx"
)
options.enable_profiling = True
return options
def get_provider(device="cpu"):
providers = ["CPUExecutionProvider"]
if device == "gpu":
providers.append(
(
"CUDAExecutionProvider",
{
"device_id": 0,
# 'arena_extend_strategy': 'kNextPowerOfTwo',
# 'gpu_mem_limit': 2 * 1024 * 1024 * 1024,
# 'cudnn_conv_algo_search': 'EXHAUSTIVE',
# 'do_copy_in_default_stream': True,
},
)
)
return providers
def get_session(model_path, profile=False, device="cpu"):
ort_session = onnxruntime.InferenceSession(
model_path,
sess_options=get_options(model_path, profile=profile),
providers=get_provider(device=device),
)
return ort_session
def generate_encoder_input(path=None):
processor = AutoProcessor.from_pretrained(model_name)
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
onnx_inputs = {
"input_features": inputs["input_features"].cpu().detach().numpy(),
}
return onnx_inputs
def generate_decoder_input(path):
encoder_hidden_states = run_encoder(path)[0]
input_ids = torch.ones((1, 1), dtype=torch.int64)
input_ids[0][0] = 50257
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
"encoder_hidden_states": encoder_hidden_states,
}
return onnx_inputs
def generate_decoder_with_past_input(
path, session_outputs, session_output_names, key_value_input_names
):
encoder_hidden_states = run_encoder(path)[0]
decoder_outputs = run_decoder(path)
input_ids = torch.ones((1, 1), dtype=torch.int64)
input_ids[0][0] = 50257
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
"encoder_hidden_states": encoder_hidden_states,
}
past_key_values = tuple(
decoder_outputs[session_outputs[key]]
for key in session_output_names
if "key_values" in key or ".key" in key or ".value" in key
)
for input_name, past_key_value in zip(key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value
return onnx_inputs
def run_model(ort_session, onnx_inputs, profile=False, device="cpu"):
if profile:
for _ in tqdm(range(500)):
ort_session.run(None, onnx_inputs)
return ort_session.end_profiling()
else:
return ort_session.run(None, onnx_inputs)
def run_encoder(path, profile=False, device="cpu", quantize=False):
model_path = f"{path}/encoder_model.onnx"
if quantize:
model_path = f"{path}/encoder_model_quantized.onnx"
ort_session = get_session(model_path, profile=profile, device=device)
return run_model(
ort_session, generate_encoder_input(path), profile=profile, device=device
)
def run_decoder(path, profile=False, device="cpu", quantize=False):
model_path = f"{path}/decoder_model.onnx"
if quantize:
model_path = f"{path}/decoder_model_quantized.onnx"
ort_session = get_session(model_path, profile=profile, device=device)
return run_model(
ort_session, generate_decoder_input(path), profile=profile, device=device
)
def run_decoder_with_past(path, profile=False, device="cpu", quantize=False):
model_path = f"{path}/decoder_with_past_model.onnx"
if quantize:
model_path = f"{path}/decoder_with_past_model_quantized.onnx"
ort_session = get_session(model_path, profile=profile, device=device)
session_inputs = {
output_key.name: idx for idx, output_key in enumerate(ort_session.get_inputs())
}
session_outputs = {
output_key.name: idx for idx, output_key in enumerate(ort_session.get_outputs())
}
session_input_names = list(session_inputs.keys())
session_output_names = list(session_outputs.keys())
key_value_input_names = [
key
for key in session_input_names
if ("key_values" in key or ".key" in key or ".value" in key)
]
onnx_inputs = generate_decoder_with_past_input(
path, session_outputs, session_output_names, key_value_input_names
)
return run_model(ort_session, onnx_inputs, profile=profile, device=device)
def save_profile(json_path, model_name, device="cpu"):
with open(json_path, "r") as f:
js = json.load(f)
def process_profiling(js):
"""
Flattens json returned by onnxruntime profiling.
:param js: json
:return: list of dictionaries
"""
rows = []
for row in js:
if "args" in row and isinstance(row["args"], dict):
for k, v in row["args"].items():
row[f"args_{k}"] = v
del row["args"]
rows.append(row)
return rows
df = pd.DataFrame(process_profiling(js))
gr_dur = (
df[["dur", "args_op_name"]].groupby("args_op_name").sum().sort_values("dur")
)
gr_n = (
df[["dur", "args_op_name"]].groupby("args_op_name").count().sort_values("dur")
)
gr_n = gr_n.loc[gr_dur.index, :]
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
gr_dur.plot.barh(ax=ax[0])
gr_n.plot.barh(ax=ax[1])
ax[0].set_title("duration")
ax[1].set_title("n occurences")
save_path = f"{time.strftime('%Y%m%d-h%Hm%Ms%S')}_profile_{model_name}_{device}.png"
plt.savefig(os.path.join(os.path.dirname("results_profile"), save_path))
return save_path
def profile_encoder(path, device="cpu", quantize=False):
model_name = "encoder_model"
json_path = run_encoder(path, profile=True, device=device, quantize=quantize)
print(save_profile(json_path, model_name, device))
def profile_decoder(path, device="cpu", quantize=False):
model_name = "decoder_model"
json_path = run_decoder(path, profile=True, device=device, quantize=quantize)
print(save_profile(json_path, model_name, device))
def profile_decoder_with_past_model(path, device="cpu", quantize=False):
model_name = "decoder_with_past_model"
json_path = run_decoder_with_past(
path, profile=True, device=device, quantize=quantize
)
print(save_profile(json_path, model_name, device))
def save_model(quantize=False):
from pathlib import Path
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
model = ORTModelForSpeechSeq2Seq.from_pretrained(
model_name, from_transformers=True
)
onnx_path = Path("results_speech_seq2seq")
model.save_pretrained(onnx_path)
if quantize:
# ORT_FULLY_CONNECTED_OPERATORS = ["Add", "Div", "Gather", "MatMul", "Mul", "Reshape", "Transpose", "Conv"]
ORT_FULLY_CONNECTED_OPERATORS = ["Add", "Conv"]
from optimum.onnxruntime.configuration import QuantizationConfig
from optimum.onnxruntime.quantization import (
QuantFormat,
QuantizationMode,
QuantType,
)
from optimum.onnxruntime import ORTQuantizer
# Load the quantization configuration detailing the quantization we wish to apply
# dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False, operators_to_quantize=ORT_FULLY_CONNECTED_OPERATOR>
dqconfig = QuantizationConfig(
is_static=False,
format=QuantFormat.QOperator,
mode=QuantizationMode.IntegerOps,
activations_dtype=QuantType.QUInt8,
activations_symmetric=False,
weights_dtype=QuantType.QUInt8,
weights_symmetric=True,
per_channel=False,
reduce_range=False,
nodes_to_quantize=[],
nodes_to_exclude=[],
operators_to_quantize=ORT_FULLY_CONNECTED_OPERATORS,
)
encoder_quantizer = ORTQuantizer.from_pretrained(
onnx_path, file_name="encoder_model.onnx"
)
decoder_quantizer = ORTQuantizer.from_pretrained(
onnx_path, file_name="decoder_model.onnx"
)
decoder_wp_quantizer = ORTQuantizer.from_pretrained(
onnx_path, file_name="decoder_with_past_model.onnx"
)
quantizer = [encoder_quantizer, decoder_quantizer, decoder_wp_quantizer]
# Apply dynamic quantization and save the resulting model
[
q.quantize(save_dir=onnx_path, quantization_config=dqconfig)
for q in quantizer
]
# save_model(quantize=True)
# print(run_encoder("results_speech_seq2seq/", quantize=True))
# profile_encoder("results_speech_seq2seq/", quantize=False)
# profile_decoder("results_speech_seq2seq/", quantize=False)
# profile_decoder_with_past_model("results_speech_seq2seq/", quantize=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment