Created
April 23, 2010 18:19
-
-
Save jelsas/376926 to your computer and use it in GitHub Desktop.
Formats a submission file suitable for upload to the Yahoo LETOR Challenge
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# This script takes the predictions and input vectors | |
# (eg. set1.test.txt) and produces a file suitable for | |
# submission to the Yahoo LETOR Challenge web interface. | |
# | |
# The PREDICTIONS_FILE should just be a list of scores, | |
# one per line, corresponding to the lines in INPUT_FILE | |
from itertools import izip, groupby | |
from optparse import OptionParser | |
parser = OptionParser( | |
usage='usage: %prog [options] PREDICTIONS_FILE INPUT_FILE') | |
(options, args) = parser.parse_args() | |
if len(args) != 2: parser.error('Must specify PREDICTIONS_FILE and INPUT_FILE') | |
preds_file, input_file = args | |
# iterator for the scores on each line of the preds_file | |
scores = (float(line.split(None, 1)[0]) for line in open(preds_file)) | |
# iterator for the qids on each line of the input_file | |
qids = (line.split(None, 2)[1] for line in open(input_file)) | |
# iterator that zips everything together & groups by qid | |
data = groupby(izip(qids, scores), lambda x: x[0]) | |
# go through each query | |
for (q, q_scores) in data: | |
q_scores = list(q_scores) | |
# calculate the sorted permutation for the q_scores list | |
sorted_perm = sorted(range(len(q_scores)), | |
reverse=True, | |
key=lambda i: q_scores[i][1]) | |
# calculate the predicted ranks | |
predicted_ranks = [sorted_perm.index(i) for i in range(len(q_scores))] | |
print ' '.join(str(s+1) for s in predicted_ranks) | |
# make sure we got to the end of the scores & qids iterators | |
# both of these should raise StopIteration | |
def check_iter_at_end(iter): | |
try: iter.next() | |
except StopIteration: pass | |
else: parser.error('scores length != qids length') | |
check_iter_at_end(scores) | |
check_iter_at_end(qids) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment