Created
January 16, 2025 22:33
-
-
Save gabfssilva/767e93262e2e5aff6e13c9b73d01d2f5 to your computer and use it in GitHub Desktop.
keras + Scikit-Learn + grid search
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras import Layer, Model | |
from sklearn.datasets import make_classification | |
import keras | |
from keras.src.layers import Dense, Input | |
from sklearn.model_selection import GridSearchCV | |
X, y = make_classification(10000, 20, n_informative=16, random_state=0) | |
X = X.astype(np.float32) | |
y = y.astype(np.int64) | |
def dynamic_model(X, y, loss, layers, out_activation): | |
# Creates a basic MLP model dynamically choosing the input and | |
# output shapes. | |
n_features_in = X.shape[1] | |
inp = Input(shape=(n_features_in,)) | |
hidden = inp | |
for layer in layers: | |
hidden = Dense(layer['dimensions'], activation=layer['activation'])(hidden) | |
n_outputs = y.shape[1] if len(y.shape) > 1 else 1 | |
out = [Dense(n_outputs, activation=out_activation)(hidden)] | |
model = Model(inp, out) | |
model.compile(loss=loss, optimizer="rmsprop") | |
return model | |
clf = keras.wrappers.SKLearnClassifier(model=dynamic_model) | |
def gridparams(params): | |
""" | |
Generate all possible combinations of parameters from a dictionary of parameter values. | |
Args: | |
params (dict): Dictionary where keys are parameter names and values are lists of possible values | |
e.g., {'color': ['red', 'blue'], 'size': ['S', 'M', 'L']} | |
Returns: | |
list: List of dictionaries, each containing one possible combination of parameters | |
e.g., [{'color': 'red', 'size': 'S'}, {'color': 'red', 'size': 'M'}, ...] | |
""" | |
result = [{}] | |
for key, values in params.items(): | |
new_result = [] | |
for value in values: | |
for combination in result: | |
new_combination = combination.copy() | |
new_combination[key] = value | |
new_result.append(new_combination) | |
result = new_result | |
return result | |
params = { | |
"model_kwargs": gridparams({ | |
"out_activation": ["softmax", "tanh"], | |
"loss": ["categorical_crossentropy"], | |
"layers": [ | |
[{ | |
"dimensions": "20", | |
"activation": "relu" | |
}, { | |
"dimensions": "20", | |
"activation": "relu" | |
}, { | |
"dimensions": "20", | |
"activation": "softmax" | |
}], | |
# ----------------- | |
[{ | |
"dimensions": "20", | |
"activation": "relu" | |
}, { | |
"dimensions": "20", | |
"activation": "tanh" | |
}, { | |
"dimensions": "5", | |
"activation": "softmax" | |
}], | |
# ----------------- | |
[{ | |
"dimensions": "20", | |
"activation": "relu" | |
}, { | |
"dimensions": "20", | |
"activation": "tanh" | |
}, { | |
"dimensions": "10", | |
"activation": "tanh" | |
}, { | |
"dimensions": "5", | |
"activation": "tanh" | |
}, { | |
"dimensions": "5", | |
"activation": "tanh" | |
}]], | |
}) | |
} | |
gs = GridSearchCV( | |
clf, | |
params, | |
refit=True, | |
cv=10, | |
scoring='accuracy', | |
) | |
gs.fit(X, y) | |
print(gs.best_score_, gs.best_params_) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment