Last active
February 21, 2021 19:46
-
-
Save adriangb/996999e708b0ec38343268e5c4fef908 to your computer and use it in GitHub Desktop.
SaveModel vs. Pickle (via SaveModel)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tempfile | |
from timeit import default_timer, timeit | |
from typing import List, Tuple | |
from matplotlib import pyplot as plt | |
import numpy as np | |
from tensorflow import keras | |
from tensorflow.keras.models import load_model | |
from scikeras._saving_utils import pack_keras_model | |
def get_model(n_hidden: int) -> keras.Sequential: | |
inp = x = keras.Input((64,)) # arbitrary dim choice | |
for _ in range(n_hidden): | |
x = keras.layers.Dense(64)(x) # arbitrary dim choice | |
model = keras.Model(inp, x) | |
model.compile(loss="mse") # arbitrary loss choice | |
return model | |
def roundtrip_pickle(model: keras.Model): | |
packed = pack_keras_model(model) | |
model = packed[0](*packed[1]) # see pickle protocol for details | |
def roundtrip_savemodel(model: keras.Model): | |
tmpdir = f"ram://{tempfile.mkdtemp()}" | |
model.save(tmpdir) | |
model = load_model(tmpdir) | |
def bench_layers( | |
n_hidden: int, | |
) -> Tuple[float, float, float, float]: | |
model = get_model(n_hidden) | |
# get any type of caching/optimization out of the way | |
roundtrip_savemodel(model) | |
pickle_times = [] | |
for repeat in range(10): | |
start = default_timer() | |
roundtrip_pickle(model) | |
end = default_timer() | |
pickle_times.append(end - start) | |
savemodel_times = [] | |
for repeat in range(10): | |
start = default_timer() | |
roundtrip_savemodel(model) | |
end = default_timer() | |
savemodel_times.append(end - start) | |
return ( | |
np.mean(pickle_times), | |
np.std(pickle_times), | |
np.mean(savemodel_times), | |
np.std(savemodel_times), | |
) | |
def get_mobilenet() -> keras.Model: | |
model = keras.applications.MobileNetV3Small(minimalistic=True) | |
model.compile(loss="sparse_categorical_crossentropy") # arbitrary loss choice | |
return model | |
def bench_mobilenet(repeats: int) -> Tuple[List[float], List[float]]: | |
model = get_mobilenet() | |
# get any type of caching/optimization out of the way | |
roundtrip_savemodel(model) | |
pickle_times = [] | |
for _ in range(repeats): | |
start = default_timer() | |
roundtrip_pickle(model) | |
end = default_timer() | |
pickle_times.append(end - start) | |
savemodel_times = [] | |
for _ in range(repeats): | |
start = default_timer() | |
roundtrip_savemodel(model) | |
end = default_timer() | |
savemodel_times.append(end - start) | |
return (pickle_times, savemodel_times) | |
n_hiddens = [1, 5, 10, 25, 50, 100] | |
pickle_times = [] | |
pickle_std = [] | |
savemodel_times = [] | |
savemodel_std = [] | |
for n_hidden in n_hiddens: | |
pt, ps, st, ss = bench_layers(n_hidden) | |
pickle_times.append(pt) | |
pickle_std.append(ps) | |
savemodel_times.append(st) | |
savemodel_std.append(ss) | |
time_pickle, time_savemodel = bench_mobilenet(30) | |
fig, (ax1, ax2) = plt.subplots(ncols=2) | |
ax1.errorbar( | |
np.array(n_hiddens) - n_hiddens[-1]/50, | |
pickle_times, | |
yerr=pickle_std, | |
label="Pickle", | |
fmt="o", | |
color="blue", | |
ecolor="cornflowerblue", | |
elinewidth=3, | |
capsize=0, | |
) | |
ax1.errorbar( | |
np.array(n_hiddens) + n_hiddens[-1]/50, | |
savemodel_times, | |
yerr=savemodel_std, | |
label="SaveModel", | |
fmt="o", | |
color="red", | |
ecolor="lightcoral", | |
elinewidth=3, | |
capsize=0, | |
) | |
ax1.legend() | |
ax1.set_xlabel("Number of hidden layers") | |
ax1.set_ylabel("Roundterip time (s)") | |
ax2.boxplot([time_pickle, time_savemodel], labels=["Pickle", "SaveModel"]) | |
ax2.set_ylabel("Roundtrip time (s)") | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment