Last active
March 9, 2023 15:36
-
-
Save afrendeiro/23c591606437dfd668673928fe4cc654 to your computer and use it in GitHub Desktop.
Run experiments on different quantization settings and report speedup and error.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Run the experiments for the quantization speedup and plot results. | |
""" | |
import subprocess | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
models = [ | |
"alexnet", | |
"vgg11", | |
"vgg16", | |
"resnet18", | |
"resnet34", | |
"resnet50", | |
"resnet101", | |
"resnet152", | |
"densenet121", | |
"densenet161", | |
"densenet201", | |
"inception_v3", | |
"shufflenet_v2_x0_5", | |
"shufflenet_v2_x1_5", | |
"convnext_tiny", | |
"convnext_small", | |
"convnext_base", | |
"convnext_large", | |
"vit_b_32", | |
"vit_l_32", | |
"maxvit_t", | |
] | |
_res = list() | |
for qdtype in ["qint8"]: | |
for n_batch in 2 ** np.arange(10): | |
for model in models: | |
print(f"Running {model} with batch size {n_batch}.") | |
resp = subprocess.check_output( | |
[ | |
"python", | |
"quantization_experiments.run.py", | |
"--model", | |
model, | |
"--n-batch", | |
str(n_batch), | |
"--q-dtype", | |
qdtype, | |
] | |
) | |
t = resp.decode("utf-8").strip().split("\n") | |
s = pd.Series(eval(t[1]), eval(t[0])) | |
_res.append(s) | |
res = pd.DataFrame(_res).convert_dtypes() | |
res["speedup_dynamic"] = res["time_original"] / res["time_quantized_dynamic"] | |
res["speedup_static"] = res["time_original"] / res["time_quantized_static"] | |
res.to_csv("quantization_speedup.csv", index=False) | |
res = pd.read_csv("quantization_speedup.csv").drop("q_dtype", axis=1) | |
colors = dict(zip(models, sns.color_palette("tab20", len(models)))) | |
colors["maxvit_t"] = (0.1, 0.1, 0.1) | |
fig, axes = plt.subplots(2, 2, figsize=(16, 10)) | |
for axs, var in zip(axes.T, ["speedup_static", "speedup_dynamic"]): | |
for i, ax in enumerate(axs): | |
ax.axhline(1, color="black", linestyle="--") | |
for model in models: | |
p = res.query(f"model == '{model}'").sort_values("n_batch") | |
ax.plot(p["n_batch"], p[var], "-o", label=model, color=colors[model]) | |
ax.set( | |
xlabel="Batch size", | |
ylabel=var, | |
# xscale="log", | |
# yscale="symlog", | |
) | |
if i == 0: | |
ax.set(title=var.replace("speedup_", "") + " quantization") | |
ax.set_xscale("log", base=2) | |
# place legend on the right side of plot | |
if var == "speedup_dynamic": | |
ax.legend(loc="upper left", bbox_to_anchor=(1.1, 1.1)) | |
fig.suptitle("Quantization speedup") | |
fig.savefig("quantization_speedup.svg", dpi=300, bbox_inches="tight") | |
res = ( | |
pd.read_csv("quantization_speedup.csv").drop("q_dtype", axis=1).query("n_batch > 4") | |
) | |
g = res.groupby("model").mean() | |
fig, axes = plt.subplots(1, 2, figsize=(14, 4)) | |
for ax, var in zip(axes.T, ["speedup_static", "speedup_dynamic"]): | |
ax.scatter(g["num_params"], g[var]) | |
for t in g.index: | |
ax.text(g.loc[t, "num_params"], g.loc[t, var], t, fontsize=8) | |
ax.set( | |
xlabel="Model size", | |
ylabel="Mean speedup", | |
title=var.replace("speedup_", "") + " quantization", | |
) | |
fig.savefig("quantization_speedup_vs_complexity.svg", dpi=300, bbox_inches="tight") | |
res = pd.read_csv("quantization_speedup.csv").drop("q_dtype", axis=1) | |
fig, ax = plt.subplots(figsize=(7, 4)) | |
g["fold_diff"] = np.log(g["speedup_dynamic"] / g["speedup_static"]) | |
g = g.sort_values("fold_diff", ascending=False) | |
ax.scatter(g["fold_diff"], g.index) | |
ax.axvline(0, linestyle="--", color="black") | |
ax.set(xlabel="log(dynamic speedup / static speedup)", ylabel="Model") | |
fig.savefig("quantization_dynamic_vs_static.svg", dpi=300, bbox_inches="tight") | |
res = pd.read_csv("quantization_speedup.csv").drop("q_dtype", axis=1) | |
res["output_size"] /= res["n_batch"] | |
res["mae_quantized_static"] /= res["n_batch"] | |
res["mae_quantized_dynamic"] /= res["n_batch"] | |
g = res.groupby("model").mean() | |
fig, axes = plt.subplots(2, 2, figsize=(14, 8), sharex="col", sharey=True) | |
for i, axs in enumerate(axes): | |
if i == 1: | |
g = g.loc[g["mae_quantized_static"] < 3] | |
else: | |
for ax in axs: | |
ax.set_yscale("log") | |
for ax, x, y in zip( | |
axs.T, | |
["speedup_static", "speedup_dynamic"], | |
["mae_quantized_static", "mae_quantized_dynamic"], | |
): | |
g[y] = g[y].replace(0, 1e-12) | |
ax.scatter(g[x], g[y]) | |
for t in g.index: | |
ax.text(g.loc[t, x], g.loc[t, y], t, fontsize=8) | |
ax.set( | |
xlabel="Speedup", | |
ylabel="Mean absolute error", | |
title=x.replace("speedup_", "") + " quantization" | |
# xscale="log", | |
# yscale="symlog", | |
) | |
fig.savefig("quantization_speedup_vs_error.svg", dpi=300, bbox_inches="tight") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Run a model with different quantization settings and report timing and accuracy. | |
""" | |
from timeit import default_timer as timer | |
import click | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from torch.ao.quantization import QConfigMapping | |
import torch.quantization.quantize_fx as quantize_fx | |
import copy | |
qlayers = {nn.Linear} | |
@click.command() | |
@click.option("--model", type=str) | |
@click.option("--n-batch", type=int) | |
@click.option("--q-dtype", type=str) | |
def profile(model, n_batch, q_dtype): | |
img = torch.rand(n_batch, 3, 224, 224) | |
# load the model | |
m = getattr(torchvision.models, model)().eval() | |
# get number of parameters | |
num_params = sum(p.numel() for p in m.parameters() if p.requires_grad) | |
# Vanilla inference | |
start = timer() | |
with torch.no_grad(): | |
pred_original = m(img) | |
time_original = timer() - start | |
# Dynamically quantize model to qint8 | |
dtype = getattr(torch, q_dtype) | |
m = getattr(torchvision.models, model)().eval() | |
m = torch.quantization.quantize_dynamic(m, qlayers, dtype=dtype).eval() | |
# Quantized inference (dynamic) | |
start = timer() | |
with torch.no_grad(): | |
pred_quantized = m(img) | |
time_quantized_dynamic = timer() - start | |
# Statically quantize model to qint8 | |
model_to_quantize = copy.deepcopy(m).eval() | |
qconfig_mapping = QConfigMapping().set_global( | |
torch.quantization.get_default_qconfig("qnnpack") | |
) | |
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, img) | |
m = quantize_fx.convert_fx(model_prepared) | |
# Quantized inference (static) | |
start = timer() | |
with torch.no_grad(): | |
pred_quantized_graph = m(img) | |
time_quantized_static = timer() - start | |
# Compute error | |
mae_quantized_dynamic = torch.abs(pred_original - pred_quantized).mean().item() | |
mae_quantized_static = torch.abs(pred_original - pred_quantized_graph).mean().item() | |
labels = [ | |
"model", | |
"num_params", | |
"n_batch", | |
"q_dtype", | |
"output_size", | |
"time_original", | |
"time_quantized_dynamic", | |
"time_quantized_static", | |
"mae_quantized_dynamic", | |
"mae_quantized_static", | |
] | |
res = [ | |
model, | |
num_params, | |
n_batch, | |
q_dtype, | |
pred_original.shape.numel(), | |
time_original, | |
time_quantized_dynamic, | |
time_quantized_static, | |
mae_quantized_dynamic, | |
mae_quantized_static, | |
] | |
print(labels) | |
print(res) | |
return res | |
if __name__ == "__main__": | |
profile() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment