Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ahmedshahriar/c09b4e5ea8d4609026f7f8c759747c07 to your computer and use it in GitHub Desktop.
Save ahmedshahriar/c09b4e5ea8d4609026f7f8c759747c07 to your computer and use it in GitHub Desktop.
A simple function which takes model and parameter space as input and to tune the hyperparameters of a model using Grid Search or Random Search method selected
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, StratifiedKFold
from sklearn.metrics import accuracy_score
# configure the cross-validation procedure
kf = StratifiedKFold(n_splits = 10 , shuffle = True , random_state = 42)
def tune_hyperparameter(search_method, estimator, search_space):
# enumerate splits
outer_results = list()
for train_ix, test_ix in kf.split(X,y):
# split data
X_train, X_test = X[train_ix, :], X[test_ix, :]
y_train, y_test = y[train_ix], y[test_ix]
# configure the cross-validation procedure
cv_inner = KFold(n_splits=5, shuffle=True, random_state=1)
if search_method == "grid":
clf = GridSearchCV(
estimator=estimator,
param_grid=search_space,
scoring='accuracy',
n_jobs=-1,
cv=cv_inner,
verbose=0,
refit=True
)
elif search_method == "random":
clf = RandomizedSearchCV(
estimator=estimator,
param_distributions=search_space,
n_iter=10,
n_jobs=-1,
cv=cv_inner,
verbose=0,
random_state=1,
refit=True
)
# execute grid search
result = clf.fit(X_train, y_train)
# get the best performing model fit on the whole training set
best_model = result.best_estimator_
# evaluate model on the hold out dataset
yhat = best_model.predict(X_test)
# evaluate the model
acc = accuracy_score(y_test, yhat)
# store the result
outer_results.append(acc)
# report progress
print('acc=%.3f, est=%.3f, cfg=%s' % (acc, result.best_score_, result.best_params_))
# summarize the estimated performance of the model
print('Accuracy: %.3f (%.3f)' % (np.mean(outer_results), np.std(outer_results)))
print("Best",search_method,"Model : ", best_model)
print("-"*50, '\n\n')
return best_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment