Skip to content

Instantly share code, notes, and snippets.

@yuki-yano
Last active March 22, 2018 11:32
Show Gist options
  • Select an option

  • Save yuki-yano/920415ac8d2087c62c294f05cabfe8ba to your computer and use it in GitHub Desktop.

Select an option

Save yuki-yano/920415ac8d2087c62c294f05cabfe8ba to your computer and use it in GitHub Desktop.
imperial
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Imperial
# ======
# Copyright (C) 2015 FTG-Reversal
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import os
import cv2
import numpy as np
import re
import json
from flask import Flask
from flask import abort
sys.path.append('.')
import imperial.constants
from imperial import CharaDetector
CHARAS = [
("Axl", "ax"),
("Bedman", "be"),
("Chipp", "ch"),
("Dizzy", "di"),
("Elphelt", "el"),
("Faust", "fa"),
("Ino", "in"),
("Jam", "ja"),
("Jack-O", "jc"),
("Johnny", "jo"),
("Kum", "ku"),
("Ky", "ky"),
("Leo", "le"),
("May", "ma"),
("Millia", "mi"),
("Potemkin", "po"),
("Ramlethal", "ra"),
("Raven", "rv"),
("Sin", "si"),
("Slayer", "sl"),
("Sol", "so"),
("Venom", "ve"),
("Zato", "za")
]
def fetch_data(video_id, fetch_dir):
if not os.path.isdir(fetch_dir):
os.mkdir(fetch_dir)
os.mkdir(fetch_dir + "/1p")
os.mkdir(fetch_dir + "/2p")
os.system("nicovideo-dump http://www.nicovideo.jp/watch/" + video_id + " | ffmpeg -i pipe:0 -filter_complex '[0:v]fps=1, scale=640:-1, split[tmp1][tmp2]; [tmp1]crop=40:40:20:0[out1]; [tmp2]crop=40:40:580:0[out2]' -map '[out1]' " + fetch_dir + "/1p/%04d.png -map '[out2]' " + fetch_dir + "/2p/%04d.png")
def read_fetched_data(fetch_dir, player, detector):
files = sorted(os.listdir(fetch_dir + '/' + player))
probas = np.empty((len(files), len(imperial.constants.charas)))
for file_index, file in enumerate(files):
img = cv2.imread(fetch_dir + '/' + player + '/' + file)
assert img is not None
height = img.shape[0]
width = img.shape[1]
img = cv2.resize(img,(width * 5, height * 5))
proba = detector.classify_proba(img)
probas[file_index] = proba
return probas.T
def convolve(probas, size):
convolved_list = np.array(probas)
g_last = np.array([ *[0.] * (size - 1), 1 / size, *[ 1 / size ] * (size - 1) ])
g_first = np.array([ *[ 1 / size ] * (size - 1), 1 / size, *[0.] * (size - 1) ])
for i, proba in enumerate(probas):
last = np.convolve(proba, g_last, 'same')
first = np.convolve(proba, g_first, 'same')
# lastの先頭とfirstの終端に異常に小さな値が入ってしまうのを修正する
# TODO: 上手くいってない気がする
for j in range(size - 1):
last[j] = last[j] / (j + 1) * (size - (j + 1))
for j in range(size - 1):
first[len(first - j) - 1] = first[len(first - j) - 1] / (j + 1) * (size - (j + 1))
convolved_list[i] = np.min([last, first], axis=0)
return convolved_list
def probas_to_scores(probas):
scores = np.empty((0, 2), dtype=np.float32)
for score in probas.T:
scores = np.append(scores, np.array([[np.argmax(score), np.amax(score)]]), axis=0)
return scores
def reduce_pair_func(reducer, pair):
# numpyのbool配列が返ってくる
boolean = reducer[-1][1][0] == pair[1][0]
if boolean[0] == True and boolean[1] == True:
return
reducer.append(pair)
def id2name(id):
return CHARAS[id][0]
# アプリ初期化
chara_detector_1p = CharaDetector(word_num=50, surf_thresh=50)
classnames = []
chara_classnames = []
for charaname, classname in CHARAS:
dir = "%s/%s/%s" % ('train', '1p', classname)
print(charaname)
for root, dirs, files in os.walk(dir):
for file in sorted(files):
if file.endswith(".png") or file.endswith(".jpg"):
f = os.path.join(root, file)
print(f)
img = cv2.imread(f)
height = img.shape[0]
width = img.shape[1]
img = cv2.resize(img,(width * 5, height * 5))
chara_detector_1p.update(img, classname)
chara_detector_1p.train()
chara_detector_2p = CharaDetector(word_num=50, surf_thresh=50)
for charaname, classname in CHARAS:
dir = "%s/%s/%s" % ('train', '2p', classname)
print(charaname)
for root, dirs, files in os.walk(dir):
for file in sorted(files):
if file.endswith(".png") or file.endswith(".jpg"):
f = os.path.join(root, file)
print(f)
img = cv2.imread(f)
height = img.shape[0]
width = img.shape[1]
img = cv2.resize(img,(width * 5, height * 5))
chara_detector_2p.update(img, classname)
chara_detector_2p.train()
server = Flask(__name__)
@server.route('/<score_threshold>/<frame_threshold>/<video_id>', methods=['GET'])
def predict(score_threshold, frame_threshold, video_id):
if not re.match('^sm\d+$', video_id):
return abort(500)
fetch_data(video_id, './fetched_data/' + video_id)
# 1Pと2Pのデータを23次元のデータとしてそれぞれ読み込む
probas_1p = read_fetched_data('./fetched_data/' + video_id, '1p', chara_detector_1p)
probas_2p = read_fetched_data('./fetched_data/' + video_id, '2p', chara_detector_2p)
# 1Pと2Pの確率をキャラ毎にたたみ込む
convolved_probas_1p = convolve(probas_1p, len(CHARAS))
convolved_probas_2p = convolve(probas_2p, len(CHARAS))
# フレーム毎の確率が高いキャラのペアをスコアとして扱う
scores_1p = probas_to_scores(convolved_probas_1p)
scores_2p = probas_to_scores(convolved_probas_2p)
# フレーム毎のスコアを1Pと2PでまとめてキャラIDとスコアのペアが入ったリストを作る
zip_scores = np.dstack((scores_1p, scores_2p))
# スコアが一定以上のものを真として各フレーム毎の対戦組み合わせのリストを生成する
pairs = []
for sec, score in enumerate(zip_scores):
if (score[1][0] + score[1][1]) * 100 > int(score_threshold):
pairs.append((sec, score))
# 連続したフレームに残っている組み合わせを畳み込む
reducer = [pairs[0]]
for pair in pairs:
reduce_pair_func(reducer, pair)
# 特定のフレーム以下しか存在しないペアを削除する
index = len(reducer) - 1
while True:
if index == 1:
break
if reducer[index][0] - reducer[index - 1][0] < int(frame_threshold):
del reducer[index - 1]
index -= 1
# 先頭の要素を削除
reducer.pop(0)
# reducerの最後の要素が特定フレーム以下の長さなら削除
while True:
if len(os.listdir('./fetched_data/' + video_id+ '/1p')) - reducer[-1][0] < int(frame_threshold):
reducer.pop()
else:
break
pairs = []
for pair in reducer:
dict = {}
dict['sec'] = pair[0]
# TODO: pairオブジェクトが無駄に1階層深くなってしまってる
dict['1p'] = id2name(int(pair[1][0][0]))
dict['2p'] = id2name(int(pair[1][0][1]))
pairs.append(dict)
print('sec: ' + str(pair[0]) + ', ' + id2name(int(pair[1][0][0])) + ' vs ' + id2name(int(pair[1][0][1])))
return json.dumps(pairs, ensure_ascii=False)
if __name__ == '__main__':
server.run(host='0.0.0.0', debug=True, port=80)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment