Last active
November 3, 2021 03:20
-
-
Save nateraw/07d5e87fa33b0e1b6085306a27e351a7 to your computer and use it in GitHub Desktop.
upload_all_timm_models.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "upload_all_timm_models.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyPmU7IabvlPCzS76HwOsf+4", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/nateraw/07d5e87fa33b0e1b6085306a27e351a7/upload_all_timm_models.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "UHCTzrWJBGNT" | |
}, | |
"source": [ | |
"%%capture\n", | |
"! pip install git+https://github.com/nateraw/pytorch-image-models.git@hf-save-and-push --upgrade\n", | |
"! pip install huggingface_hub\n", | |
"! apt install git-lfs\n", | |
"! git config --global credential.helper store" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RarQEE-jBRoQ" | |
}, | |
"source": [ | |
"! huggingface-cli login" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YvZL0rO-BBEN" | |
}, | |
"source": [ | |
"import logging\n", | |
"import requests\n", | |
"import tempfile\n", | |
"from typing import Optional, List\n", | |
"from pathlib import Path\n", | |
"\n", | |
"import timm\n", | |
"import torch\n", | |
"\n", | |
"logging.basicConfig()\n", | |
"logger = logging.getLogger(__name__)\n", | |
"logger.setLevel(logging.INFO)\n", | |
"\n", | |
"IMAGENET_21k_URL = 'https://storage.googleapis.com/bit_models/imagenet21k_wordnet_lemmas.txt'\n", | |
"IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'\n", | |
"IMAGENET_11k_URL = 'http://data.mxnet.io.s3-website-us-west-1.amazonaws.com/models/imagenet-11k/synset.txt'\n", | |
"\n", | |
"IMAGENET_21k_LABELS = requests.get(IMAGENET_21k_URL).text.strip().split('\\n')\n", | |
"IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\\n')\n", | |
"IMAGENET_11k_LABELS = [x.split()[1] for x in requests.get(IMAGENET_11k_URL).text.strip().split('\\n')]\n", | |
"\n", | |
"\n", | |
"def model_repo_exists(model_id):\n", | |
" url = f'https://huggingface.co/{model_id}'\n", | |
" r = requests.get(url)\n", | |
" return r.status_code == 200\n", | |
"\n", | |
"def main(repo_owner, models_to_push: Optional[List[str]] = None, limit=None):\n", | |
"\n", | |
" models_to_push = (models_to_push or timm.list_models(pretrained=True))[:limit]\n", | |
"\n", | |
" for i, model_name in enumerate(models_to_push):\n", | |
" model_id = f'{repo_owner}/{model_name}'\n", | |
" if model_repo_exists(model_id):\n", | |
" logger.info(f\"👍 Skipping model {i + 1}/{len(models_to_push)}: {model_name}, as it already exists\")\n", | |
" continue\n", | |
"\n", | |
" logger.info(f\"⤴️ Uploading model {i + 1}/{len(models_to_push)}: {model_name}\")\n", | |
" with tempfile.TemporaryDirectory() as tempdir:\n", | |
" root = Path(tempdir)\n", | |
"\n", | |
" # Set torchhub dir to tempdir so it gets cleaned up automatically\n", | |
" torch.hub.set_dir(root / 'torchhub_cachedir/')\n", | |
"\n", | |
" # Load pretrained model\n", | |
" model = timm.create_model(model_name, pretrained=True)\n", | |
"\n", | |
" # Try to resolve model labels\n", | |
" if model.num_classes == 1000:\n", | |
" labels = IMAGENET_1k_LABELS\n", | |
" elif model.num_classes == 21843:\n", | |
" labels = IMAGENET_21k_LABELS\n", | |
" elif num_classes == 11221:\n", | |
" labels = IMAGENET_11k_LABELS\n", | |
" else:\n", | |
" logger.warn(\"🚨 Unable to link labels to known list of label names.\")\n", | |
"\n", | |
" # Push to 🤗 hub \n", | |
" timm.models.hub.push_to_hf_hub(\n", | |
" model,\n", | |
" root / model_name,\n", | |
" repo_namespace_or_url=model_id,\n", | |
" labels=labels\n", | |
" )\n", | |
"\n", | |
"main(repo_owner='nates-test-org')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment