Created
March 13, 2012 06:29
-
-
Save pprett/2027241 to your computer and use it in GitHub Desktop.
Sklearn Yahoo LTRC 2010 Benchmark script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import svmlight_loader | |
from sklearn.ensemble import GradientBoostingRegressor | |
from time import time | |
ROOT_DIR = '/home/pprett/corpora/yahoo-ltrc-2010/data' | |
X_train, y_train = svmlight_loader.load_svmlight_file(ROOT_DIR + '/set1.train.txt', | |
n_features=700, | |
dtype=np.float32, | |
buffer_mb=500) | |
X_train = np.asfortranarray(X_train.toarray()) | |
X_test, y_test = svmlight_loader.load_svmlight_file(ROOT_DIR + '/set1.test.txt', | |
n_features=700, | |
dtype=np.float32, | |
buffer_mb=500) | |
X_test = np.asfortranarray(X_test.toarray()) | |
clf = GradientBoostingRegressor(n_estimators=200, | |
max_depth=3, | |
min_split=15, | |
learn_rate=0.05) | |
t0 = time() | |
def monitor(model, i): | |
print("Iteration %d\t%ds" % (i, time() - t0)) | |
clf.fit(X_train, y_train, monitor=monitor) | |
y_pred = clf.predict(X_test) | |
np.save(ROOT_DIR + '/y_pred.npy', y_pred) | |
################################################################################ | |
# get labels and qids from test set | |
qids = !cut -f -2 -d' ' data/set1.test.txt | |
qids = [map(int, q.split(' qid:')) for q in qids] | |
# group labels and scores by qid | |
it = groupby(izip(qids, y_pred), lambda x: x[0][1]) | |
data = [[(l,s) for (l,_), s in subit] for qid, subit in it] | |
labels = [[l for l,s in d] for d in data] | |
scores = [[s for l,s in d] for d in data] | |
ranks = [] | |
for s in scores: | |
ranks.append(np.zeros((len(s),), dtype=np.int)) | |
ranks[-1][np.argsort(s)[::-1]] = np.arange(1, len(s) + 1) | |
# use ``evaluate.evaluate_submission(labels,ranks,k=10)`` to get NDCG and MRR scores | |
# you can find ``evaluate.py`` on the Yahoo LTCR 2010 website. | |
# scores (0.45268919338794184, 0.77109134293735371) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment