Created
October 12, 2011 14:09
-
-
Save lqc/1281315 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import time | |
import queue | |
import locale | |
import os.path | |
import urllib.parse | |
import netutils | |
import textprogressbar | |
class download_manager (object): | |
locale.setlocale(locale.LC_NUMERIC, "") | |
def __init__ (self, ): | |
self.queue = queue.Queue() | |
self.pbar = textprogressbar.textprogressbar() | |
def add_URL(self, url_list, ): | |
if isinstance(url_list, (list, tuple, set, frozenset)): | |
for item in url_list: | |
self._append(item) | |
elif isinstance(url_list, str): | |
self._append(url_list) | |
self.start_next_download() | |
def _append(self, URL): | |
verified_URL = self.verify_URL(URL) | |
if not verified_URL: | |
return | |
self.queue.put(verified_URL) | |
def verify_URL(self, URL): | |
if not URL.strip(): | |
return | |
verified_URL = urllib.parse.urlparse(URL) | |
if not verified_URL.scheme: | |
return self.verify_URL("http://" + URL) | |
if verified_URL.scheme != "http": | |
return | |
verified_URL = urllib.parse.urlunparse(verified_URL) | |
return verified_URL | |
@staticmethod | |
def save_file_name(url): | |
url = urllib.parse.urlparse(url) | |
basename = url.path.rsplit('/')[-1] | |
if not basename: basename = "index.htm" | |
i = 1 | |
exp = "" | |
# if already exists, don't overwrite | |
while os.path.isfile(basename + exp): | |
exp = "." + str(i) | |
i += 1 | |
return basename + exp | |
def start_next_download(self): | |
if self.queue.empty(): | |
return | |
url = self.queue.get() | |
filename = self.save_file_name(url) | |
task = netutils.http_get_file(url, filename) | |
task.add_observer(self) | |
print("Downloading %s..." % url[:65]) | |
task.resume() | |
def update(self, task, *args, **kw): | |
self.update_download_progress(task) | |
status = task.get_status() | |
if status == netutils.COMPLETE: | |
self.pbar.clear() | |
print("Succeeded.") | |
elif status == netutils.ERROR: | |
self.pbar.clear() | |
print("Failed", task.get_url()) | |
elif status in (netutils.DOWNLOADING, netutils.CONNECTING): | |
return | |
self.start_next_download() | |
def update_download_progress(self, task): | |
self.pbar.set_status(task.received, task.size) | |
speed = task.received / (time.time() - task.timestamp) | |
speed = netutils.retr_rate_to_human(speed) | |
received = locale.format("%d", int(task.received), grouping=True) | |
self.pbar.set_message("%s %s" % (received, speed)) | |
self.pbar.update() | |
def main(argv): | |
manager = download_manager() | |
manager.add_URL(argv) | |
if __name__ == '__main__': | |
import sys | |
if not sys.argv[1:]: | |
sys.exit(2) | |
main(sys.argv[1:]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys | |
import time | |
import traceback | |
import threading | |
import http.client | |
from socket import timeout, gaierror | |
from tempfile import mkstemp | |
from urllib.parse import urlparse | |
def retr_eta_to_human(secs): | |
if secs == -1: | |
eta = "aeternum" | |
elif (secs < 100): | |
eta = "%d s"%secs | |
elif (secs < 100 * 60): | |
eta = "%dm %ds"%(secs / 60, secs % 60) | |
elif (secs < 48 * 3600): | |
eta = "%dh %dm"% (secs / 3600, (secs / 60) % 60) | |
elif (secs < 100 * 86400): | |
eta = "%dd %dh"%( secs / 86400, (secs / 3600) % 60) | |
else: | |
eta = "%dd"% (secs / 86400) | |
return eta | |
def retr_size_to_human(bytes): | |
size_units_names = ('B', 'KB', 'MB', 'GB', 'TB') | |
if bytes == -1: | |
return '--,-B' | |
elif bytes < 1024: | |
units = 0 | |
elif bytes < 1024 ** 2: | |
units = 1; bytes /= 1024 | |
elif bytes < 1024 ** 3: | |
units = 2; bytes /= 1024 ** 2 | |
elif bytes < 1024 ** 4: | |
units = 3; bytes /= 1024 ** 3 | |
else: | |
units = 4; bytes /= 1024 ** 4 | |
return '{0:.{1}f} {2}'.format(bytes, 0 if bytes >= 99.95 else 1\ | |
if bytes >= 9.995 else 2, | |
size_units_names[units]) | |
def retr_rate_to_human(dlrate): | |
rate_units_names = ("B/s", "kB/s", "MB/s", "GB/s" ) | |
if dlrate == -1: | |
return '--,-KB/s' | |
if dlrate < 1024.0: | |
units = 0 | |
elif dlrate < 1024.0 ** 2: | |
units = 1; dlrate /= 1024.0 | |
elif dlrate < 1024.0 ** 3: | |
units = 2; dlrate /= 1024.0 ** 2 | |
else: | |
units = 3; dlrate /= 1024.0 ** 3 | |
return '{0:.{1}f} {2}'.format(dlrate, 0 if dlrate >= 99.95 else 1\ | |
if dlrate >= 9.995 else 2, | |
rate_units_names[units]) | |
def retr_progress_to_human(progress): | |
if progress == -1: | |
return '--,-%' | |
return '{0:.{1}f}%'.format(progress, 0 if progress >= 99.95 else 1\ | |
if progress >= 9.995 else 2) | |
def c_strpbrk(str, char_list): | |
try: | |
return str[next(i for i,x in enumerate(str) if x in char_list):] | |
except StopIteration: | |
pass | |
def c_strspn(str, char_list): | |
def dump(): | |
for c in str: | |
if c in char_list: | |
yield 1 | |
else: | |
break | |
return sum(dump()) | |
def rewrite_shorthand_url(url): | |
'''Used to detect URLs written using the "shorthand" URL forms originally | |
popularized by Netscape and NcFTP. HTTP shorthands look like this: | |
www.foo.com[:port]/dir/file -> http://www.foo.com[:port]/dir/file | |
www.foo.com[:port] -> http://www.foo.com[:port] | |
FTP shorthands look like this: | |
foo.bar.com:dir/file -> ftp://foo.bar.com/dir/file | |
foo.bar.com:/absdir/file -> ftp://foo.bar.com//absdir/file | |
If the URL needs not or cannot be rewritten, return orginal URL. | |
''' | |
# Look for a ':' or '/'. The former signifies NcFTP syntax, the | |
# latter Netscape. | |
p = c_strpbrk(url, ':/') | |
if p == url: | |
return url | |
# If we're looking at "://", it means the URL uses a scheme we | |
# don't support, Don't bogusly rewrite such URLs. | |
if p and p[:3] == '://': | |
return url | |
if p and p[0] == ':': | |
# Colon indicates ftp, as in foo.bar.com:path. Check for | |
# special case of http port number ("localhost:10000"). | |
digits = c_strspn (p[1:], "0123456789") | |
if digits and ( p[1 + digits:][0] == '/' or not p[1 + digits:] ): | |
# Just prepend "http://" to URL. | |
url = "http://" + url | |
else: | |
# Turn "foo.bar.com:path" to "ftp://foo.bar.com/path". | |
url = "ftp://" + url[:-len(p)] + '/' + p[1:] | |
# url = url[6 + (p - url)] = '/' | |
else: | |
# Just prepend "http://" to URL. | |
url = "http://" + url | |
return url | |
class Observable(): | |
def add_observer(self, observer, *args, **kw): | |
try: | |
observers = self._observers | |
except AttributeError: #first time use | |
observers = self.__dict__.setdefault('_observers', []) | |
observers.append((observer, args, kw)) | |
def del_observer(self, observer): | |
try: | |
for idx, obs in enumerate(self._observers): | |
if obs[0] == observer: | |
del self._observers[idx] | |
except AttributeError: | |
pass | |
def notify_observers(self, *args, **kw): | |
try: | |
for observer, args_, kw_ in self._observers: | |
args = args + args_; kw.update(kw_) | |
observer.update(self, *args, **kw) | |
except AttributeError as e: | |
pass | |
# HTTP/1.0 status codes from RFC1945, provided for reference. | |
# Successful 2xx. | |
HTTP_STATUS_PARTIAL_CONTENTS = 206 | |
HTTP_STATUS_NO_CONTENT = 204 | |
HTTP_STATUS_ACCEPTED = 202 | |
HTTP_STATUS_CREATED = 201 | |
HTTP_STATUS_OK = 200 | |
# Size of download buffer. | |
BUFFER_SIZE = 1024 | |
# Maximum number of repeat | |
MAX_REPEATS = 20 | |
#Maximum number of redirection | |
MAX_REDIRECTIONS = 10 | |
# These are the status codes. | |
CONNECTING = 0 | |
DOWNLOADING = 1 | |
COMPLETE = 2 | |
PAUSED = 3 | |
CANCELLED = 4 | |
ERROR = 5 | |
# These are the status names. | |
STATUSES = ("Connecting", | |
"Downloading", | |
"Complete", | |
"Paused", | |
"Cancelled", | |
"Error") | |
# Resume incomplete download. | |
RESUME = True | |
# Debuglevel == 0 -> disabled; debuglevel > 0 -> enabled | |
DEBUGLEVEL = 0 | |
# Sleep time before retry | |
RETRY_DELAY_SAMPLE = 0.5 | |
class MAXrdr(Exception): pass | |
class MISilo(Exception): pass | |
class INread(Exception): pass | |
class ICStat(Exception): pass | |
class http_get_file(Observable): | |
def __init__ (self, url, file=None, limitrate=None, headers={}): | |
self.url = urlparse(url) | |
self.status = CONNECTING | |
self.limitrate = limitrate # | |
self.file = file # file to write data | |
self.errstr = None # error (if any) description | |
self.req_hdrs = headers # request headers | |
self.res_hdrs = {} # response headers | |
self.size = -1 # downloading file size | |
self.redirectno = 0 # redirection counter | |
self.timestamp = 0 # | |
self.received = 0 # | |
self.triesno = 0 # | |
self._recent_start = 0 # | |
self._recent_bytes = 0 # | |
# Begin the download. | |
# self._download() | |
def suspend(self): | |
'Pause this download.' | |
self.status = PAUSED | |
self.stats_changed() | |
def resume(self): | |
'Resume this download.' | |
self.status = CONNECTING | |
self.stats_changed() | |
self._download() | |
def cancel(self): | |
'Cancel this download.' | |
self.status = CANCELLED | |
self.stats_changed() | |
def get_url(self,): | |
return self.url.geturl() | |
def get_info(self): | |
'Return the meta-information of the page - headers' | |
return self.res_hdrs | |
def get_size(self,): | |
return self.size | |
def get_status(self,): | |
return self.status | |
def get_progress(self,): | |
"Get this download's progress." | |
if self.size > 0: | |
return 100 * self.received / self.size | |
elif self.status == COMPLETE: | |
return 100 | |
else: | |
return -1 | |
def add_headers(self, headers): | |
'Add headers for this request' | |
if type(headers) != dict: | |
headers = dict(headers) | |
self.req_hdrs.update(headers) | |
def set_limitrate(self, limitrate): | |
'Set download limit' | |
self.limitrate = float(limitrate) | |
def error(self): | |
errtype, value = sys.exc_info()[:2] | |
if errtype in (timeout, INread, ): | |
self.non_fatal_error_handler(value) | |
elif errtype == gaierror: | |
if value.errno == -2: | |
self.fatal_error_handler(value.strerror) | |
else: | |
# todo | |
self.fatal_error_handler(value.strerror) | |
elif errtype in (MAXrdr, MISilo): # Fatal errors | |
self.fatal_error_handler(value) | |
else: # Unknown Errors | |
self.fatal_error_handler(value) | |
def fatal_error_handler(self, value): | |
self.status = ERROR | |
self.errstr = value | |
self.stats_changed() | |
#------------------- | |
type, value, tb = sys.exc_info() | |
info = traceback.extract_tb(tb) | |
filename, lineno, function, text = info[-1] # last line only | |
errmsg = ( | |
'\nfilename: "%s" at line: %d' | |
'\ntype : "%s"' | |
'\nvalue : "%s"' | |
'\n(in function "%s")' | |
) | |
print(errmsg % (filename, lineno, type.__name__, str(value), function)) | |
sys.exit(1) | |
def non_fatal_error_handler(self, errstr): | |
if self.tries < MAX_REPEATS: | |
self.tries += 1 | |
time.sleep(RETRY_DELAY_SAMPLE) | |
self.resume() | |
else: | |
self.fatal_error_handler(errstr) | |
def _download(self, ): | |
'Start or resume downloading.' | |
thread = threading.Thread(target=self.run) | |
thread.start() | |
def _redirect(self, ): | |
'Make redirection' | |
if self.redirectno >= MAX_REDIRECTIONS: | |
raise MAXrdr | |
else: | |
self.redirectno += 1 | |
self.triesno = 0 | |
headers = {k.lower(): v for k, v in self.get_info()} | |
if "location" in headers: | |
newurl = headers["location"] | |
elif "uri" in headers: | |
newurl = headers["uri"] | |
else: | |
raise MISilo | |
self.url = urlparse(newurl) | |
self.resume() | |
def _limit_bandwidth(self,): | |
'Limit the bandwidth by pausing the download for an amount of time.' | |
if not self.limitrate: return | |
# Calculate the amount of time we expect downloading the chunk should | |
# take. If in reality it took less time, sleep to compensate for the | |
# difference. | |
expected = self._recent_bytes / self.limitrate | |
recent_age = time.time() - self._recent_start | |
if expected <= recent_age: | |
return | |
sleep_t = expected - recent_age | |
time.sleep(sleep_t) | |
def run(self, ): | |
conn = None | |
file = None | |
try: | |
## Set status to CONNECTING | |
## self.status = CONNECTING | |
# Open connection to URL. | |
conn = http.client.HTTPConnection(self.url.netloc) | |
# Set debuglevel | |
conn.debuglevel = DEBUGLEVEL | |
# Specify what portion of file to download. | |
if RESUME and self.received: | |
self.req_hdrs.update({"Range": "bytes=%d-" % self.received}) | |
# Connect to server. | |
conn.request("GET", self.url.path, headers=self.req_hdrs) | |
response = conn.getresponse() | |
# Get connection headers | |
self.res_hdrs = response.getheaders() | |
# Make redirection if needed | |
if response.status // 100 == 3: | |
return self._redirect() | |
# Make sure response code is in the 200 range. | |
elif response.status // 100 != 2: | |
raise ICStat(response.status, response.reason) | |
# Make sure range retrieval work correctly. | |
if response.status != HTTP_STATUS_PARTIAL_CONTENTS: | |
# reset number of downloaded bytes | |
self.received = 0 | |
# Set the size for this download if it hasn't been already set. | |
if "Content-Length" in dict(self.res_hdrs) and self.size == -1: | |
self.size = int(dict(self.res_hdrs)["Content-Length"]) | |
# Open temp file and seek to the end of it. | |
if not self.file: | |
fd, self.file = mkstemp() | |
if type(self.file) == str: # if path to file | |
file = open(self.file, 'ab') | |
else: # if file like object | |
file = self.file | |
file.seek(self.received) | |
# Set status to DOWNLOADING because downloading has started | |
self.status = DOWNLOADING | |
self.stats_changed() | |
# Time stamp download time start | |
self.timestamp = time.time() | |
while (self.status == DOWNLOADING): | |
# TODO: description of tahat value | |
self._recent_start = time.time() # recent start | |
# Read from server into buffer. | |
chunk = response.read(BUFFER_SIZE) | |
if not chunk: break | |
# Write buffer to file. | |
file.write(chunk) | |
self._recent_bytes = _read = len(chunk) | |
self.received += _read | |
# Limit the bandwidth by pausing the download for an amount | |
# of time. | |
self._limit_bandwidth() | |
# Notify observers. | |
self.stats_changed(), | |
# Check for valid content length. Raise exception if actual size | |
# does not match "content-length" header | |
if self.status == DOWNLOADING: | |
if self.size >= 0 and self.received < self.size: | |
raise INread('Error: %d expected %d get' % (self.received, | |
self.size)) | |
# Change status to complete if this point was | |
# reached because downloading has finished. | |
if self.status == DOWNLOADING: | |
self.status = COMPLETE | |
self.stats_changed() | |
except Exception as e: | |
self.error() | |
finally: | |
# Close file. | |
if file != None: | |
try: | |
file.close() | |
except Exception as e: | |
self.error(e) | |
# Close connection to server. | |
if conn != None: | |
try: | |
conn.close() | |
except Exception as e: | |
self.error(e) | |
def stats_changed(self,): | |
'Notify observers that this download\'s stats has changed.' | |
self.notify_observers() | |
class DLTaskObserver(Observable): | |
DL_SPEED_HISTORY_SIZE = 20 | |
DL_SPEED_SAMPLE_MIN = 0.15 | |
# STALL_START_TIME = 5 | |
def __init__ (self, download): | |
self._history = { | |
'times': [0] * self.DL_SPEED_HISTORY_SIZE, | |
'bytes': [0] * self.DL_SPEED_HISTORY_SIZE, | |
'pos' : 0, | |
} | |
self._recent_chunk = 0 | |
self._recent_start = 0 | |
download.add_observer(self) | |
self.download = download | |
# def add_download(self, download): | |
# download.add_observer(self) | |
# self.download = download | |
def calc_eta(self, download): | |
bytes_remaining = download.size - download.received | |
bytes_sofar = download.received | |
if bytes_sofar: | |
secs = (time.time() - download.timestamp) * bytes_remaining / bytes_sofar | |
else: | |
secs = -1 | |
return secs | |
def calc_rate(self): | |
try: | |
return sum(self._history['bytes']) / sum(self._history['times']) | |
except ZeroDivisionError: | |
return -1 | |
def update_speed_ring(self, download): | |
recent_age = self._dltime - self._recent_start | |
self._recent_chunk += self.download._recent_bytes | |
if recent_age < self.DL_SPEED_SAMPLE_MIN: | |
return | |
if not self._recent_chunk: | |
return | |
self._history['bytes'][self._history['pos']] = self._recent_chunk | |
self._history['times'][self._history['pos']] = recent_age | |
self._recent_start = self._dltime | |
self._recent_chunk = 0 | |
if self._history['pos'] == self.DL_SPEED_HISTORY_SIZE - 1: | |
self._history['pos'] = 0 | |
else: | |
self._history['pos'] += 1 | |
def update(self, download): | |
self._dltime = time.time() - download.timestamp | |
self.update_speed_ring(download) | |
self.notify_observers() | |
def __getitem__(self, name): | |
if name == "progress": | |
return self.download.get_progress() | |
elif name == "position": | |
return self.download.received | |
elif name == "status": | |
return self.download.status | |
elif name == "size": | |
return self.download.size | |
elif name == "rate": | |
return self.calc_rate() | |
elif name == "file": | |
return self.download.file | |
elif name == "url": | |
return self.download.get_url() | |
elif name == "eta": | |
return self.calc_eta(self.download) | |
else: | |
raise KeyError(name) | |
class TableModelBase(Observable): | |
def __init__ (self, ): | |
self.cols_ = () # table columns headers | |
self.rows_ = [] # table rows | |
def table_cell_update(self, row, col, val): | |
self.rows_[row][col] = val | |
def get_col_by_name(self, name): | |
return self.cols_.index(name) | |
def get_col_by_idx(self, index): | |
return self.cols_[ index ] | |
def table_row_add(self, observer=None): | |
self.rows_.append( { col: None for col in self.cols_ } ) | |
def table_row_del(self, row): | |
del self.rows_[row] | |
self.stats_changed() # notify observers | |
def table_row_update(self, row): | |
for col in self.cols_: | |
self.table_cell_update(row, col, self.get_value_at(row, col)) | |
else: | |
self.stats_changed() # notify observers | |
def get_value_at(self, row, col): | |
pass | |
class DownloadTableModel(TableModelBase): | |
def __init__ (self): | |
TableModelBase.__init__ (self, ) | |
self.cols_ = ( "progress", "position", "status", "size", "rate", | |
"file", "url", "eta", ) | |
# The table's list of downloads. | |
self.download_list = [] | |
# The table's list of downloads URLs. | |
self.download_urls = [] | |
def add_download(self, url, **kwargs): | |
"Add a new download to the table." | |
# Make shure there is only one download per URL. | |
if url in self.download_urls: return | |
# Register to be notified when the download changes. | |
download = DLTaskObserver(http_get_file(url, **kwargs)) | |
download.add_observer(self) | |
self.download_list.append(download) | |
self.download_urls.append(url) | |
self.table_row_add() | |
def get_download(self, row): | |
"Get a download for the specified row." | |
return self.download_list[row].download | |
def __getitem__(self, row): | |
return self.rows_[ row ] | |
def del_download(self, row): | |
"Remove a download from the list." | |
del self.download_list[row] | |
self.table_row_del(row) | |
def get_download_by_status(self, status): | |
return [d.download for d in self.download_list | |
if d.download.status == status] | |
def get_value_at(self, row, col): | |
"Get value for a specific row and column combination." | |
return self.download_list[ row ][ col ] | |
def update(self, observable, *args, **kw): | |
'''Update is called when a Download notifies its | |
observers of any change''' | |
index = self.download_list.index(observable) | |
# Table row update notification to table. | |
self.table_row_update(index) | |
class Singleton(object): | |
'''Implement Pattern: SINGLETON''' | |
__single = None # the one, true Singleton | |
def __new__(classtype, *args, **kwargs): | |
# Check to see if a __single exists already for this class | |
# Compare class types instead of just looking for None so | |
# that subclasses will create their own __single objects | |
if classtype != type(classtype.__single): | |
classtype.__single = object.__new__(classtype, *args, **kwargs) | |
return classtype.__single | |
def __init__(self,name=None): | |
self.name = name | |
def display(self): | |
print (self.name,id(self),type(self)) | |
class DownloadManager(Singleton): | |
DEFAULT_QUEUE_SIZE = 20 | |
def __init__ (self): | |
self.TableModel = DownloadTableModel() | |
self.TableModel.add_observer(self) | |
self.queue_size = self.DEFAULT_QUEUE_SIZE | |
self.selected_download = 0 | |
def download_add(self, url, **kwargs): | |
self.TableModel.add_download(self.verify_url(url), **kwargs) | |
@staticmethod | |
def verify_url(url): | |
"Verify download URL" | |
verified_url = rewrite_shorthand_url(url.lower()) | |
# Only allow HTTP URLs. | |
if not verified_url.lower().startswith('http://'): | |
raise ValueError('Only allow HTTP URLs.') | |
# Verify format of URL. | |
if not urlparse(verified_url).netloc: | |
raise ValueError('Mall formed URL') | |
return verified_url | |
def download_del(self, row): | |
# self.TableModel.get_download(row).cancel() | |
self.TableModel[row].cancel() | |
self.TableModel.del_download(row) | |
def download_select(self, row): | |
self.selected_download = self.TableModel.get_download(row) | |
def _get_download(self, row=None): | |
if row != None: | |
download = self.TableModel.get_download(row) | |
else: | |
download = self.selected_download | |
return download | |
def download_cancel(self, row=None): | |
self._get_download(row).cancel() | |
def download_resume(self, row=None): | |
self._get_download(row).resume() | |
def download_pause(self, row=None): | |
self._get_download(row).suspend() | |
def download_set_limit(self, limitrate, row=None): | |
# self.TableModel.get_download(row).set_limitrate(float(limitrate)) | |
self.TableModel[row].set_limitrate( float(limitrate) ) | |
def queue_on(self): | |
self.q_size = 1 | |
self.queue_check() | |
def queue_off(self): | |
self.is_queue = self.DEFAULT_QUEUE_SIZE | |
self.queue_check() | |
def queue_check(self): | |
active_downloads = self.TableModel.get_download_by_status(1) | |
if len(active_downloads) > self.queue_size: | |
for download in active_downloads[self.queue_size:]: | |
download.suspend() | |
else: | |
paused_downloads = self.TableModel.get_download_by_status(3) | |
for download in paused_downloads[:self.queue_size]: | |
download.resume() | |
def update(self): | |
self.queue_check() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from sys import stdout as _stdout | |
class textprogressbar: | |
stdout = _stdout | |
def __init__(self, value=0, maximum=-1, message=""): | |
self.itr = 0 | |
self.val = value | |
self.max = maximum | |
self.msg = message | |
def clear(self): | |
self.stdout.write('\n') | |
self.stdout.flush() | |
self.itr = 0 | |
self.val = -1 | |
self.max = 0 | |
self.msg = "" | |
def update(self): | |
if self.max > 0: | |
# we know the maximum; draw a progress bar | |
percent = 100 * self.val / self.max | |
hashes = "=" * int(38 * percent / 100) | |
if percent % 2: | |
hashes += ">" | |
msg = "\r%3.f%% [%-38s] %s" % (percent, hashes, self.msg) | |
else: | |
# we don't know the maximum, so we can't draw a progress bar | |
# 40 spaces minus 2 for brackets "[" and "]" | |
center = self.itr % 36 + 1 | |
before = " " * (center - 1) | |
after = " " * (36 - center) | |
msg = "\r??%% [%s===%s] %s" % (before, after, self.msg) | |
self.itr = self.itr + 1 | |
self.stdout.write(msg) | |
self.stdout.flush() | |
def set_status(self, value, maximum): | |
self.max = maximum | |
self.val = value | |
def set_message(self, message): | |
self.msg = message | |
def test_1(n=386, msg=""): | |
pbar = textprogressbar() | |
for i in range(n+1): | |
pbar.set_status(i, n) | |
pbar.set_message("%d of %d" % (i, n)) | |
pbar.update() | |
time.sleep(0.05) | |
else: | |
pbar.clear() | |
def test_2(n=386, msg=""): | |
pbar = textprogressbar() | |
for i in range(n+1): | |
pbar.set_status(i, -1) | |
pbar.set_message("%d of %d" % (i, n)) | |
pbar.update() | |
time.sleep(0.05) | |
else: | |
pbar.clear() | |
if __name__ == '__main__': | |
import time | |
test_1() | |
test_2() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment