Skip to content

Instantly share code, notes, and snippets.

Created August 31, 2024 14:35
Show Gist options
  • Save cuuupid/2f7b9614ac23aa6f49368d80cc1cec38 to your computer and use it in GitHub Desktop.
Save cuuupid/2f7b9614ac23aa6f49368d80cc1cec38 to your computer and use it in GitHub Desktop. v0.3
import os
import subprocess
import time
import yaml
from tqdm import tqdm
from urllib.parse import urlparse
import fnmatch
CACHE_URI = "" # either s3://bucket/path/ or gs://bucket/path
CDN = None
def parse_dockerignore(fileobj):
return [line.strip() for line in fileobj if line.strip() and not line.startswith('#')]
def should_ignore(file_path, dockerignore_patterns):
# Ensure the file_path is relative to the current directory
rel_path = os.path.normpath(file_path)
if rel_path.startswith(os.path.sep):
rel_path = rel_path[1:]
return any(fnmatch.fnmatch(rel_path, pattern) for pattern in dockerignore_patterns)
def add_to_dockerignore(files):
with open('.dockerignore', 'a') as f:
for file in files:
def make_manifest(manifest_filename: str = 'manifest.pget'):
large_files = []
# Load .dockerignore patterns
dockerignore_patterns = []
if os.path.exists('.dockerignore'):
with open('.dockerignore', 'r') as f:
dockerignore_patterns = parse_dockerignore(f)
# Step 1: Find all files larger than SIZE_THRESHOLD
for root, dirs, files in os.walk('.', topdown=True):
# Modify dirs in-place to exclude ignored directories
dirs[:] = [d for d in dirs if not should_ignore(os.path.relpath(os.path.join(root, d), '.'), dockerignore_patterns)]
for file in files:
filepath = os.path.join(root, file)
rel_filepath = os.path.relpath(filepath, '.')
if not should_ignore(rel_filepath, dockerignore_patterns):
if os.path.getsize(filepath) > SIZE_THRESHOLD * 1024 * 1024:
large_files.append((filepath, os.path.getsize(filepath)))
except OSError as e:
print(f"Error accessing {filepath}: {e}")
# Step 2: List relative filepaths and their sizes
print("Large files found:")
for filepath, size in large_files:
print(f"{filepath}: {size / (1024 * 1024):.2f} MB")
# Step 3: Confirm with user
user_input = input("Please confirm you would like to cache these [Y/n]: ").strip().lower()
if user_input == 'n':
print("Ok, I won't generate a manifest at this time.")
# Step 4: Copy files to cache
if CACHE_URI.startswith('s3://'):
cp_command = ['aws', 's3', 'cp']
elif CACHE_URI.startswith('gs://'):
cp_command = ['gcloud', 'storage', 'cp']
raise ValueError("Invalid CACHE_URI. Must start with 's3://' or 'gs://'")
for filepath, _ in tqdm(large_files, desc="Copying files to cache"):
dest_path = os.path.join(CACHE_URI, filepath.lstrip('./')) + [filepath, dest_path], check=True)
# Step 5: Generate manifest file
with open(manifest_filename, 'w') as f:
for filepath, _ in large_files:
if CDN:
parsed_uri = urlparse(CACHE_URI)
path = parsed_uri.path.strip('/')
url = f"{CDN.rstrip('/')}/{path}/{filepath.lstrip('./')}"
elif CACHE_URI.startswith('s3://'):
bucket, path = CACHE_URI[5:].split('/', 1)
url = f"https://{bucket}{path}/{filepath.lstrip('./')}"
else: # gs://
bucket, path = CACHE_URI[5:].split('/', 1)
url = f"{bucket}/{path}/{filepath.lstrip('./')}"
f.write(f"{url} {filepath}\n")
# Add cached files to .dockerignore
add_to_dockerignore([filepath for filepath, _ in large_files])
print("Added cached files to .dockerignore")
# Step 6: Update cog.yaml
with open('cog.yaml', 'r') as f:
cog_config = yaml.safe_load(f)
build_config = cog_config.get('build', {})
run_commands = build_config.get('run', [])
pget_commands = [
'curl -o /usr/local/bin/pget -L "$(uname -s)_$(uname -m)"',
'chmod +x /usr/local/bin/pget'
if not all(cmd in run_commands for cmd in pget_commands):
build_config['run'] = run_commands
cog_config['build'] = build_config
with open('cog.yaml', 'w') as f:
yaml.dump(cog_config, f)
print("Updated cog.yaml to install pget.")
# Step 7: Update predictor file
predict_config = cog_config.get('predict', '')
if predict_config:
predictor_file, predictor_class = predict_config.split(':')
with open(predictor_file, 'r') as f:
predictor_content =
if 'from pget import pget_manifest' not in predictor_content:
predictor_content = f"from pget import pget_manifest\n{predictor_content}"
if 'def setup(self):' in predictor_content:
predictor_content = predictor_content.replace(
'def setup(self):',
f"def setup(self):\n pget_manifest('{manifest_filename}')"
predictor_content += f"\n def setup(self):\n pget_manifest('{manifest_filename}')\n"
with open(predictor_file, 'w') as f:
print(f"Updated {predictor_file} to include pget_manifest in setup method.")
def pget_manifest(manifest_filename: str='manifest.pget'):
start = time.time()
with open(manifest_filename, 'r') as f:
manifest =
to_dl = []
# ensure directories exist
for line in manifest.splitlines():
_, path = line.split(" ")
os.makedirs(os.path.dirname(path), exist_ok=True)
if not os.path.exists(path):
# write new manifest
with open("tmp.pget", 'w') as f:
# download using pget
subprocess.check_call(["pget", "multifile", "tmp.pget"])
# log metrics
timing = time.time() - start
print(f"Downloaded weights in {timing} seconds")
Copy link

cuuupid commented Oct 10, 2024

To use this:

  • ensure you are authenticated to your bucket (via AWS CLI or GCloud CLI)
  • set CACHE_URI to your bucket, either s3://bucket/path/ or gs://bucket/path
  • if your bucket is accelerated by a CDN like Cloudfront, supply that base URL in CDN


python3 -c "import pget; pget.make_manifest()"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment