Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Last active March 11, 2017 20:45
Show Gist options
  • Save razhangwei/85df7e3176a9038fe460 to your computer and use it in GitHub Desktop.
Save razhangwei/85df7e3176a9038fe460 to your computer and use it in GitHub Desktop.
sklearn: tune parameter using cross validation
"""
This file uses cross validation to tune the parameters.
"""
import multiprocessing
import sys
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, ParameterGrid
import joblib
from joblib import Parallel, delayed
from tqdm import tqdm
from model.hmm import ConstrainedMixHMM
from utils.utils import load_data, send_email
sns.set_style('whitegrid')
np.set_printoptions(precision=3)
pd.set_option('precision', 3)
task_name = "HMMCrossValidation"
input_folder = '../data/output'
output_folder = '../data/output/%s' % task_name
figure_folder = '../figure/%s' % task_name
data_config = dict(
path='%s/RNZS_CC_F_include_missing.csv' % (input_folder),
aMCI_only=True,
split_aMCI_naMCI=False,
split_early_clinic_MCI=True,
include_missing_x=True,
include_missing_y=True,
impute_missing=False,
reviewed_label_only=False
)
run_config = dict(
run_times=10,
verbose=False,
plot=False
)
model_config = dict(
monotonic_state=True,
covariance_type='diag',
)
param_grid = {
'transmat_type': ['upper-bidiagonal'],
'n_components': range(3, 15) + [17, 20, 24]
}
def _fit(X, lengths, random_state, config):
"""fit the model"""
return ConstrainedMixHMM(random_state=random_state, **config).fit(X, lengths)
def fit_model(parallel, X, lengths, config):
""" fit model with multiple times using different random starts
config : dict
model config
"""
if run_config['verbose']:
print "Fitting the model..."
models = parallel(
delayed(_fit)(X, lengths, i, config)
for i in range(run_config['run_times'])
)
scores = [m.monitor_.history[-1] for m in models]
if run_config['plot']:
plt.figure()
sns.kdeplot(np.array(scores))
plt.plot(np.max(scores), 0, 'ro')
plt.title('kernel density esitmation of log-likelihood')
return models[np.argmax(scores)]
def run_cross_validation(n_splits, notification=False):
""" tune the number of hidden states using cross_validation
Parameters
----------
n_splits : int
number of folds for cross validation
notification : boolean
Whether to notify the progress through wechat
"""
X = load_data(**data_config)
path = "%s/cv_splits.pkl" % output_folder
try:
k_splits = joblib.load(path)
except IOError:
k_splits = list(KFold(n_splits, shuffle=True).split(X.index.levels[0]))
joblib.dump(k_splits, path)
print k_splits
n_jobs = min(int(multiprocessing.cpu_count() * 1.5), run_config['run_times'])
with Parallel(n_jobs=n_jobs) as parallel:
for params in tqdm(ParameterGrid(param_grid)):
print "Fitting %r using cross-validation..." % params
for i, (train_index, _) in tqdm(enumerate(k_splits)):
DBID_train = X.index.levels[0][train_index].tolist()
X_train = X.loc[DBID_train]
lengths_train = X_train.groupby(level=0).size().values
# load or traint the model
path = "%s/cv=%d_transmat=%s_n=%d.pkl" % (
output_folder, i, params['transmat_type'], params['n_components'])
try:
model = ConstrainedMixHMM.load(path)
except IOError:
if notification:
send_email("Cross Validation",
"Start fitting %d-fold with %r." % (i, params))
model_config['n_components'] = params['n_components']
model_config['transmat_type'] = params['transmat_type']
model = fit_model(parallel, X_train,
lengths_train, model_config)
model.save(path)
if notification:
send_email("Cross Validation",
"Finished fitting %d-fold with %r." % (i, params))
def fit_model_on_whole_dataset():
X = load_data(**data_config)
lengths = X.groupby(level=0).size().values
n_jobs = min(int(multiprocessing.cpu_count() * 1.5), run_config['run_times'])
with Parallel(n_jobs=n_jobs) as parallel:
for params in tqdm(ParameterGrid(param_grid)):
print "fitting %r ..." % params
# load or traint the model
filename = "%s/model_type=%s_n=%d.pkl" % (output_folder,
params['transmat_type'], params['n_components'])
try:
model = ConstrainedMixHMM.load(filename)
except IOError:
model_config['n_components'] = params['n_components']
model_config['transmat_type'] = params['transmat_type']
model = fit_model(parallel, X, lengths, model_config)
model.save(filename)
if __name__ == "__main__":
assert len(sys.argv) == 2
if "CV" in sys.argv[1]:
run_cross_validation(n_splits=5, notification=True)
if "WHOLE" in sys.argv[1]:
fit_model_on_whole_dataset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment