Skip to content

Instantly share code, notes, and snippets.

@daskol
Last active January 24, 2025 15:34
Show Gist options
  • Save daskol/8fa9e1c93e410526ab53d9d649fbaf19 to your computer and use it in GitHub Desktop.
Save daskol/8fa9e1c93e410526ab53d9d649fbaf19 to your computer and use it in GitHub Desktop.
Download a repo from HuggingFace Hub with aria2c.
#!/usr/bin/env python3
"""Little script for generating a download list for fetching model weights and
configuration files of a model from HuggingFace Hub. With download list is
ready, you can easily fetch all files with throatling and supspending or
resuming with `aria2c` as an example: aria2c -c -i index.txt.
"""
from argparse import ArgumentParser, Namespace
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_url
from huggingface_hub.constants import HF_TOKEN_PATH
from huggingface_hub.hf_api import DatasetInfo, ModelInfo
LIBRARY_NAME = 'hf2aria'
LIBRARY_VERSION = '0.0.0-5'
LIBRARY_KWARGS = {
'library_name': LIBRARY_NAME,
'library_version': LIBRARY_VERSION,
}
parser = ArgumentParser(description=__doc__)
parser.add_argument('-t',
'--repo-type',
default='model',
choices=('dataset', 'model', 'space'),
help='repository type (default: model)')
parser.add_argument('--output-dir',
default=Path(''),
type=Path,
help='where to store files (default: cwd)')
parser.add_argument('repo_id', help='model identifier in huggingface hub')
parser.add_argument('download_list',
default=Path('index.txt'),
type=Path,
nargs='?',
help='where to write download list (default: index.txt)')
def main(args: Namespace):
download_list = Path(args.download_list)
with open(download_list, 'w') as fout:
make_download_list(fout, args.repo_id, args.repo_type, args.output_dir)
path = Path(HF_TOKEN_PATH)
if path.is_relative_to(home := Path.home()):
path = path.relative_to(home)
command = ('aria2c', '-c', f'-i {download_list}',
f'--header="Authorization: Bearer $(cat {path})"')
print('Run the following command to download.\n\n ', ' '.join(command))
print()
def make_download_list(fout,
repo_id: str,
repo_type: str,
output_dir: Path,
hf_api=None):
if hf_api is None:
hf_api = HfApi(**LIBRARY_KWARGS)
repo_info: DatasetInfo | ModelInfo
match repo_type:
case 'dataset':
repo_info = hf_api.dataset_info(repo_id)
case 'model':
repo_info = hf_api.model_info(repo_id)
case _:
raise ValueError(f'Unexpected repository type {repo_type}.')
repo_dir = '--'.join([repo_type, *repo_id.split('/')]) # repo_folder_name
output_dir = output_dir / repo_dir / repo_info.sha
total_size = 0
for ent in repo_info.siblings:
url = hf_hub_url(repo_id, ent.rfilename, repo_type=repo_type)
path = output_dir / ent.rfilename
fout.write(url)
fout.write(f'\n out={path}\n')
if ent.size:
total_size += ent.size
if total_size:
print('Total download size is', total_size)
if __name__ == '__main__':
main(parser.parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment