Skip to content

Instantly share code, notes, and snippets.

@sammcj
Created December 13, 2024 03:55
Show Gist options
  • Save sammcj/ec38182b10f6be3f7e96f7259a9b37e1 to your computer and use it in GitHub Desktop.
Save sammcj/ec38182b10f6be3f7e96f7259a9b37e1 to your computer and use it in GitHub Desktop.
download-azure-ai-models.py
import asyncio
import aiohttp
import os
from pathlib import Path
import logging
from bs4 import BeautifulSoup
from typing import List, Dict
from dataclasses import dataclass
from datetime import datetime
import time
from urllib.parse import urlparse, parse_qs
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@dataclass
class SASConfig:
"""Configuration for Azure Blob Storage SAS token"""
container_url: str
object_id: str
tenant_id: str
token_start: str
token_expiry: str
start_time: str
end_time: str
signature: str
class AzureModelDownloader:
def __init__(self, model_url: str, sas_config: SASConfig, output_dir: str = "downloads", max_parallel: int = 5):
self.model_url = model_url
self.output_dir = output_dir
self.max_parallel = max_parallel
self.files_list = []
self.sas_config = sas_config
self.semaphore = asyncio.Semaphore(max_parallel)
@classmethod
def from_example_url(cls, example_url: str, model_url: str, output_dir: str = "downloads", max_parallel: int = 5):
"""Create instance by parsing an example download URL"""
parsed = urlparse(example_url)
query = parse_qs(parsed.query)
# Extract base container URL
container_url = f"{parsed.scheme}://{parsed.netloc}{os.path.dirname(os.path.dirname(parsed.path))}"
sas_config = SASConfig(
container_url=container_url,
object_id=query['skoid'][0],
tenant_id=query['sktid'][0],
token_start=query['skt'][0],
token_expiry=query['ske'][0],
start_time=query['st'][0],
end_time=query['se'][0],
signature=query['sig'][0]
)
return cls(model_url, sas_config, output_dir, max_parallel)
def parse_file_tree(self, html_content: str) -> List[Dict]:
"""Parse the file tree HTML to extract file paths"""
soup = BeautifulSoup(html_content, 'html.parser')
files = []
nav_links = soup.find_all('div', class_='nav-link')
for link in nav_links:
automation_id = link.get('data-automation-id', '')
if not automation_id:
continue
is_directory = bool(link.find('i', class_='folder-icon'))
if is_directory:
continue
file_span = link.find('span', attrs={'data-automation-localized': 'false'})
if file_span:
file_name = file_span.text
files.append({
'path': automation_id,
'name': file_name,
'full_path': os.path.join(*automation_id.split('/'))
})
return files
def construct_download_url(self, file_path: str) -> str:
"""Construct download URL for a file using SAS config"""
# Remove any leading slashes and ensure proper formatting
clean_path = file_path.lstrip('/')
# Construct the full URL with SAS parameters
params = {
'skoid': self.sas_config.object_id,
'sktid': self.sas_config.tenant_id,
'skt': self.sas_config.token_start,
'ske': self.sas_config.token_expiry,
'sks': 'b',
'skv': '2021-10-04',
'sv': '2021-10-04',
'st': self.sas_config.start_time,
'se': self.sas_config.end_time,
'sr': 'c',
'sp': 'rl',
'sig': self.sas_config.signature
}
query_string = '&'.join([f"{k}={v}" for k, v in params.items()])
return f"{self.sas_config.container_url}/{clean_path}?{query_string}"
async def download_file(self, session: aiohttp.ClientSession, file_info: Dict):
"""Download a single file with retries"""
file_path = os.path.join(self.output_dir, file_info['full_path'])
os.makedirs(os.path.dirname(file_path), exist_ok=True)
max_retries = 3
retry_delay = 1
download_url = self.construct_download_url(file_info['full_path'])
async with self.semaphore:
for attempt in range(max_retries):
try:
async with session.get(download_url) as response:
if response.status != 200:
raise aiohttp.ClientError(f"HTTP {response.status}")
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(file_path, 'wb') as f:
async for chunk in response.content.iter_chunked(1024*1024):
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
progress = (downloaded / total_size) * 100
print(f"\r{file_info['name']}: {progress:.1f}%", end='', flush=True)
print(f"\nCompleted: {file_info['name']}")
return
except Exception as e:
if attempt == max_retries - 1:
logging.error(f"Failed to download {file_info['name']} after {max_retries} attempts: {str(e)}")
else:
await asyncio.sleep(retry_delay * (2 ** attempt))
logging.info(f"Retrying {file_info['name']} (attempt {attempt + 2}/{max_retries})")
async def download_all_files(self):
"""Download all files in parallel"""
async with aiohttp.ClientSession() as session:
tasks = []
for file_info in self.files_list:
task = asyncio.create_task(self.download_file(session, file_info))
tasks.append(task)
await asyncio.gather(*tasks)
def save_file_list(self):
"""Save the list of files to a text file"""
list_path = os.path.join(self.output_dir, "files_to_download.txt")
os.makedirs(self.output_dir, exist_ok=True)
with open(list_path, 'w') as f:
f.write("# Files to download from Azure AI Model Repository\n")
f.write(f"# Model URL: {self.model_url}\n")
f.write(f"# Container URL: {self.sas_config.container_url}\n")
f.write(f"# Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
f.write("File path | Target location | Download URL\n")
f.write("-" * 40 + "|" + "-" * 40 + "|" + "-" * 60 + "\n")
for file_info in self.files_list:
download_url = self.construct_download_url(file_info['full_path'])
f.write(f"{file_info['path']} | {file_info['full_path']} | {download_url}\n")
return list_path
async def run(self, html_content: str):
"""Main execution flow"""
print("Parsing file tree...")
self.files_list = self.parse_file_tree(html_content)
print(f"Found {len(self.files_list)} files")
list_file = self.save_file_list()
print(f"File list saved to {list_file}")
print("\nStarting downloads...")
await self.download_all_files()
print("\nAll downloads completed!")
async def main():
import argparse
parser = argparse.ArgumentParser(description="Download Azure AI model files with directory structure")
parser.add_argument("--html", required=True, help="Path to the HTML file containing the file tree")
parser.add_argument("--url", required=True, help="URL of the model page")
parser.add_argument("--output-dir", default="downloads", help="Output directory for downloaded files")
parser.add_argument("--parallel", type=int, default=5, help="Maximum parallel downloads")
# SAS token parameters
parser.add_argument("--example-url", help="Example download URL to extract SAS parameters from")
parser.add_argument("--container-url", default="https://amlwlrt4use01.blob.core.windows.net/azureml-0002c54c-3ae6-5726-aa2b-9823dd1236dc",
help="Azure Blob Storage container URL")
parser.add_argument("--object-id", default="ae2sdd35-a062-42a-961d-aasdad1sd294",
help="Storage account object ID (skoid)")
parser.add_argument("--tenant-id", default="3sss921-ss64-4f8c-a055-5bdasdasda3d",
help="Tenant ID (sktid)")
parser.add_argument("--token-start", default="2024-12-13T00:12:36Z",
help="Token start time (skt)")
parser.add_argument("--token-expiry", default="2024-12-13T16:22:36Z",
help="Token expiry time (ske)")
parser.add_argument("--start-time", default="2024-12-13T03:16:01Z",
help="Start time (st)")
parser.add_argument("--end-time", default="2024-12-13T11:26:01Z",
help="End time (se)")
parser.add_argument("--signature", default="MJlFgasdasdasddwefftXqxNTasdasdd=",
help="SAS signature (sig)")
args = parser.parse_args()
with open(args.html, 'r', encoding='utf-8') as f:
html_content = f.read()
if args.example_url:
downloader = AzureModelDownloader.from_example_url(
args.example_url, args.url, args.output_dir, args.parallel
)
else:
sas_config = SASConfig(
container_url=args.container_url,
object_id=args.object_id,
tenant_id=args.tenant_id,
token_start=args.token_start,
token_expiry=args.token_expiry,
start_time=args.start_time,
end_time=args.end_time,
signature=args.signature
)
downloader = AzureModelDownloader(args.url, sas_config, args.output_dir, args.parallel)
await downloader.run(html_content)
if __name__ == "__main__":
asyncio.run(main())
@sammcj
Copy link
Author

sammcj commented Dec 13, 2024

How to:

  1. Go to the azure ai model thing and download one small file from the public repo (you might have to login but it should be free)
  2. Right click on the download and copy the download url
  3. Input the parameters from the URL into the python script or arguments
  4. pip install webdriver_manager selenium requests
  5. python azure-model-scraper.py --html filetree.html --url 'https://ai.azure.com/explore/models/Phi-4/version/1/registry/azureml?tid=f07d1415-3879-4431-918c-0346c8f0111b#artifacts' --output-dir . --parallel 8

azzzzurreerrr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment