Skip to content

Instantly share code, notes, and snippets.

@yzhangcs
Last active November 30, 2019 10:33
Show Gist options
  • Save yzhangcs/de4fccb70ebf69719f06abb2b4511d35 to your computer and use it in GitHub Desktop.
Save yzhangcs/de4fccb70ebf69719f06abb2b4511d35 to your computer and use it in GitHub Desktop.
Script used for outputting some metrics w.r.t. the predicted results.
# -*- coding: utf-8 -*-
import argparse
import os
import unicodedata
import torch
def ispunct(token):
return all(unicodedata.category(char).startswith('P')
for char in token)
def isprojective(sequence):
sequence = [0] + list(sequence)
arcs = [(h, d) for d, h in enumerate(sequence[1:], 1) if h >= 0]
for i, (hi, di) in enumerate(arcs):
for hj, dj in arcs[i+1:]:
(li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj])
if (li <= hj <= ri and hi == dj) or (lj <= hi <= rj and hj == di):
return False
if (li < lj < ri or li < rj < ri) and (li - lj) * (ri - rj) > 0:
return False
return True
def numericalize_arcs(sequence):
return [int(i) for i in sequence]
def numericalize_sibs(sequence):
sibs = [-1] * (len(sequence) + 1)
heads = [0] + [int(i) for i in sequence]
for i in range(1, len(heads)):
hi = heads[i]
for j in range(i + 1, len(heads)):
hj = heads[j]
di, dj = hi - i, hj - j
if hi >= 0 and hj >= 0 and hi == hj and di * dj > 0:
if abs(di) > abs(dj):
sibs[i] = j
else:
sibs[j] = i
break
return sibs[1:]
def numericalize_grds(sequence):
grds = [-1] * (len(sequence) + 1)
heads = [len(sequence) + 1] + [int(i) for i in sequence]
for i in range(1, len(heads)):
hi = heads[i]
if hi >= 0:
grds[i] = heads[hi]
return grds[1:]
def read(path):
start, words, arcs, rels, = 0, [], [], []
with open(path, 'r') as f:
lines = [line.strip() for line in f]
for i, line in enumerate(lines):
if not line:
values = list(zip(*[l.split() for l in lines[start:i]]))
words.append(list(values[1]))
arcs.append([int(i) for i in values[6]])
rels.append(list(values[7]))
start = i + 1
return words, arcs, rels
def evaluate(fgold, fpred, evalb=False, punct=False, proj=False):
words, gold_arcs, gold_rels, = read(fgold)
_, pred_arcs, pred_rels = read(fpred)
sib_tp, sib_golds, sib_preds = 0, 0, 0
grd_tp, grd_golds, grd_preds = 0, 0, 0
n, n_ucm, n_lcm, c_arcs, c_rels, total = 0, 0, 0, 0, 0, 0
evallines = []
for sent_id, (w_seq, g_arc, p_arc, g_rel, p_rel) in \
enumerate(zip(words, gold_arcs, pred_arcs, gold_rels, pred_rels)):
if proj and not isprojective([int(i) for i in g_arc]):
continue
mask = torch.tensor([g >= 0 for g in g_arc])
if not punct:
mask &= torch.tensor([not ispunct(w) for w in w_seq])
if not mask.any():
continue
arc_mask = torch.tensor([g == p for g, p in zip(g_arc, p_arc)]) & mask
rel_mask = torch.tensor([g == p for g, p in zip(g_rel, p_rel)])
rel_mask = rel_mask & arc_mask
c_arc = arc_mask.sum().item()
c_rel = rel_mask.sum().item()
c_total = mask.sum().item()
mask = mask.tolist()
g_sib, p_sib = numericalize_sibs(g_arc), numericalize_sibs(p_arc)
g_grd, p_grd = numericalize_grds(g_arc), numericalize_grds(p_arc)
g_arc_sib = {(i, a, s) for i, (a, s) in enumerate(zip(g_arc, g_sib))
if s > 0 and mask[i]}
p_arc_sib = {(i, a, s) for i, (a, s) in enumerate(zip(p_arc, p_sib))
if s > 0 and mask[i]}
sib_tp += len(g_arc_sib & p_arc_sib)
sib_golds += len(g_arc_sib)
sib_preds += len(p_arc_sib)
g_arc_grd = {(i, a, s) for i, (a, s) in enumerate(zip(g_arc, g_grd))
if mask[i]}
p_arc_grd = {(i, a, s) for i, (a, s) in enumerate(zip(p_arc, p_grd))
if mask[i]}
grd_tp += len(g_arc_grd & p_arc_grd)
grd_golds += len(g_arc_grd)
grd_preds += len(p_arc_grd)
n += 1
n_ucm += c_arc == c_total
n_lcm += c_rel == c_total
c_arcs += c_arc
c_rels += c_rel
total += c_total
# Sent. Attachment Correct Scoring
# ID Tokens - Unlab. Lab. HEAD HEAD+DEPREL tokens - - - -
# only considers the LAS here
evallines.append(f" {sent_id+1:4d} {len(mask):4d} 0"
f" {c_rel/c_total*100:6.2f} {c_rel/c_total*100:6.2f}"
f" {c_arc:4d} {c_rel:4d} "
f" {c_total:4d} 0 0 0 0\n")
print(f"SIB:\n"
f" P: {sib_tp:5} / {sib_preds:5} = {sib_tp/sib_preds:6.2%}\n"
f" R: {sib_tp:5} / {sib_golds:5} = {sib_tp/sib_golds:6.2%}\n"
f" F: 2*P*R / (P+R) = {2*sib_tp/(sib_preds+sib_golds):6.2%}\n"
f"GRD:\n"
f" P: {grd_tp:5} / {grd_preds:5} = {grd_tp/grd_preds:6.2%}\n"
f"UCM: {n_ucm:5} / {n:5} = {n_ucm/n:6.2%}\n"
f"LCM: {n_lcm:5} / {n:5} = {n_lcm/n:6.2%}\n"
f"UAS: {c_arcs:5} / {total:5} = {c_arcs/total:6.2%}\n"
f"LAS: {c_rels:5} / {total:5} = {c_rels/total:6.2%}\n")
if evalb:
print(os.path.splitext(fpred)[0]+'.evalb')
with open(os.path.splitext(fpred)[0]+'.evalb', 'w') as f:
f.writelines(evallines)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Output some metrics w.r.t. the predicted results.'
)
parser.add_argument('--fgold', '-g', help='path to gold dataset')
parser.add_argument('--fpred', '-s', help='path to predicted result')
parser.add_argument('--evalb', '-b', action='store_true',
help='produce output in a format similar to evalb')
parser.add_argument('--punct', '-p', action='store_true',
help='also score on punctuation')
parser.add_argument('--proj', action='store_true',
help='whether to projectivise the data')
args = parser.parse_args()
evaluate(args.fgold, args.fpred, args.evalb, args.punct, args.proj)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment