Skip to content

Instantly share code, notes, and snippets.

@clee704
Last active August 29, 2015 13:58
Show Gist options
  • Save clee704/9996914 to your computer and use it in GitHub Desktop.
Save clee704/9996914 to your computer and use it in GitHub Desktop.
Image Downloader
#! /usr/bin/env python
# Copyright 2014 Choongmin Lee
# Licensed under the MIT License
from __future__ import print_function
import argparse
from datetime import datetime
import os
import re
import sys
from tempfile import NamedTemporaryFile
import threading
import time
import lxml.html
import pytz
import requests
from tzlocal import get_localzone
if sys.version_info[0] > 2:
string_types = (str,)
raw_input = input
else:
string_types = (str, unicode)
parser = argparse.ArgumentParser(
description='Download images from HTML documents.')
parser.add_argument('url', metavar='URL', nargs='+',
help='URL to an HTML document')
parser.add_argument('-d', '--directory', metavar='DIR', default='.',
help='directory where downloaded files will be saved')
parser.add_argument('-p', '--prefix', default='',
help='add prefix when saving downloaded files')
parser.add_argument('-e', '--extensions',
help='comma-separated list of file extensions to download')
parser.add_argument('--min-size', metavar='SIZE', type=int, default=0,
help='do not download files smaller than SIZE')
parser.add_argument('-y', '--overwrite', action='store_true',
help='overwrite existing files')
parser.add_argument('-n', '--do-not-overwrite', action='store_true',
help='do not overwrite existing files')
parser.add_argument('-q', '--quiet', action='store_true',
help='do not print messages; use --overwrite or '
'--do-not-overwrite to also disable prompt')
def main():
opts = parser.parse_args()
if not os.path.exists(opts.directory):
os.makedirs(opts.directory)
if opts.extensions is not None:
opts.extensions = [t.strip() for t in opts.extensions.split(',')]
overwrite = ('no' if opts.do_not_overwrite
else 'yes' if opts.overwrite
else 'ask')
downloader = Downloader(download_dir=opts.directory,
prefix=opts.prefix,
extensions=opts.extensions,
min_size=opts.min_size,
overwrite=overwrite,
verbose=not opts.quiet)
for url in opts.url:
downloader.download_images(url)
class Downloader(object):
DEFAULT_EXTENSIONS = ['jpg', 'jpeg', 'gif', 'png', 'tiff', 'bmp']
def __init__(self, download_dir='.', prefix='', extensions=None,
min_size=0, overwrite='ask', verbose=True,
download_chunk_size=4096, progress_bar_interval=0.25,
progress_bar_width=50):
if extensions is None:
extensions = self.DEFAULT_EXTENSIONS
self.download_dir = download_dir
self.prefix = prefix
self.overwrite = overwrite
self.extensions = extensions
self.min_size = min_size
self.verbose = verbose
self.download_chunk_size = download_chunk_size
self.progress_bar_interval = progress_bar_interval
self.progress_bar_width = progress_bar_width
def log(self, msg):
if self.verbose:
print(msg)
def download_images(self, url):
self.log('Downloading images from {}'.format(url))
resp = requests.get(url)
doc = lxml.html.fromstring(resp.content)
doc.make_links_absolute(url)
downloaded_urls = set()
downloaded_paths = set()
for img_or_a in doc.xpath('//img | //a'):
attr = 'href' if img_or_a.tag == 'a' else 'src'
if not self.is_image_link(img_or_a.attrib.get(attr)):
continue
image_url = img_or_a.attrib[attr]
if image_url in downloaded_urls:
continue
path = self.make_path(image_url, downloaded_paths)
if path is None:
continue
self.download_image(image_url, save_as=path, referer=url)
if os.path.exists(path):
downloaded_paths.add(path)
downloaded_urls.add(image_url)
self.log('Downloaded {} files'.format(len(downloaded_paths)))
def is_image_link(self, url):
if not isinstance(url, string_types):
return False
url = url.lower()
for ext in self.extensions:
if url.endswith(ext):
return True
return False
def make_path(self, url, old_paths):
path = os.path.join(self.download_dir,
self.prefix + os.path.basename(url))
if path in old_paths:
path = avoid_existing_path(path, self.extensions)
if os.path.exists(path) and self.overwrite != 'yes':
if self.overwrite == 'no':
self.log('Skipping {} since the file {} exists'.format(url,
path))
return
msg = 'File {} exists; overwrite? (y/N) '.format(path)
if raw_input(msg) != 'y':
return
return path
def download_image(self, url, save_as, referer):
resp = requests.get(url, headers={'Referer': referer}, stream=True)
try:
total_length = int(resp.headers.get('content-length'))
except (TypeError, ValueError):
total_length = None
if total_length is not None and total_length < self.min_size:
self.log('Skipping {} since it is smaller than {}'.format(
url, format_filesize(self.min_size)))
return
with NamedTemporaryFile(delete=False) as f:
progress_data = {
'current': 0,
'total': total_length
}
self.log('Downloading {}'.format(url))
progress_bar_enabled = (total_length is not None and
self.verbose and
sys.stdout.isatty())
if progress_bar_enabled:
progress_bar = ProgressBar(progress_data,
self.progress_bar_interval,
self.progress_bar_width)
progress_bar.start()
for data in resp.iter_content(self.download_chunk_size):
progress_data['current'] += len(data)
f.write(data)
if progress_bar_enabled:
progress_bar.stop()
if total_length is None:
f.seek(0, os.SEEK_END)
size = f.tell()
if size < self.min_size:
self.log('Skipping {} since it is smaller than {}'.format(
url, format_filesize(self.min_size)))
os.unlink(f.name)
return
os.rename(f.name, save_as)
# Set mtime
last_modified = resp.headers.get('last-modified')
if last_modified:
try:
lm = datetime.strptime(last_modified,
'%a, %d %b %Y %H:%M:%S GMT')
except ValueError:
pass
else:
lm_utc = lm.replace(tzinfo=pytz.utc)
lm_local = lm_utc.astimezone(get_localzone())
mtime = time.mktime(lm_local.timetuple())
os.utime(save_as, (mtime, mtime))
self.log('Saved as {}'.format(save_as))
class ProgressBar(threading.Thread):
def __init__(self, data, interval, width):
super(ProgressBar, self).__init__()
self.daemon = True
self.data = data
self.interval = interval
self.width = width
self.throbber = 0
self.throbber_chars = ['|', '/', '-', '\\']
self.running = False
def run(self):
self.running = True
while self.running:
self.print_progress()
if self.data['current'] >= self.data['total']:
self.running = False
else:
time.sleep(self.interval)
def print_progress(self, end=False):
current = self.data['current']
total = self.data['total']
percent = min(100, int(100 * current / total))
progress = min(self.width, int(self.width * current / total))
progress_line = '[{}{}] {}% {} {} / {}'.format(
'=' * progress,
' ' * (self.width - progress),
percent,
self.throbber_chars[self.throbber],
format_filesize(current),
format_filesize(total))
sys.stdout.write('\r{:79}'.format(progress_line))
sys.stdout.flush()
self.throbber = (self.throbber + 1) % len(self.throbber_chars)
if end:
print()
def stop(self):
self.running = False
self.print_progress(True)
def format_filesize(size):
if size < 1024:
return '{} B'.format(size)
size /= 1024.0
for unit in ['KB', 'MB', 'GB']:
if size < 1024.0:
return '{:3.1f} {}'.format(size, unit)
size /= 1024.0
return '{:3.1f} {}'.format(size, 'TB')
def avoid_existing_path(path, extensions):
pattern = re.compile('(.*(?: \(([0-9]+)\))?)\.({})'.format('|'.join(
re.escape(ext) for ext in extensions)))
while os.path.exists(path):
m = pattern.match(path)
number = m.group(2)
if number is None:
path = '{} (2).{}'.format(m.group(1), m.group(3))
else:
path = '{} ({}).{}'.format(m.group(1), int(number), m.group(3))
return path
if __name__ == '__main__':
try:
sys.exit(main())
except KeyboardInterrupt:
print()
sys.exit(2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment