Skip to content

Instantly share code, notes, and snippets.

@ternaus
Last active August 22, 2020 22:48
Show Gist options
  • Save ternaus/8c4bdc5b3695e420db76874261092c1a to your computer and use it in GitHub Desktop.
Save ternaus/8c4bdc5b3695e420db76874261092c1a to your computer and use it in GitHub Desktop.
# https://github.com/ternaus/retinaface/blob/master/retinaface/pre_trained_models.py
from collections import namedtuple
from torch.utils import model_zoo
from retinaface.predict_single import Model
model = namedtuple("model", ["url", "model"])
models = {
"resnet50_2020-07-20": model(
url="https://github.com/ternaus/retinaface/releases/download/0.01/retinaface_resnet50_2020-07-20-f168fae3c.zip", # noqa: E501
model=Model,
)
}
def get_model(model_name: str, max_size: int, device: str = "cpu") -> Model:
model = models[model_name].model(max_size=max_size, device=device)
state_dict = model_zoo.load_url(models[model_name].url, progress=True, map_location="cpu")
model.load_state_dict(state_dict)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment