Skip to content

Instantly share code, notes, and snippets.

@afrendeiro
Last active March 9, 2023 15:36
Show Gist options
  • Save afrendeiro/23c591606437dfd668673928fe4cc654 to your computer and use it in GitHub Desktop.
Save afrendeiro/23c591606437dfd668673928fe4cc654 to your computer and use it in GitHub Desktop.
Run experiments on different quantization settings and report speedup and error.
"""
Run the experiments for the quantization speedup and plot results.
"""
import subprocess
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
models = [
"alexnet",
"vgg11",
"vgg16",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"densenet121",
"densenet161",
"densenet201",
"inception_v3",
"shufflenet_v2_x0_5",
"shufflenet_v2_x1_5",
"convnext_tiny",
"convnext_small",
"convnext_base",
"convnext_large",
"vit_b_32",
"vit_l_32",
"maxvit_t",
]
_res = list()
for qdtype in ["qint8"]:
for n_batch in 2 ** np.arange(10):
for model in models:
print(f"Running {model} with batch size {n_batch}.")
resp = subprocess.check_output(
[
"python",
"quantization_experiments.run.py",
"--model",
model,
"--n-batch",
str(n_batch),
"--q-dtype",
qdtype,
]
)
t = resp.decode("utf-8").strip().split("\n")
s = pd.Series(eval(t[1]), eval(t[0]))
_res.append(s)
res = pd.DataFrame(_res).convert_dtypes()
res["speedup_dynamic"] = res["time_original"] / res["time_quantized_dynamic"]
res["speedup_static"] = res["time_original"] / res["time_quantized_static"]
res.to_csv("quantization_speedup.csv", index=False)
res = pd.read_csv("quantization_speedup.csv").drop("q_dtype", axis=1)
colors = dict(zip(models, sns.color_palette("tab20", len(models))))
colors["maxvit_t"] = (0.1, 0.1, 0.1)
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
for axs, var in zip(axes.T, ["speedup_static", "speedup_dynamic"]):
for i, ax in enumerate(axs):
ax.axhline(1, color="black", linestyle="--")
for model in models:
p = res.query(f"model == '{model}'").sort_values("n_batch")
ax.plot(p["n_batch"], p[var], "-o", label=model, color=colors[model])
ax.set(
xlabel="Batch size",
ylabel=var,
# xscale="log",
# yscale="symlog",
)
if i == 0:
ax.set(title=var.replace("speedup_", "") + " quantization")
ax.set_xscale("log", base=2)
# place legend on the right side of plot
if var == "speedup_dynamic":
ax.legend(loc="upper left", bbox_to_anchor=(1.1, 1.1))
fig.suptitle("Quantization speedup")
fig.savefig("quantization_speedup.svg", dpi=300, bbox_inches="tight")
res = (
pd.read_csv("quantization_speedup.csv").drop("q_dtype", axis=1).query("n_batch > 4")
)
g = res.groupby("model").mean()
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
for ax, var in zip(axes.T, ["speedup_static", "speedup_dynamic"]):
ax.scatter(g["num_params"], g[var])
for t in g.index:
ax.text(g.loc[t, "num_params"], g.loc[t, var], t, fontsize=8)
ax.set(
xlabel="Model size",
ylabel="Mean speedup",
title=var.replace("speedup_", "") + " quantization",
)
fig.savefig("quantization_speedup_vs_complexity.svg", dpi=300, bbox_inches="tight")
res = pd.read_csv("quantization_speedup.csv").drop("q_dtype", axis=1)
fig, ax = plt.subplots(figsize=(7, 4))
g["fold_diff"] = np.log(g["speedup_dynamic"] / g["speedup_static"])
g = g.sort_values("fold_diff", ascending=False)
ax.scatter(g["fold_diff"], g.index)
ax.axvline(0, linestyle="--", color="black")
ax.set(xlabel="log(dynamic speedup / static speedup)", ylabel="Model")
fig.savefig("quantization_dynamic_vs_static.svg", dpi=300, bbox_inches="tight")
res = pd.read_csv("quantization_speedup.csv").drop("q_dtype", axis=1)
res["output_size"] /= res["n_batch"]
res["mae_quantized_static"] /= res["n_batch"]
res["mae_quantized_dynamic"] /= res["n_batch"]
g = res.groupby("model").mean()
fig, axes = plt.subplots(2, 2, figsize=(14, 8), sharex="col", sharey=True)
for i, axs in enumerate(axes):
if i == 1:
g = g.loc[g["mae_quantized_static"] < 3]
else:
for ax in axs:
ax.set_yscale("log")
for ax, x, y in zip(
axs.T,
["speedup_static", "speedup_dynamic"],
["mae_quantized_static", "mae_quantized_dynamic"],
):
g[y] = g[y].replace(0, 1e-12)
ax.scatter(g[x], g[y])
for t in g.index:
ax.text(g.loc[t, x], g.loc[t, y], t, fontsize=8)
ax.set(
xlabel="Speedup",
ylabel="Mean absolute error",
title=x.replace("speedup_", "") + " quantization"
# xscale="log",
# yscale="symlog",
)
fig.savefig("quantization_speedup_vs_error.svg", dpi=300, bbox_inches="tight")
"""
Run a model with different quantization settings and report timing and accuracy.
"""
from timeit import default_timer as timer
import click
import torch
import torch.nn as nn
import torchvision
from torch.ao.quantization import QConfigMapping
import torch.quantization.quantize_fx as quantize_fx
import copy
qlayers = {nn.Linear}
@click.command()
@click.option("--model", type=str)
@click.option("--n-batch", type=int)
@click.option("--q-dtype", type=str)
def profile(model, n_batch, q_dtype):
img = torch.rand(n_batch, 3, 224, 224)
# load the model
m = getattr(torchvision.models, model)().eval()
# get number of parameters
num_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
# Vanilla inference
start = timer()
with torch.no_grad():
pred_original = m(img)
time_original = timer() - start
# Dynamically quantize model to qint8
dtype = getattr(torch, q_dtype)
m = getattr(torchvision.models, model)().eval()
m = torch.quantization.quantize_dynamic(m, qlayers, dtype=dtype).eval()
# Quantized inference (dynamic)
start = timer()
with torch.no_grad():
pred_quantized = m(img)
time_quantized_dynamic = timer() - start
# Statically quantize model to qint8
model_to_quantize = copy.deepcopy(m).eval()
qconfig_mapping = QConfigMapping().set_global(
torch.quantization.get_default_qconfig("qnnpack")
)
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, img)
m = quantize_fx.convert_fx(model_prepared)
# Quantized inference (static)
start = timer()
with torch.no_grad():
pred_quantized_graph = m(img)
time_quantized_static = timer() - start
# Compute error
mae_quantized_dynamic = torch.abs(pred_original - pred_quantized).mean().item()
mae_quantized_static = torch.abs(pred_original - pred_quantized_graph).mean().item()
labels = [
"model",
"num_params",
"n_batch",
"q_dtype",
"output_size",
"time_original",
"time_quantized_dynamic",
"time_quantized_static",
"mae_quantized_dynamic",
"mae_quantized_static",
]
res = [
model,
num_params,
n_batch,
q_dtype,
pred_original.shape.numel(),
time_original,
time_quantized_dynamic,
time_quantized_static,
mae_quantized_dynamic,
mae_quantized_static,
]
print(labels)
print(res)
return res
if __name__ == "__main__":
profile()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment