Last active
January 19, 2024 18:20
-
-
Save PortNumber53/28beca927b1cd5c036efe1fd31e32202 to your computer and use it in GitHub Desktop.
Downloads HuggingFace models.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
from dotenv import load_dotenv | |
from huggingface_hub import snapshot_download | |
load_dotenv() | |
def download_model(repo_id, use_auth_token=True): | |
models_base_folder = os.getenv("MODELS_BASE_FOLDER") | |
models_cache_folder = os.getenv("MODELS_CACHE_FOLDER") | |
if not models_base_folder or not models_cache_folder: | |
raise ValueError("MODELS_BASE_FOLDER or MODELS_CACHE_FOLDER environment variable is not set.") | |
repo_id_parts = repo_id.split("/") | |
model_folder = f"{repo_id_parts[0]}_{repo_id_parts[1]}" # Assuming the first two parts of the repo_id form the model folder | |
local_path = os.path.join(models_base_folder, model_folder) | |
downloaded_model_path = snapshot_download( | |
repo_id=repo_id, | |
use_auth_token=use_auth_token, | |
cache_dir=models_cache_folder, | |
local_dir=local_path, | |
max_workers=1 | |
) | |
print(downloaded_model_path) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Download model from Hugging Face Hub.") | |
parser.add_argument("--repo_id", required=True, help="Hugging Face repository ID") | |
args = parser.parse_args() | |
download_model(repo_id=args.repo_id) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment