Created
December 20, 2017 06:59
-
-
Save yohm/553de46c407ed8ec1f004c0fa38e0676 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os.path | |
import caravan_dump | |
import numpy as np | |
from sklearn.linear_model import Ridge | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.preprocessing import PolynomialFeatures | |
from sklearn.model_selection import train_test_split | |
from sklearn.pipeline import make_pipeline | |
from sklearn.externals import joblib | |
if len(sys.argv) != 2: | |
sys.stderr.write("Invalid argument\n") | |
sys.stderr.write(" Usage: python %s dump.bin\n" % os.path.basename(__file__)) | |
raise Exception("Invalid command line argument") | |
dump = caravan_dump.CaravanDump( sys.argv[1] ) | |
def parse_results(dump): | |
x = [] | |
y = [] | |
for r in dump.runs: | |
result = r["result"] | |
if result[0] == 0.0: | |
sys.stderr.write("skipped\n") | |
continue | |
y.append( result ) | |
psid = r["parentPSId"] | |
ps = dump.parameter_sets[psid] | |
x.append( [float(x) for x in ps["point"]] ) | |
return (np.array(x), np.array(y) ) | |
X,ys = parse_results( dump ) | |
sys.stderr.write("parsed\n") | |
for k in range(ys.shape[1]): | |
y = ys[:,k] | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=1234) | |
pipe = make_pipeline( StandardScaler(), PolynomialFeatures(degree=5), Ridge(alpha=0.005, fit_intercept=True)) | |
from sklearn.model_selection import GridSearchCV | |
params = dict(polynomialfeatures__degree=[3,4,5], ridge__alpha= np.logspace(-5,2,num=8) ) | |
grid_search = GridSearchCV( pipe, param_grid=params, cv=5, n_jobs=4 ) | |
grid_search.fit( X_train, y_train ) | |
print( grid_search.best_estimator_ ) | |
print( grid_search.score( X_test, y_test ) ) | |
joblib.dump( grid_search, "reg_model%d.pkl"%k ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment