Last active
April 27, 2016 13:03
-
-
Save jstypka/a14aee497b527f8c41c2ba67f166d345 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def count_ones(n): | |
# How many ones are there in a numbers binary representation a.k.a. how many keywords are relevant to a sample | |
count = 0 | |
while n != 0: | |
n &= n - 1 | |
count += 1 | |
return count | |
def compute_descendant_metric(y_true, y_pred, labels, ontology): | |
label_mapping = {lab: idx for idx, lab in enumerate(labels)} # a label -> index mapping | |
label_set = set(labels) | |
# Transform the y_true matrix | |
# it's a list of integers, where each integer keeps an information about which keywords are relevant to this sample | |
# if an integer has a one on the p position, that means that the label with index p is relevant | |
bit_sets = [] | |
for i in xrange(len(y_true)): | |
bit_set = 0 | |
for lab in compress(labels, y_true[i]): | |
desc = ontology.get_descendants_of_label(lab, filtered_by=label_set).keys() | |
for lab_no in [label_mapping[label] for label in desc]: | |
bit_set |= 1 << lab_no | |
bit_sets.append(bit_set) | |
y_true = bit_sets | |
y_true_size = [count_ones(i) for i in y_true] # keeps the sizes | |
# Fill up the priority queue | |
heap = [] | |
for row in xrange(len(y_pred)): | |
for col in xrange(len(y_pred[0])): | |
heapq.heappush(heap, (-y_pred[row][col], row, col)) # we want to pop them out in order of confidence | |
# Transform the y_pred matrix | |
# the same representation as for the y_true matrix | |
y_pred = [0 for _ in xrange(len(y_pred))] | |
y_pred_size = [0 for _ in xrange(len(y_pred))] | |
precision_means, recall_means = [1], [0] | |
precision = np.ones(len(y_true)) | |
recall = np.zeros(len(y_true)) | |
while heap: # around 750 mln iterations | |
_, row, col = heapq.heappop(heap) | |
desc = ontology.get_descendants_of_label(labels[col], filtered_by=label_set).keys() | |
for lab_no in [label_mapping[lab] for lab in desc]: | |
new_lab = 1 << lab_no | |
if new_lab & y_pred[row] == 0: | |
y_pred[row] |= new_lab | |
y_pred_size[row] += 1 | |
intersection_size = count_ones(y_true[row] & y_pred[row]) # this should be faster than hell | |
pred_size = y_pred_size[row] | |
true_size = y_true_size[row] | |
precision[row] = intersection_size / pred_size if pred_size else 1 | |
recall[row] = intersection_size / true_size if true_size else 1 | |
precision_means.append(np.mean(precision)) # this is the bottlenech | |
recall_means.append(np.mean(recall)) # we mean over 75k records on every iteration | |
return auc(recall_means, precision_means) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment