Created
September 24, 2024 20:33
-
-
Save ptrhvns/cfeb859eab2cfda042dd93c2f3bc5b0f to your computer and use it in GitHub Desktop.
Researching various methods of downloading a URL.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import concurrent.futures | |
import http | |
import itertools | |
import logging | |
import math | |
import os | |
import pathlib | |
import platform | |
import ssl | |
import sys | |
import time | |
import urllib.parse | |
import urllib.request | |
import aiofiles | |
import aiohttp | |
import certifi | |
import requests | |
DEFAULT_REQUEST_TIMEOUT = 60 | |
# Per ProcessPoolExecutor documentation. | |
MAX_WINDOWS_WORKERS = 61 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def error(message) -> None: | |
logger.error(message) | |
sys.exit(1) | |
def parse_args(): | |
urls = sys.argv[1:] | |
if len(urls) < 1: | |
error(f"Usage: {sys.argv[0]} URL [URL ...]") | |
return urls | |
def build_filename(url): | |
url_path = urllib.parse.urlparse(url).path | |
unquoted_url_path = urllib.parse.unquote(url_path) | |
return pathlib.PurePosixPath(unquoted_url_path).name | |
async def save_one_a(url, response) -> None: | |
filename = build_filename(url) | |
logger.info("Saving data to file - %s", filename) | |
async with aiofiles.open(filename, mode="wb") as file: | |
async for chunk in response.content.iter_chunked(1024): | |
await file.write(chunk) | |
logger.info("Done saving data to file - %s", filename) | |
async def download_one_a(session, sslcontext, timeout, url): | |
logger.info("Downloading URL - %s", url) | |
async with session.get(url, ssl=sslcontext, timeout=timeout) as response: | |
if response.status == http.HTTPStatus.OK: | |
await save_one_a(url, response) | |
else: | |
error(f"URL download failed (HTTP status {response.status}): {url}") | |
logger.info("Done downloading URL - %s", url) | |
async def download_all_a(urls, timeout): | |
# Used to prevent SSL certificate verification failures. | |
sslcontext = ssl.create_default_context(cafile=certifi.where()) | |
async with aiohttp.ClientSession() as session: | |
tasks = [] | |
for url in urls: | |
coroutine = download_one_a( | |
session=session, sslcontext=sslcontext, timeout=timeout, url=url | |
) | |
task = asyncio.create_task(coroutine) | |
tasks.append(task) | |
await asyncio.gather(*tasks) | |
def save_one_t(url, response) -> None: | |
filename = build_filename(url) | |
logger.info("Saving data to file - %s", filename) | |
with pathlib.Path.open(filename, mode="wb") as file: | |
for chunk in response.iter_content(chunk_size=1024): | |
file.write(chunk) | |
logger.info("Done saving data to file - %s", filename) | |
def download_one_t(url, timeout): | |
logger.info("Downloading URL - %s", url) | |
response = requests.get(url, timeout=timeout) | |
logger.info("Done downloading URL - %s", url) | |
return response | |
def download_all_t(urls, timeout): | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future_to_url = { | |
executor.submit(download_one_t, url, timeout): url for url in urls | |
} | |
for future in concurrent.futures.as_completed(future_to_url): | |
url = future_to_url[future] | |
response = future.result() | |
if response.status_code == http.HTTPStatus.OK: | |
save_one_t(url, response) | |
else: | |
error(f"HTTP error ({response.status_code}) - {url}") | |
def start_async_run(urls, timeout): | |
asyncio.run(download_all_a(urls, timeout)) | |
def calc_num_workers(): | |
cpu_count = os.cpu_count() | |
system = platform.system() | |
if system == "Windows" and cpu_count > MAX_WINDOWS_WORKERS: | |
return MAX_WINDOWS_WORKERS | |
return cpu_count | |
def calc_num_urls_per_worker(urls, num_workers): | |
return int(math.ceil(len(urls) / num_workers)) | |
def download_all_ma(urls, timeout): | |
num_workers = calc_num_workers() | |
logger.info("Number of workers set to %s", num_workers) | |
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: | |
futures = [] | |
num_urls_per_worker = calc_num_urls_per_worker(urls, num_workers) | |
logger.info("URLs per worker set to %s", num_urls_per_worker) | |
for url_batch in itertools.batched(urls, num_urls_per_worker): | |
future = executor.submit(start_async_run, url_batch, timeout) | |
futures.append(future) | |
concurrent.futures.wait(futures) | |
def download_all(urls, timeout): | |
for url in urls: | |
logger.info("Downloading URL - %s", url) | |
response = requests.get(url, stream=True, timeout=timeout) | |
if response.status_code == http.HTTPStatus.OK: | |
filename = build_filename(url) | |
logger.info("Saving data to file - %s", filename) | |
with pathlib.Path.open(filename, mode="wb") as file: | |
for chunk in response.iter_content(chunk_size=1024): | |
file.write(chunk) | |
logger.info("Done saving data to file - %s", filename) | |
else: | |
error(f"HTTP error ({response.status_code}) - {url}") | |
logger.info("Done downloading URL - %s", url) | |
def ns_to_s(t): | |
return t / 1_000_000_000 | |
def main(): | |
urls = parse_args() | |
timeout = DEFAULT_REQUEST_TIMEOUT | |
logger.info("Starting downloads") | |
start_time = time.perf_counter_ns() | |
# Async | |
# start_async_run(urls, timeout) | |
# Threading | |
# download_all_t(urls, timeout) | |
# Multiprocessing w/ Async | |
download_all_ma(urls, timeout) | |
# Linear | |
# download_all(urls, timeout) | |
stop_time = time.perf_counter_ns() | |
run_time = ns_to_s(stop_time - start_time) | |
logger.info("Downloads completed in %s seconds", run_time) | |
return 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment