Last active
May 8, 2019 04:04
-
-
Save serihiro/5d5c4a177b42a9f80722df18759478e4 to your computer and use it in GitHub Desktop.
a modified script of `compute_mean.py` at Chainer imagenet example ( https://github.com/chainer/chainer/blob/6c495789f09ecea44449a60f905bd78e1773ce74/examples/imagenet/compute_mean.py )
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import sys | |
import numpy as np | |
import concurrent.futures | |
import chainer | |
def compute_mean(file): | |
dataset = chainer.datasets.LabeledImageDataset(file, root) | |
print('compute mean image') | |
sum_image = 0 | |
N = len(dataset) | |
for i, (image, _) in enumerate(dataset): | |
sum_image += image | |
sys.stderr.write('{} / {}\r'.format(i, N)) | |
sys.stderr.flush() | |
sys.stderr.write('\n') | |
return sum_image, N | |
def main(): | |
parser = argparse.ArgumentParser(description='Compute images mean array') | |
parser.add_argument('dataset_list', | |
help='Path to a list file of training image-label list files (ssv)') | |
parser.add_argument('--root', '-R', default='.', | |
help='Root directory path of image files') | |
parser.add_argument('--output', '-o', default='mean.npy', | |
help='path to output mean array') | |
parser.add_argument('--concurrency', type=int, default=10) | |
args = parser.parse_args() | |
file = open(args.dataset) | |
file_list = file.read().split("\n") | |
try: | |
file_list.remove("") | |
except(ValueError): | |
pass | |
print(file_list) | |
file.close() | |
mean_list = None | |
global root | |
root = args.root | |
with concurrent.futures.ProcessPoolExecutor(max_workers=args.concurrency) as executor: | |
mean_list = list(executor.map(compute_mean, file_list)) | |
mean = 0 | |
num = 0 | |
for m,n in mean_list: | |
mean += m | |
num += n | |
mean /= num | |
np.save(args.output, mean) | |
if __name__ == '__main__': | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input', '-i', type=str, required=True) | |
parser.add_argument('--number', '-n', type=int, required=True) | |
args = parser.parse_args() | |
file_len = len(open(args.input).readlines()) | |
file = open(args.input, mode='r') | |
lines_per_file = file_len / args.number | |
n = 0 | |
output = open(f'{args.input}.{n}', mode='w') | |
current_line = 0 | |
for line in file: | |
output.write(line) | |
current_line += 1 | |
if current_line >= lines_per_file: | |
current_line = 0 | |
output.close() | |
n += 1 | |
output = open(f'{args.input}.{n}', mode='w') | |
output.close() | |
file.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
example