Created
October 29, 2020 14:36
-
-
Save yzhangcs/48661e5dcc06a54cc7d1f8a198913c89 to your computer and use it in GitHub Desktop.
Script for evaluating conll file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import argparse | |
from collections import Counter | |
def factorize(tags): | |
spans = [] | |
for i, tag in enumerate(tags): | |
if tag.startswith(('B', 'S')): | |
spans.append([i, i, tag.split('-', 1)[1]]) | |
elif not tag.startswith('O'): | |
spans[-1][1] += 1 | |
spans = [tuple(i) for i in spans] | |
return spans | |
class Metric(object): | |
def __lt__(self, other): | |
return self.score < other | |
def __le__(self, other): | |
return self.score <= other | |
def __ge__(self, other): | |
return self.score >= other | |
def __gt__(self, other): | |
return self.score > other | |
@property | |
def score(self): | |
return 0. | |
class SpanMetric(Metric): | |
def __init__(self, eps=1e-5): | |
super(SpanMetric, self).__init__() | |
self.tp = 0.0 | |
self.utp = 0.0 | |
self.pred = 0.0 | |
self.gold = 0.0 | |
self.eps = eps | |
def __call__(self, preds, golds): | |
for pred, gold in zip(preds, golds): | |
lpred = Counter(pred) | |
lgold = Counter(gold) | |
upred = Counter([(i, j) for i, j, label in pred]) | |
ugold = Counter([(i, j) for i, j, label in gold]) | |
self.tp += len(list((lpred & lgold).elements())) | |
self.utp += len(list((upred & ugold).elements())) | |
self.pred += len(pred) | |
self.gold += len(gold) | |
def __repr__(self): | |
return f"UP: {self.up:6.2%} UR: {self.ur:6.2%} UF: {self.uf:6.2%} P: {self.p:6.2%} R: {self.r:6.2%} F: {self.f:6.2%}" | |
@property | |
def score(self): | |
return self.f | |
@property | |
def up(self): | |
return self.utp / (self.pred + self.eps) | |
@property | |
def ur(self): | |
return self.utp / (self.gold + self.eps) | |
@property | |
def uf(self): | |
return 2 * self.utp / (self.pred + self.gold + self.eps) | |
@property | |
def p(self): | |
return self.tp / (self.pred + self.eps) | |
@property | |
def r(self): | |
return self.tp / (self.gold + self.eps) | |
@property | |
def f(self): | |
return 2 * self.tp / (self.pred + self.gold + self.eps) | |
def read(path): | |
start, spans = 0, [] | |
with open(path, 'r') as f: | |
lines = [line.strip() for line in f] | |
for i, line in enumerate(lines): | |
if not line: | |
spans.append(factorize([line.split()[-1] for line in lines[start:i]])) | |
start = i + 1 | |
return spans | |
def evaluate(fpred, fgold): | |
pred_spans = read(fpred) | |
gold_spans = read(fgold) | |
metric = SpanMetric() | |
metric(pred_spans, gold_spans) | |
print(metric) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description='Output some metrics w.r.t. the predicted results.' | |
) | |
parser.add_argument('--fpred', '-s', help='path to predicted result') | |
parser.add_argument('--fgold', '-g', help='path to gold dataset') | |
args = parser.parse_args() | |
evaluate(args.fpred, args.fgold) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment