Created
December 13, 2024 03:55
-
-
Save sammcj/ec38182b10f6be3f7e96f7259a9b37e1 to your computer and use it in GitHub Desktop.
download-azure-ai-models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import aiohttp | |
import os | |
from pathlib import Path | |
import logging | |
from bs4 import BeautifulSoup | |
from typing import List, Dict | |
from dataclasses import dataclass | |
from datetime import datetime | |
import time | |
from urllib.parse import urlparse, parse_qs | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
@dataclass | |
class SASConfig: | |
"""Configuration for Azure Blob Storage SAS token""" | |
container_url: str | |
object_id: str | |
tenant_id: str | |
token_start: str | |
token_expiry: str | |
start_time: str | |
end_time: str | |
signature: str | |
class AzureModelDownloader: | |
def __init__(self, model_url: str, sas_config: SASConfig, output_dir: str = "downloads", max_parallel: int = 5): | |
self.model_url = model_url | |
self.output_dir = output_dir | |
self.max_parallel = max_parallel | |
self.files_list = [] | |
self.sas_config = sas_config | |
self.semaphore = asyncio.Semaphore(max_parallel) | |
@classmethod | |
def from_example_url(cls, example_url: str, model_url: str, output_dir: str = "downloads", max_parallel: int = 5): | |
"""Create instance by parsing an example download URL""" | |
parsed = urlparse(example_url) | |
query = parse_qs(parsed.query) | |
# Extract base container URL | |
container_url = f"{parsed.scheme}://{parsed.netloc}{os.path.dirname(os.path.dirname(parsed.path))}" | |
sas_config = SASConfig( | |
container_url=container_url, | |
object_id=query['skoid'][0], | |
tenant_id=query['sktid'][0], | |
token_start=query['skt'][0], | |
token_expiry=query['ske'][0], | |
start_time=query['st'][0], | |
end_time=query['se'][0], | |
signature=query['sig'][0] | |
) | |
return cls(model_url, sas_config, output_dir, max_parallel) | |
def parse_file_tree(self, html_content: str) -> List[Dict]: | |
"""Parse the file tree HTML to extract file paths""" | |
soup = BeautifulSoup(html_content, 'html.parser') | |
files = [] | |
nav_links = soup.find_all('div', class_='nav-link') | |
for link in nav_links: | |
automation_id = link.get('data-automation-id', '') | |
if not automation_id: | |
continue | |
is_directory = bool(link.find('i', class_='folder-icon')) | |
if is_directory: | |
continue | |
file_span = link.find('span', attrs={'data-automation-localized': 'false'}) | |
if file_span: | |
file_name = file_span.text | |
files.append({ | |
'path': automation_id, | |
'name': file_name, | |
'full_path': os.path.join(*automation_id.split('/')) | |
}) | |
return files | |
def construct_download_url(self, file_path: str) -> str: | |
"""Construct download URL for a file using SAS config""" | |
# Remove any leading slashes and ensure proper formatting | |
clean_path = file_path.lstrip('/') | |
# Construct the full URL with SAS parameters | |
params = { | |
'skoid': self.sas_config.object_id, | |
'sktid': self.sas_config.tenant_id, | |
'skt': self.sas_config.token_start, | |
'ske': self.sas_config.token_expiry, | |
'sks': 'b', | |
'skv': '2021-10-04', | |
'sv': '2021-10-04', | |
'st': self.sas_config.start_time, | |
'se': self.sas_config.end_time, | |
'sr': 'c', | |
'sp': 'rl', | |
'sig': self.sas_config.signature | |
} | |
query_string = '&'.join([f"{k}={v}" for k, v in params.items()]) | |
return f"{self.sas_config.container_url}/{clean_path}?{query_string}" | |
async def download_file(self, session: aiohttp.ClientSession, file_info: Dict): | |
"""Download a single file with retries""" | |
file_path = os.path.join(self.output_dir, file_info['full_path']) | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
max_retries = 3 | |
retry_delay = 1 | |
download_url = self.construct_download_url(file_info['full_path']) | |
async with self.semaphore: | |
for attempt in range(max_retries): | |
try: | |
async with session.get(download_url) as response: | |
if response.status != 200: | |
raise aiohttp.ClientError(f"HTTP {response.status}") | |
total_size = int(response.headers.get('content-length', 0)) | |
downloaded = 0 | |
with open(file_path, 'wb') as f: | |
async for chunk in response.content.iter_chunked(1024*1024): | |
f.write(chunk) | |
downloaded += len(chunk) | |
if total_size > 0: | |
progress = (downloaded / total_size) * 100 | |
print(f"\r{file_info['name']}: {progress:.1f}%", end='', flush=True) | |
print(f"\nCompleted: {file_info['name']}") | |
return | |
except Exception as e: | |
if attempt == max_retries - 1: | |
logging.error(f"Failed to download {file_info['name']} after {max_retries} attempts: {str(e)}") | |
else: | |
await asyncio.sleep(retry_delay * (2 ** attempt)) | |
logging.info(f"Retrying {file_info['name']} (attempt {attempt + 2}/{max_retries})") | |
async def download_all_files(self): | |
"""Download all files in parallel""" | |
async with aiohttp.ClientSession() as session: | |
tasks = [] | |
for file_info in self.files_list: | |
task = asyncio.create_task(self.download_file(session, file_info)) | |
tasks.append(task) | |
await asyncio.gather(*tasks) | |
def save_file_list(self): | |
"""Save the list of files to a text file""" | |
list_path = os.path.join(self.output_dir, "files_to_download.txt") | |
os.makedirs(self.output_dir, exist_ok=True) | |
with open(list_path, 'w') as f: | |
f.write("# Files to download from Azure AI Model Repository\n") | |
f.write(f"# Model URL: {self.model_url}\n") | |
f.write(f"# Container URL: {self.sas_config.container_url}\n") | |
f.write(f"# Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n") | |
f.write("File path | Target location | Download URL\n") | |
f.write("-" * 40 + "|" + "-" * 40 + "|" + "-" * 60 + "\n") | |
for file_info in self.files_list: | |
download_url = self.construct_download_url(file_info['full_path']) | |
f.write(f"{file_info['path']} | {file_info['full_path']} | {download_url}\n") | |
return list_path | |
async def run(self, html_content: str): | |
"""Main execution flow""" | |
print("Parsing file tree...") | |
self.files_list = self.parse_file_tree(html_content) | |
print(f"Found {len(self.files_list)} files") | |
list_file = self.save_file_list() | |
print(f"File list saved to {list_file}") | |
print("\nStarting downloads...") | |
await self.download_all_files() | |
print("\nAll downloads completed!") | |
async def main(): | |
import argparse | |
parser = argparse.ArgumentParser(description="Download Azure AI model files with directory structure") | |
parser.add_argument("--html", required=True, help="Path to the HTML file containing the file tree") | |
parser.add_argument("--url", required=True, help="URL of the model page") | |
parser.add_argument("--output-dir", default="downloads", help="Output directory for downloaded files") | |
parser.add_argument("--parallel", type=int, default=5, help="Maximum parallel downloads") | |
# SAS token parameters | |
parser.add_argument("--example-url", help="Example download URL to extract SAS parameters from") | |
parser.add_argument("--container-url", default="https://amlwlrt4use01.blob.core.windows.net/azureml-0002c54c-3ae6-5726-aa2b-9823dd1236dc", | |
help="Azure Blob Storage container URL") | |
parser.add_argument("--object-id", default="ae2sdd35-a062-42a-961d-aasdad1sd294", | |
help="Storage account object ID (skoid)") | |
parser.add_argument("--tenant-id", default="3sss921-ss64-4f8c-a055-5bdasdasda3d", | |
help="Tenant ID (sktid)") | |
parser.add_argument("--token-start", default="2024-12-13T00:12:36Z", | |
help="Token start time (skt)") | |
parser.add_argument("--token-expiry", default="2024-12-13T16:22:36Z", | |
help="Token expiry time (ske)") | |
parser.add_argument("--start-time", default="2024-12-13T03:16:01Z", | |
help="Start time (st)") | |
parser.add_argument("--end-time", default="2024-12-13T11:26:01Z", | |
help="End time (se)") | |
parser.add_argument("--signature", default="MJlFgasdasdasddwefftXqxNTasdasdd=", | |
help="SAS signature (sig)") | |
args = parser.parse_args() | |
with open(args.html, 'r', encoding='utf-8') as f: | |
html_content = f.read() | |
if args.example_url: | |
downloader = AzureModelDownloader.from_example_url( | |
args.example_url, args.url, args.output_dir, args.parallel | |
) | |
else: | |
sas_config = SASConfig( | |
container_url=args.container_url, | |
object_id=args.object_id, | |
tenant_id=args.tenant_id, | |
token_start=args.token_start, | |
token_expiry=args.token_expiry, | |
start_time=args.start_time, | |
end_time=args.end_time, | |
signature=args.signature | |
) | |
downloader = AzureModelDownloader(args.url, sas_config, args.output_dir, args.parallel) | |
await downloader.run(html_content) | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How to:
pip install webdriver_manager selenium requests
python azure-model-scraper.py --html filetree.html --url 'https://ai.azure.com/explore/models/Phi-4/version/1/registry/azureml?tid=f07d1415-3879-4431-918c-0346c8f0111b#artifacts' --output-dir . --parallel 8