Skip to content

Instantly share code, notes, and snippets.

@PortNumber53
Last active January 19, 2024 18:20
Show Gist options
  • Save PortNumber53/28beca927b1cd5c036efe1fd31e32202 to your computer and use it in GitHub Desktop.
Save PortNumber53/28beca927b1cd5c036efe1fd31e32202 to your computer and use it in GitHub Desktop.
Downloads HuggingFace models.
import argparse
import os
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
load_dotenv()
def download_model(repo_id, use_auth_token=True):
models_base_folder = os.getenv("MODELS_BASE_FOLDER")
models_cache_folder = os.getenv("MODELS_CACHE_FOLDER")
if not models_base_folder or not models_cache_folder:
raise ValueError("MODELS_BASE_FOLDER or MODELS_CACHE_FOLDER environment variable is not set.")
repo_id_parts = repo_id.split("/")
model_folder = f"{repo_id_parts[0]}_{repo_id_parts[1]}" # Assuming the first two parts of the repo_id form the model folder
local_path = os.path.join(models_base_folder, model_folder)
downloaded_model_path = snapshot_download(
repo_id=repo_id,
use_auth_token=use_auth_token,
cache_dir=models_cache_folder,
local_dir=local_path,
max_workers=1
)
print(downloaded_model_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download model from Hugging Face Hub.")
parser.add_argument("--repo_id", required=True, help="Hugging Face repository ID")
args = parser.parse_args()
download_model(repo_id=args.repo_id)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment