Last active
March 11, 2017 20:45
-
-
Save razhangwei/85df7e3176a9038fe460 to your computer and use it in GitHub Desktop.
sklearn: tune parameter using cross validation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This file uses cross validation to tune the parameters. | |
""" | |
import multiprocessing | |
import sys | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from sklearn.model_selection import KFold, ParameterGrid | |
import joblib | |
from joblib import Parallel, delayed | |
from tqdm import tqdm | |
from model.hmm import ConstrainedMixHMM | |
from utils.utils import load_data, send_email | |
sns.set_style('whitegrid') | |
np.set_printoptions(precision=3) | |
pd.set_option('precision', 3) | |
task_name = "HMMCrossValidation" | |
input_folder = '../data/output' | |
output_folder = '../data/output/%s' % task_name | |
figure_folder = '../figure/%s' % task_name | |
data_config = dict( | |
path='%s/RNZS_CC_F_include_missing.csv' % (input_folder), | |
aMCI_only=True, | |
split_aMCI_naMCI=False, | |
split_early_clinic_MCI=True, | |
include_missing_x=True, | |
include_missing_y=True, | |
impute_missing=False, | |
reviewed_label_only=False | |
) | |
run_config = dict( | |
run_times=10, | |
verbose=False, | |
plot=False | |
) | |
model_config = dict( | |
monotonic_state=True, | |
covariance_type='diag', | |
) | |
param_grid = { | |
'transmat_type': ['upper-bidiagonal'], | |
'n_components': range(3, 15) + [17, 20, 24] | |
} | |
def _fit(X, lengths, random_state, config): | |
"""fit the model""" | |
return ConstrainedMixHMM(random_state=random_state, **config).fit(X, lengths) | |
def fit_model(parallel, X, lengths, config): | |
""" fit model with multiple times using different random starts | |
config : dict | |
model config | |
""" | |
if run_config['verbose']: | |
print "Fitting the model..." | |
models = parallel( | |
delayed(_fit)(X, lengths, i, config) | |
for i in range(run_config['run_times']) | |
) | |
scores = [m.monitor_.history[-1] for m in models] | |
if run_config['plot']: | |
plt.figure() | |
sns.kdeplot(np.array(scores)) | |
plt.plot(np.max(scores), 0, 'ro') | |
plt.title('kernel density esitmation of log-likelihood') | |
return models[np.argmax(scores)] | |
def run_cross_validation(n_splits, notification=False): | |
""" tune the number of hidden states using cross_validation | |
Parameters | |
---------- | |
n_splits : int | |
number of folds for cross validation | |
notification : boolean | |
Whether to notify the progress through wechat | |
""" | |
X = load_data(**data_config) | |
path = "%s/cv_splits.pkl" % output_folder | |
try: | |
k_splits = joblib.load(path) | |
except IOError: | |
k_splits = list(KFold(n_splits, shuffle=True).split(X.index.levels[0])) | |
joblib.dump(k_splits, path) | |
print k_splits | |
n_jobs = min(int(multiprocessing.cpu_count() * 1.5), run_config['run_times']) | |
with Parallel(n_jobs=n_jobs) as parallel: | |
for params in tqdm(ParameterGrid(param_grid)): | |
print "Fitting %r using cross-validation..." % params | |
for i, (train_index, _) in tqdm(enumerate(k_splits)): | |
DBID_train = X.index.levels[0][train_index].tolist() | |
X_train = X.loc[DBID_train] | |
lengths_train = X_train.groupby(level=0).size().values | |
# load or traint the model | |
path = "%s/cv=%d_transmat=%s_n=%d.pkl" % ( | |
output_folder, i, params['transmat_type'], params['n_components']) | |
try: | |
model = ConstrainedMixHMM.load(path) | |
except IOError: | |
if notification: | |
send_email("Cross Validation", | |
"Start fitting %d-fold with %r." % (i, params)) | |
model_config['n_components'] = params['n_components'] | |
model_config['transmat_type'] = params['transmat_type'] | |
model = fit_model(parallel, X_train, | |
lengths_train, model_config) | |
model.save(path) | |
if notification: | |
send_email("Cross Validation", | |
"Finished fitting %d-fold with %r." % (i, params)) | |
def fit_model_on_whole_dataset(): | |
X = load_data(**data_config) | |
lengths = X.groupby(level=0).size().values | |
n_jobs = min(int(multiprocessing.cpu_count() * 1.5), run_config['run_times']) | |
with Parallel(n_jobs=n_jobs) as parallel: | |
for params in tqdm(ParameterGrid(param_grid)): | |
print "fitting %r ..." % params | |
# load or traint the model | |
filename = "%s/model_type=%s_n=%d.pkl" % (output_folder, | |
params['transmat_type'], params['n_components']) | |
try: | |
model = ConstrainedMixHMM.load(filename) | |
except IOError: | |
model_config['n_components'] = params['n_components'] | |
model_config['transmat_type'] = params['transmat_type'] | |
model = fit_model(parallel, X, lengths, model_config) | |
model.save(filename) | |
if __name__ == "__main__": | |
assert len(sys.argv) == 2 | |
if "CV" in sys.argv[1]: | |
run_cross_validation(n_splits=5, notification=True) | |
if "WHOLE" in sys.argv[1]: | |
fit_model_on_whole_dataset() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment