Skip to content

Instantly share code, notes, and snippets.

@knok
Last active July 4, 2017 03:59
Show Gist options
  • Save knok/c1e82c39de57800b7d3ae4219e0e8c0d to your computer and use it in GitHub Desktop.
Save knok/c1e82c39de57800b7d3ae4219e0e8c0d to your computer and use it in GitHub Desktop.
猫画像から猫部分のみを抽出する(matting/semantig segmentation) ref: http://qiita.com/knok/items/6ad09cc870739dbd921b
$ pip install -r requirements.txt
$ curl -L https://github.com/nicolov/segmentation_keras/releases/download/model/nicolov_segmentation_model.tar.gz \
| tar xvf -
$ python predict.py --weights_path \
conversion/converted/dilation8_pascal_voc.npy \
images/cat.jpg
# -*- coding: utf-8 -*-
import argparse
import os, sys
from PIL import Image
import numpy as np
def get_args():
p = argparse.ArgumentParser()
p.add_argument("--contents-dir", '-c', default=None)
p.add_argument("--segments-dir", '-s', default=None)
p.add_argument("--output-dir", '-o', default=None)
p.add_argument("--cat-label-vals", '-v', default="64,0,0")
args = p.parse_args()
if args.contents_dir is None or args.segments_dir is None or args.output_dir is None:
p.print_help()
sys.exit(1)
return args
def cat_col_array(val_str):
vals = val_str.split(',')
vals = [int(i) for i in vals]
return vals
def cmpary(a1, a2):
v1 = np.asarray(a1)
v2 = np.asarray(a2)
norm = np.linalg.norm(v1-v2)
if norm <= 32:
return True
return False
def make_images(cont_dir, seg_dir, out_dir, cat_vals):
files = []
for fname in os.listdir(cont_dir): # make target file list
seg_fname = os.path.join(seg_dir, fname)
if os.path.exists(seg_fname):
files.append(fname)
for fname in files:
print("processing: %s" % fname)
c_fname = os.path.join(cont_dir, fname)
cont = np.asarray(Image.open(c_fname)).copy()
s_fname = os.path.join(seg_dir, fname)
seg = np.asarray(Image.open(s_fname))
width, height, _ = seg.shape
for y in range(height):
for x in range(width):
vals = seg[x, y]
if not cmpary(vals, cat_vals):
cont[x, y] = [255, 255, 255]
out = Image.fromarray(cont)
o_fname = os.path.join(out_dir, fname)
out.save(o_fname)
def main():
args = get_args()
cat_vals = cat_col_array(args.cat_label_vals)
make_images(args.contents_dir, args.segments_dir, args.output_dir,
cat_vals)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment