Skip to content

Instantly share code, notes, and snippets.

@yzhangcs
Created October 29, 2020 14:36
Show Gist options
  • Save yzhangcs/48661e5dcc06a54cc7d1f8a198913c89 to your computer and use it in GitHub Desktop.
Save yzhangcs/48661e5dcc06a54cc7d1f8a198913c89 to your computer and use it in GitHub Desktop.
Script for evaluating conll file
# -*- coding: utf-8 -*-
import argparse
from collections import Counter
def factorize(tags):
spans = []
for i, tag in enumerate(tags):
if tag.startswith(('B', 'S')):
spans.append([i, i, tag.split('-', 1)[1]])
elif not tag.startswith('O'):
spans[-1][1] += 1
spans = [tuple(i) for i in spans]
return spans
class Metric(object):
def __lt__(self, other):
return self.score < other
def __le__(self, other):
return self.score <= other
def __ge__(self, other):
return self.score >= other
def __gt__(self, other):
return self.score > other
@property
def score(self):
return 0.
class SpanMetric(Metric):
def __init__(self, eps=1e-5):
super(SpanMetric, self).__init__()
self.tp = 0.0
self.utp = 0.0
self.pred = 0.0
self.gold = 0.0
self.eps = eps
def __call__(self, preds, golds):
for pred, gold in zip(preds, golds):
lpred = Counter(pred)
lgold = Counter(gold)
upred = Counter([(i, j) for i, j, label in pred])
ugold = Counter([(i, j) for i, j, label in gold])
self.tp += len(list((lpred & lgold).elements()))
self.utp += len(list((upred & ugold).elements()))
self.pred += len(pred)
self.gold += len(gold)
def __repr__(self):
return f"UP: {self.up:6.2%} UR: {self.ur:6.2%} UF: {self.uf:6.2%} P: {self.p:6.2%} R: {self.r:6.2%} F: {self.f:6.2%}"
@property
def score(self):
return self.f
@property
def up(self):
return self.utp / (self.pred + self.eps)
@property
def ur(self):
return self.utp / (self.gold + self.eps)
@property
def uf(self):
return 2 * self.utp / (self.pred + self.gold + self.eps)
@property
def p(self):
return self.tp / (self.pred + self.eps)
@property
def r(self):
return self.tp / (self.gold + self.eps)
@property
def f(self):
return 2 * self.tp / (self.pred + self.gold + self.eps)
def read(path):
start, spans = 0, []
with open(path, 'r') as f:
lines = [line.strip() for line in f]
for i, line in enumerate(lines):
if not line:
spans.append(factorize([line.split()[-1] for line in lines[start:i]]))
start = i + 1
return spans
def evaluate(fpred, fgold):
pred_spans = read(fpred)
gold_spans = read(fgold)
metric = SpanMetric()
metric(pred_spans, gold_spans)
print(metric)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Output some metrics w.r.t. the predicted results.'
)
parser.add_argument('--fpred', '-s', help='path to predicted result')
parser.add_argument('--fgold', '-g', help='path to gold dataset')
args = parser.parse_args()
evaluate(args.fpred, args.fgold)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment