Last active
January 23, 2022 03:45
-
-
Save mihow/c22adf52fcb9c07ce67dcf4d2495cedd to your computer and use it in GitHub Desktop.
PyTorch model to JIT / TorchScript
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import pathlib | |
import tarfile | |
import tempfile | |
import torch | |
PROJECT_PATH = pathlib.Path(os.environ.get("PROJECT_PATH", ".")) | |
def export_model(model, classes, export_dir=None): | |
""" | |
Export model to pickled jit version. | |
@TODO add file to tarball with metadata about the model. | |
Version shape of image, who trained, friendly name, etc. | |
Perhaps include a link to the training page on W&B or CometML and our model zoo. | |
>>> data = get_data(256, 256, 0.9999) | |
>>> export_model(model, classes) | |
""" | |
timestamp = str(int(time.time())) | |
export_dir = export_dir or PROJECT_PATH / "models" | |
export_dir.mkdir(exist_ok=True, parents=True) | |
model_path = tempfile.NamedTemporaryFile(delete=False).name | |
classes_path = tempfile.NamedTemporaryFile(delete=False).name | |
tarball_path = export_dir / f"model-{timestamp}.tar.gz" | |
# "eval()" is needed to predict a single image? batch size of 1 | |
# https://discuss.pytorch.org/t/error-expected-more-than-1-value-per-channel-when-training/26274 | |
model_raw = ( | |
model.float().eval() | |
) | |
model_classes = ( | |
classes # This is just a 1-dimensional list of class names, ordered by index | |
) | |
# @TODO not sure if image size should be fixed or always use the current | |
# model's image size that we are exporting | |
# I've generally been using 3x256x256 | |
# channels, width, height = learner.data.single_ds[0][0].shape | |
channels, width, height = 3, 256, 256 | |
print("Example item shape:", channels, width, height) | |
example_input = torch.ones(1, channels, width, height) | |
if torch.cuda.is_available(): | |
example_input = example_input.cuda() | |
# Create jit model | |
print("Exporting model") | |
model_jit = torch.jit.trace(model_raw, example_input) | |
# Save model | |
torch.jit.save(model_jit, model_path) | |
# Save list of classes, ordered by index! | |
with open(classes_path, "w") as f: | |
for c in model_classes: | |
f.write(f"{c}\n") | |
# Create tar archive with the exported model and classes text file | |
with tarfile.open(tarball_path, "w:gz") as f: | |
f.add(model_path, arcname="model.pkl") # Don't save directories | |
f.add(classes_path, arcname="classes.txt") | |
print("Model and classes saved to", tarball_path) | |
return tarball_path |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment