Last active
April 26, 2016 08:26
-
-
Save jstypka/abbf5df1fe5048ce8b92bb2fecbd1977 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import compress | |
import numpy as np | |
from sklearn.metrics import auc | |
def compute_precision_recall_auc(y_true, y_pred, precision_recall_fun): | |
thresholds = np.append(np.unique(y_pred), [-1]) | |
thresholds.sort() | |
precision = np.zeros(len(y_true)) # reuse them for every binarisation | |
recall = np.zeros(len(y_true)) | |
precision_means = np.zeros(len(thresholds)) | |
recall_means = np.zeros(len(thresholds)) | |
for i, t in enumerate(reversed(thresholds)): | |
y_pred_bin = y_pred > t | |
for row in xrange(len(y_true)): | |
precision[row], recall[row] = precision_recall_fun(y_true[row], | |
y_pred_bin[row]) | |
precision_means[i], recall_means[i] = np.mean(precision), np.mean(recall) | |
return auc(recall_means, precision_means) | |
def descendant_precision_recall(y_t, y_p): | |
masked_labels = set(compress(labels, y_p)) | |
pred_labels = ontology.get_descendants_of_labels(masked_labels).keys() | |
masked_labels = set(compress(labels, y_t)) | |
true_labels = ontology.get_descendants_of_labels(masked_labels).keys() | |
intersection = set(true_labels) & set(pred_labels) | |
precision = len(intersection) / len(pred_labels) if pred_labels else 1 | |
recall = len(intersection) / len(true_labels) if true_labels else 1 | |
return precision, recall | |
class Ontology(object): | |
# ... | |
# some other stuff here | |
# ... | |
def get_descendants_of_labels(self, starting_labels): | |
""" | |
Walks a graph downwards (with a BFS) from given nodes towards the leaves. | |
Returns all the found descendants and their distances from starting nodes. | |
:param node_canonical_labels: canonical labels of different nodes | |
:return: list of tuples e.g. [('node1', 1), ('node2', 2) ...] | |
""" | |
relations = {SKOS.narrower, SKOS.composite} | |
parsed = [self.parse_label(lab) for lab in starting_labels] # constant time | |
uris = [self.get_uri_from_label(lab) for lab in parsed] # constant time | |
for lab, uri in zip(starting_labels, uris): | |
if not uri: | |
raise ValueError('Label ' + lab + ' not in the ontology graph') | |
distances = {} | |
queue = deque([(uri, 0) for uri in uris]) | |
while queue: | |
node, distance = queue.popleft() | |
node_label = self.get_canonical_label_from_uri(node) # constant time | |
if node_label in distances: | |
distances[node_label] = min(distances[node_label], distance) | |
continue | |
else: | |
distances[node_label] = distance | |
for edge_tuple in self.graph.out_edges_iter(nbunch=[node], data=True): | |
if edge_tuple[2].get('relation') in relations: | |
queue.append((edge_tuple[1], distance + 1)) | |
return distances |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment