Last active
June 26, 2019 13:02
-
-
Save yueyericardo/ad4e4a5ed5f4a95403f0c414523335fc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Progbar(object): | |
"""Displays a progress bar. | |
Arguments: | |
target: Total number of steps expected, None if unknown. | |
width: Progress bar width on screen. | |
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) | |
stateful_metrics: Iterable of string names of metrics that | |
should *not* be averaged over time. Metrics in this list | |
will be displayed as-is. All others will be averaged | |
by the progbar before display. | |
interval: Minimum visual progress update interval (in seconds). | |
unit_name: Display name for step counts (usually "step" or "sample"). | |
""" | |
def __init__(self, target, width=30, verbose=1, interval=0.05, | |
stateful_metrics=None, unit_name='step'): | |
self.target = target | |
self.width = width | |
self.verbose = verbose | |
self.interval = interval | |
self.unit_name = unit_name | |
if stateful_metrics: | |
self.stateful_metrics = set(stateful_metrics) | |
else: | |
self.stateful_metrics = set() | |
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and | |
sys.stdout.isatty()) or | |
'ipykernel' in sys.modules or | |
'posix' in sys.modules) | |
self._total_width = 0 | |
self._seen_so_far = 0 | |
# We use a dict + list to avoid garbage collection | |
# issues found in OrderedDict | |
self._values = {} | |
self._values_order = [] | |
self._start = time.time() | |
self._last_update = 0 | |
def update(self, current, values=None): | |
"""Updates the progress bar. | |
Arguments: | |
current: Index of current step. | |
values: List of tuples: | |
`(name, value_for_last_step)`. | |
If `name` is in `stateful_metrics`, | |
`value_for_last_step` will be displayed as-is. | |
Else, an average of the metric over time will be displayed. | |
""" | |
values = values or [] | |
for k, v in values: | |
if k not in self._values_order: | |
self._values_order.append(k) | |
if k not in self.stateful_metrics: | |
if k not in self._values: | |
self._values[k] = [v * (current - self._seen_so_far), | |
current - self._seen_so_far] | |
else: | |
self._values[k][0] += v * (current - self._seen_so_far) | |
self._values[k][1] += (current - self._seen_so_far) | |
else: | |
# Stateful metrics output a numeric value. This representation | |
# means "take an average from a single value" but keeps the | |
# numeric formatting. | |
self._values[k] = [v, 1] | |
self._seen_so_far = current | |
now = time.time() | |
info = ' - %.0fs' % (now - self._start) | |
if self.verbose == 1: | |
if (now - self._last_update < self.interval and | |
self.target is not None and current < self.target): | |
return | |
prev_total_width = self._total_width | |
if self._dynamic_display: | |
sys.stdout.write('\b' * prev_total_width) | |
sys.stdout.write('\r') | |
else: | |
sys.stdout.write('\n') | |
if self.target is not None: | |
numdigits = int(np.log10(self.target)) + 1 | |
bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) | |
prog = float(current) / self.target | |
prog_width = int(self.width * prog) | |
if prog_width > 0: | |
bar += ('=' * (prog_width - 1)) | |
if current < self.target: | |
bar += '>' | |
else: | |
bar += '=' | |
bar += ('.' * (self.width - prog_width)) | |
bar += ']' | |
else: | |
bar = '%7d/Unknown' % current | |
self._total_width = len(bar) | |
sys.stdout.write(bar) | |
if current: | |
time_per_unit = (now - self._start) / current | |
else: | |
time_per_unit = 0 | |
if self.target is not None and current < self.target: | |
eta = time_per_unit * (self.target - current) | |
if eta > 3600: | |
eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60) | |
elif eta > 60: | |
eta_format = '%d:%02d' % (eta // 60, eta % 60) | |
else: | |
eta_format = '%ds' % eta | |
info = ' - ETA: %s' % eta_format | |
else: | |
if time_per_unit >= 1 or time_per_unit == 0: | |
info += ' %.0fs/%s' % (time_per_unit, self.unit_name) | |
elif time_per_unit >= 1e-3: | |
info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) | |
else: | |
info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) | |
for k in self._values_order: | |
info += ' - %s:' % k | |
if isinstance(self._values[k], list): | |
avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) | |
if abs(avg) > 1e-3: | |
info += ' %.4f' % avg | |
else: | |
info += ' %.4e' % avg | |
else: | |
info += ' %s' % self._values[k] | |
self._total_width += len(info) | |
if prev_total_width > self._total_width: | |
info += (' ' * (prev_total_width - self._total_width)) | |
if self.target is not None and current >= self.target: | |
info += '\n' | |
sys.stdout.write(info) | |
sys.stdout.flush() | |
elif self.verbose == 2: | |
if self.target is not None and current >= self.target: | |
numdigits = int(np.log10(self.target)) + 1 | |
count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) | |
info = count + info | |
for k in self._values_order: | |
info += ' - %s:' % k | |
avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) | |
if avg > 1e-3: | |
info += ' %.4f' % avg | |
else: | |
info += ' %.4e' % avg | |
info += '\n' | |
sys.stdout.write(info) | |
sys.stdout.flush() | |
self._last_update = now | |
def add(self, n, values=None): | |
self.update(self._seen_so_far + n, values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_bar(i, iter_per_epoch): | |
numdigits = int(np.log10(iter_per_epoch)) + 1 | |
bar = ('%' + str(numdigits) + 'd/%d ') % (i, iter_per_epoch) | |
numdigits = int(np.log10(iter_per_epoch)) + 1 | |
bar = ('%' + str(numdigits) + 'd/%d [') % (i, iter_per_epoch) | |
prog = float(i) / iter_per_epoch | |
prog_width = int(30 * prog) | |
if prog_width > 0: | |
bar += ('=' * (prog_width - 1)) | |
if i < iter_per_epoch: | |
bar += '>' | |
else: | |
bar += '=' | |
bar += ('.' * (30 - prog_width)) | |
bar += ']' | |
return bar | |
for epoch in range(num_epochs): | |
print('Epoch: %d/%d' % (epoch+1, num_epochs)) | |
start_time_epoch = time.time() | |
train_iterator = tf.compat.v1.data.make_one_shot_iterator(train_dataset) | |
train_iterator_handle = sess.run(train_iterator.string_handle()) | |
for i in range(iter_per_epoch): | |
_, loss_value = sess.run([optimizer, loss], feed_dict={handle: train_iterator_handle}) | |
if i % 10 == 0: | |
bar = get_bar(i, iter_per_epoch) | |
now = time.time() | |
# info = ' - %.0fs' % (now - self._start) | |
time_per_unit = (now - start_time_epoch) / (i+1) | |
eta = time_per_unit * (iter_per_epoch - i) | |
def get_eta(eta): | |
if eta > 3600: | |
eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60) | |
elif eta > 60: | |
eta_format = '%02d:%02d' % (eta // 60, eta % 60) | |
else: | |
eta_format = '%ds' % eta | |
return eta_format | |
eta_format = get_eta(eta) | |
info_eta = ' - ETA: %s' % eta_format | |
if abs(loss_value) > 1e-2: | |
info_loss = ' - loss: %.4f' % loss_value | |
else: | |
info_loss = ' - loss: %.4e' % loss_value | |
sys.stdout.write(bar+info_eta+info_loss+' '*10) | |
sys.stdout.write('\r') | |
sys.stdout.flush() | |
model.save(directory = model_dir, filename = model_name) | |
bar = get_bar(iter_per_epoch, iter_per_epoch) | |
end_time_epoch = time.time() | |
time_elapsed_epoch = end_time_epoch - start_time_epoch | |
eta_format = get_eta(time_elapsed_epoch) | |
info_eta = ' - Time: %s' % eta_format | |
sys.stdout.write(bar+info_eta+info_loss+' '*10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, time | |
class ShowProcess(): | |
""" | |
显示处理进度的类 | |
调用该类相关函数即可实现处理进度的显示 | |
""" | |
i = 0 # 当前的处理进度 | |
max_steps = 0 # 总共需要处理的次数 | |
max_arrow = 50 #进度条的长度 | |
infoDone = 'done' | |
# 初始化函数,需要知道总共的处理次数 | |
def __init__(self, max_steps, infoDone = 'Done'): | |
self.max_steps = max_steps | |
self.i = 0 | |
self.infoDone = infoDone | |
# 显示函数,根据当前的处理进度i显示进度 | |
# 效果为[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]100.00% | |
def show_process(self, i=None): | |
if i is not None: | |
self.i = i | |
else: | |
self.i += 1 | |
num_arrow = int(self.i * self.max_arrow / self.max_steps) #计算显示多少个'>' | |
num_line = self.max_arrow - num_arrow #计算显示多少个'-' | |
percent = self.i * 100.0 / self.max_steps #计算完成进度,格式为xx.xx% | |
process_bar = '[' + '>' * num_arrow + '-' * num_line + ']'\ | |
+ '%.2f' % percent + '%' + '\r' #带输出的字符串,'\r'表示不换行回到最左边 | |
sys.stdout.write(process_bar) #这两句打印字符到终端 | |
sys.stdout.flush() | |
if self.i >= self.max_steps: | |
self.close() | |
def close(self): | |
print('') | |
print(self.infoDone) | |
self.i = 0 | |
if __name__=='__main__': | |
max_steps = 100 | |
process_bar = ShowProcess(max_steps, 'OK') | |
for i in range(max_steps): | |
process_bar.show_process() | |
time.sleep(0.01) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment