Last active
July 20, 2018 01:02
-
-
Save kanjirz50/616b3a1c069dc4b0a4d9357457f6a105 to your computer and use it in GitHub Desktop.
森羅の評価用スクリプト(データ読み込み先のパスは適当に指定してください) MIT License
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
import pandas as pd | |
class WikiepediaEvaluation: | |
"""抽出した属性情報を評価する。 | |
使い方: | |
from evaluation import WikiepediaEvaluation | |
# 学習データの読み込み | |
with open('../data/train-20180516T074039Z-001/company_train.json') as f: | |
data = json.load(f) | |
preds = SomeExtractor.extract(data) | |
we = WikiepediaEvaluation(data) | |
result = we.evaluate(preds) | |
""" | |
def __init__(self, train_data): | |
self.train_data = train_data | |
def init_attribute_counts(self): | |
first_entry = self.train_data['entry'][0] | |
self.attribute_counts = dict( | |
[(attr, Counter(true_positive=0, false_positive=0, false_negative=0, support=0)) | |
for attr in first_entry['Attributes'].keys()] | |
) | |
def evaluate(self, preds): | |
"""評価""" | |
self.init_attribute_counts() | |
for pred, test in zip(preds, self.train_data['entry']): | |
for k in test['Attributes'].keys(): | |
attribute_count = self.attribute_counts[k] | |
tp, fp, fn = self.count_tp_fp_fn(pred.get(k, []), test["Attributes"][k]) | |
attribute_count['true_positive'] += tp | |
attribute_count['false_positive'] += fp | |
attribute_count['false_negative'] += fn | |
attribute_count['support'] += len(test["Attributes"][k]) | |
result = pd.DataFrame( | |
[(k, *self.evaluation(v)) for k, v in self.attribute_counts.items()], | |
columns=["attribute", "precision", "recall", "f1-score", "support"] | |
) | |
return result | |
@staticmethod | |
def count_tp_fp_fn(pred, test): | |
"""True positive, False positive, False negativeの個数を計算する""" | |
results = [(item in test) for item in pred] | |
# 抽出できたうえ,正解 | |
true_positive = results.count(True) | |
# 抽出できたが、不正解 | |
false_positive = results.count(False) | |
# 抽出したいが、抽出できなかった | |
false_negative = len(test) - true_positive | |
return true_positive, false_positive, false_negative | |
@staticmethod | |
def evaluation(c): | |
"""計上したTrue positiveなどから、適合率、再現率、F値をDataframe化する""" | |
# ToDO: 0のときどうするか | |
if c['true_positive'] == 0 and c['false_positive'] == 0 and c['false_negative'] == 0: | |
return | |
if c['true_positive'] == 0 and c['false_positive'] == 0: | |
precision = 0 | |
else: | |
precision = c['true_positive'] / (c['true_positive'] + c['false_positive']) | |
recall = c['true_positive'] / (c['true_positive'] + c['false_negative']) | |
if precision == 0 and recall == 0: | |
f_measure = 0 | |
else: | |
f_measure = (2 * recall * precision) / (recall + precision) | |
return precision, recall, f_measure, c['support'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
TRAIN_JSON_PATHを変更する。 | |
同じディレクトリにevaluation.pyと本スクリプト(test_evaluation.py)がある状態でpytestを実行 | |
$ ls | |
evaluation.py test_evaluation.py | |
$ pytest | |
""" | |
import json | |
from collections import Counter | |
from evaluation import WikiepediaEvaluation | |
TRAIN_JSON_PATH = '../data/train-20180516T074039Z-001/company_train.json' | |
with open(TRAIN_JSON_PATH) as f: | |
data = json.load(f) | |
we = WikiepediaEvaluation(data) | |
def test_count_tp_fp_fn(): | |
# True positive, False positive, False negative | |
assert we.count_tp_fp_fn(["a"], ["a"]) == (1, 0, 0) | |
assert we.count_tp_fp_fn(["a"], ["a", "b"]) == (1, 0, 1) | |
assert we.count_tp_fp_fn(["a", "b"], ["a"]) == (1, 1, 0) | |
assert we.count_tp_fp_fn(["a", "b"], ["a", "c"]) == (1, 1, 1) | |
assert we.count_tp_fp_fn(["a", "b", "c"], ["a", "c"]) == (2, 1, 0) | |
assert we.count_tp_fp_fn([], []) == (0, 0, 0) | |
assert we.count_tp_fp_fn(["a"], []) == (0, 1, 0) | |
def test_evaluation(): | |
# precision, recall, f_measure, c['support'] | |
assert we.evaluation(Counter(true_positive=1, false_positive=0, false_negative=0, support=0)) == (1.0, 1.0, 1.0, 0) | |
assert we.evaluation(Counter(true_positive=1, false_positive=1, false_negative=0, support=0)) == (0.5, 1.0, 0.6666666666666666, 0) | |
assert we.evaluation(Counter(true_positive=1, false_positive=1, false_negative=1, support=0)) == (0.5, 0.5, 0.5, 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment