Created
August 31, 2024 00:04
-
-
Save cuuupid/44ed1d2aa223e661b7a106e304e5a04a to your computer and use it in GitHub Desktop.
PGet.py v0.2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import subprocess | |
import time | |
import yaml | |
from tqdm import tqdm | |
from urllib.parse import urlparse | |
from docker import utils as docker_utils | |
SIZE_THRESHOLD = 100 # MB | |
CACHE_URI = "s3://..." # either s3://bucket/path/ or gs://bucket/path | |
CDN = None | |
def should_ignore(file_path, dockerignore_patterns): | |
return docker_utils.exclude_paths('.', [file_path], dockerignore_patterns) | |
def make_manifest(manifest_filename: str = 'manifest.pget'): | |
large_files = [] | |
# Load .dockerignore patterns | |
dockerignore_patterns = [] | |
if os.path.exists('.dockerignore'): | |
with open('.dockerignore', 'r') as f: | |
dockerignore_patterns = docker_utils.parse_dockerignore(f) | |
# Step 1: Find all files larger than SIZE_THRESHOLD | |
for root, _, files in os.walk('.'): | |
for file in files: | |
filepath = os.path.join(root, file) | |
rel_filepath = os.path.relpath(filepath, '.') | |
if not should_ignore(rel_filepath, dockerignore_patterns): | |
if os.path.getsize(filepath) > SIZE_THRESHOLD * 1024 * 1024: | |
large_files.append((filepath, os.path.getsize(filepath))) | |
# Step 2: List relative filepaths and their sizes | |
print("Large files found:") | |
for filepath, size in large_files: | |
print(f"{filepath}: {size / (1024 * 1024):.2f} MB") | |
# Step 3: Confirm with user | |
user_input = input("Please confirm you would like to cache these [Y/n]: ").strip().lower() | |
if user_input == 'n': | |
print("Ok, I won't generate a manifest at this time.") | |
return | |
# Step 4: Copy files to cache | |
if CACHE_URI.startswith('s3://'): | |
cp_command = ['aws', 's3', 'cp'] | |
elif CACHE_URI.startswith('gs://'): | |
cp_command = ['gcloud', 'storage', 'cp'] | |
else: | |
raise ValueError("Invalid CACHE_URI. Must start with 's3://' or 'gs://'") | |
for filepath, _ in tqdm(large_files, desc="Copying files to cache"): | |
dest_path = os.path.join(CACHE_URI, filepath.lstrip('./')) | |
subprocess.run(cp_command + [filepath, dest_path], check=True) | |
# Step 5: Generate manifest file | |
with open(manifest_filename, 'w') as f: | |
for filepath, _ in large_files: | |
if CDN: | |
parsed_uri = urlparse(CACHE_URI) | |
path = parsed_uri.path.strip('/') | |
url = f"{CDN.rstrip('/')}/{path}/{filepath.lstrip('./')}" | |
elif CACHE_URI.startswith('s3://'): | |
bucket, path = CACHE_URI[5:].split('/', 1) | |
url = f"https://{bucket}.s3.amazonaws.com/{path}/{filepath.lstrip('./')}" | |
else: # gs:// | |
bucket, path = CACHE_URI[5:].split('/', 1) | |
url = f"https://storage.googleapis.com/{bucket}/{path}/{filepath.lstrip('./')}" | |
f.write(f"{url} {filepath}\n") | |
# Step 6: Update cog.yaml | |
with open('cog.yaml', 'r') as f: | |
cog_config = yaml.safe_load(f) | |
build_config = cog_config.get('build', {}) | |
run_commands = build_config.get('run', []) | |
pget_commands = [ | |
'curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)"', | |
'chmod +x /usr/local/bin/pget' | |
] | |
if not all(cmd in run_commands for cmd in pget_commands): | |
run_commands.extend(pget_commands) | |
build_config['run'] = run_commands | |
cog_config['build'] = build_config | |
with open('cog.yaml', 'w') as f: | |
yaml.dump(cog_config, f) | |
print("Updated cog.yaml to install pget.") | |
# Step 7: Update predictor file | |
predict_config = cog_config.get('predict', '') | |
if predict_config: | |
predictor_file, predictor_class = predict_config.split(':') | |
with open(predictor_file, 'r') as f: | |
predictor_content = f.read() | |
if 'from pget import pget_manifest' not in predictor_content: | |
predictor_content = f"from pget import pget_manifest\n{predictor_content}" | |
if 'def setup(self):' in predictor_content: | |
predictor_content = predictor_content.replace( | |
'def setup(self):', | |
f"def setup(self):\n pget_manifest('{manifest_filename}')" | |
) | |
else: | |
predictor_content += f"\n def setup(self):\n pget_manifest('{manifest_filename}')\n" | |
with open(predictor_file, 'w') as f: | |
f.write(predictor_content) | |
print(f"Updated {predictor_file} to include pget_manifest in setup method.") | |
def pget_manifest(manifest_filename: str='manifest.pget'): | |
start = time.time() | |
with open(manifest_filename, 'r') as f: | |
manifest = f.read() | |
to_dl = [] | |
# ensure directories exist | |
for line in manifest.splitlines(): | |
_, path = line.split(" ") | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
if not os.path.exists(path): | |
to_dl.append(line) | |
# write new manifest | |
with open("tmp.pget", 'w') as f: | |
f.write("\n".join(to_dl)) | |
# download using pget | |
subprocess.check_call(["pget", "multifile", "tmp.pget"]) | |
# log metrics | |
timing = time.time() - start | |
print(f"Downloaded weights in {timing} seconds") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment