Last active
June 16, 2021 11:42
-
-
Save spezold/ab49e8f396853b2ff18ef7623ee6c55b to your computer and use it in GitHub Desktop.
**Update: have a look at torch.package instead** (https://pytorch.org/docs/1.9.0/package.html) -- Original description: Save and load a PyTorch model (both code and weights): minimum working example, based on the inner workings of TorchServe.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import importlib.util | |
import inspect | |
import json | |
from pathlib import Path | |
from typing import Optional, Union | |
import zipfile | |
import torch | |
from torch import nn | |
def _name_from(path: Optional[Union[str, Path]]) -> Optional[str]: return None if path is None else Path(path).name | |
def _pymodule_from(path: Path) -> object: | |
"""Load and return the python module from the given ``*.py`` file path.""" | |
# Following https://stackoverflow.com/questions/67631/ (20201008) | |
spec = importlib.util.spec_from_file_location(name=path.stem, location=path) | |
pymodule = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(pymodule) | |
return pymodule | |
def _model_class_from(pymodule: object, name: Optional[str]) -> object: | |
""" | |
Load and return the model class (``torch.nn.Module`` subclass) with the given name from the given module (name is | |
not necessary if there is only one ``torch.nn.Module`` subclass in the given module). | |
""" | |
# Following TorchServe: ``ts.utils.util.list_classes_from_module()`` (20201008) | |
predicate = lambda m: (inspect.isclass(m) and m.__module__ == pymodule.__name__ and issubclass(m, nn.Module) and | |
(name is None or m.__name__ == name)) | |
classes = [c[1] for c in inspect.getmembers(pymodule, predicate)] | |
if not classes: | |
raise ValueError("No ``torch.nn.Module`` subclass " + ("" if name is None else f"named '{name}' ") + | |
f"found in given module (module '{pymodule.__name__}' in '{pymodule.__file__}')") | |
elif len(classes) > 1: | |
raise ValueError(f"Multiple subclasses of ``torch.nn.Module`` found in given module " | |
f"(module '{pymodule.__name__}' in '{pymodule.__file__}'; " | |
f"candidates are {', '.join(c.__name__ for c in classes)})") | |
return classes[0] | |
def save(*, | |
name: str, | |
version: str, | |
code_path: Path, | |
archive_path: Path, | |
class_name: Optional[str] = None, | |
weights_path: Optional[Path] = None | |
): | |
""" | |
Save the model (both code and, if provided, weights) to a zip archive (similar to but not the same as TorchServe's | |
``*.mar`` archive). | |
:param name: name of the model | |
:param version: version of the model | |
:param code_path: ``*.py`` file that contains the model class as a ``torch.nn.Module`` subclass | |
:param archive_path: path to which to write the model's zip archive | |
:param class_name: optional name of the model class (necessary if there are multiple ``torch.nn.Module`` subclasses | |
in the code file) | |
:param weights_path: optional model weights, need to be loadable via | |
``torch.nn.Module.load_state_dict(torch.load(...))`` | |
""" | |
manifest_data = { # "name" and "version" are only informative in this example | |
"name": name, | |
"version": version, | |
"code_file": _name_from(code_path), | |
"weights_file": _name_from(weights_path), | |
"class_name": class_name | |
} | |
with zipfile.ZipFile(archive_path, mode="w") as archive: | |
archive.writestr("META-INF/MANIFEST.json", data=json.dumps(manifest_data, indent=2)) | |
archive.write(code_path, arcname=_name_from(code_path)) | |
if weights_path: | |
archive.write(weights_path, arcname=_name_from(weights_path)) | |
def load(*, archive_path: Path, extract_dir: Optional[Path] = None) -> nn.Module: | |
""" | |
Load a model from a zip archive created with ``save()``, instantiate it and, if provided in the archive, load the | |
model's weights. | |
:param archive_path: directory from which to load the model archive | |
:param extract_dir: optional directory into which to extract the model archive (use the archive's parent directory | |
if not given; in any case, create a new subdirectory for the archive content) | |
:return: instance of the loaded model | |
""" | |
pymodule_dir = (archive_path.parent if extract_dir is None else extract_dir) / archive_path.stem | |
with zipfile.ZipFile(archive_path, mode="r") as archive: | |
archive.extractall(pymodule_dir) | |
manifest_data = json.loads((pymodule_dir / "META-INF" / "MANIFEST.json").read_text(encoding="utf-8")) | |
pymodule = _pymodule_from(pymodule_dir / manifest_data["code_file"]) | |
model_class = _model_class_from(pymodule, manifest_data["class_name"]) | |
model = model_class() | |
weights_file = manifest_data["weights_file"] | |
if weights_file: | |
model.load_state_dict(torch.load(pymodule_dir / weights_file)) | |
model.eval() | |
return model | |
if __name__ == "__main__": | |
from hashlib import sha256 | |
from textwrap import dedent | |
import urllib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
from scipy.misc import face | |
from torch.nn.functional import softmax | |
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor | |
base_dir = Path(__file__).parent | |
name, version = "mymodel", "0.0.1" | |
code = base_dir / "model.py" | |
weights = base_dir / "densenet161-8d451a50.pth" | |
archive = base_dir / "myarchive.zip" | |
index_to_name = base_dir / "index_to_name.json" # Mapping from class index to human-readable class label | |
# Create the model code and save it to disk on the fly (in an actual project we would not do that) -- code from | |
# TorchServe: examples\image_classifier\densenet_161\model.py (20201009) | |
code.write_text(dedent( | |
""" | |
import re | |
from torchvision.models.densenet import DenseNet | |
class ImageClassifier(DenseNet): | |
def __init__(self): | |
super().__init__(48, (6, 12, 36, 24), 96) | |
def load_state_dict(self, state_dict, strict=True): | |
pattern = re.compile(r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$") | |
for key in list(state_dict.keys()): | |
res = pattern.match(key) | |
if res: | |
new_key = res.group(1) + res.group(2) | |
state_dict[new_key] = state_dict[key] | |
del state_dict[key] | |
return super().load_state_dict(state_dict, strict) | |
""" | |
)) | |
# Get the trained weights, if necessary | |
if not weights.exists(): | |
print("Download weights ...") | |
url = r"https://download.pytorch.org/models/" + weights.name | |
urllib.request.urlretrieve(url, filename=weights) | |
assert sha256(weights.read_bytes()).hexdigest()[:10] == "8d451a50ba" | |
# Get the index to label mapping, if necessary | |
if not index_to_name.exists(): | |
print("Download mapping of class indices to class labels ...") | |
url = r"https://raw.githubusercontent.com/pytorch/serve/master/examples/image_classifier/index_to_name.json" | |
urllib.request.urlretrieve(url, filename=index_to_name) | |
assert sha256(index_to_name.read_bytes()).hexdigest()[:10] == "a1e7a966a1" | |
name_for_class_index = json.loads(index_to_name.read_text(encoding="utf-8")) | |
print(f"Archive {name} v{version} to {archive.name} ...") | |
save(name=name, version=version, code_path=code, archive_path=archive, weights_path=weights) | |
print(f"Reload {name} v{version} from {archive.name} ...") | |
model = load(archive_path=archive) | |
print("\nClassify demo image ``scipy.misc.face()`` ...") | |
input_image = Image.fromarray(face()) # Actually a racoon, but racoon is not among the trained classes | |
prep = Compose([Resize(256), CenterCrop(224), ToTensor(), Normalize(mean=[.49, .46, .41], std=[.23, .22, .23])]) | |
labels = model(prep(input_image).unsqueeze_(0)) | |
labels = softmax(labels.squeeze(0), dim=0).cpu().detach().numpy() | |
class_indices = np.argsort(labels)[::-1] # Class indices in order of their probability | |
names_and_probs = [(name_for_class_index[f"{idx}"][1], 100 * labels[idx]) for idx in class_indices] | |
print("\n".join(f"{name} ({prob:.2f}%)" for name, prob in names_and_probs[:5])) | |
plt.imshow(input_image) | |
plt.title(f"{names_and_probs[0][0]} ({names_and_probs[0][1]:.2f}%)") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
What actually happens:
model.py
(created on the fly for simplicity)densenet161-8d451a50.pth
(downloaded from pytorch.org, if necessary)index_to_name.json
(downloaded from the TorchServe repository, if necessary)myarchive.zip
scipy.misc.face()
), and the five most probable classes are shown. Unfortunately the racoon on the loaded demo image is misclassified as "badger" – but this is because racoon is not among the trained classes.The essential functions here are
save()
(to archive the model) andload()
(to load and instantiate the model).