Created
March 14, 2025 15:37
-
-
Save nickfox-taterli/9e4b297b260535f6311c2935ff5c80c4 to your computer and use it in GitHub Desktop.
GridSearchCV + Keras Sequential
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import os | |
from tensorflow import keras | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Dense | |
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard | |
from scikeras.wrappers import KerasClassifier | |
from sklearn.model_selection import GridSearchCV | |
from datetime import datetime | |
# 创建 Keras 模型 | |
def create_model(optimizer='adam', activation='relu'): | |
model = Sequential([ | |
Dense(16, activation=activation, input_shape=(20,)), # 20 维特征 | |
Dense(8, activation=activation), | |
Dense(1, activation='sigmoid') # 二分类 | |
]) | |
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy']) | |
return model | |
# 生成模拟数据 | |
X = np.random.rand(500, 20) # 500 个样本,20 维特征 | |
y = np.random.randint(0, 2, 500) # 二分类标签 | |
# 定义 EarlyStopping 回调,防止过拟合 | |
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) | |
# 定义动态 TensorBoard 目录(防止覆盖) | |
log_dir = "logs/gridsearch_" + datetime.now().strftime("%Y%m%d-%H%M%S") | |
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1) | |
# 用 KerasClassifier 适配模型 | |
model = KerasClassifier(build_fn=create_model, verbose=0) | |
# 定义超参数搜索范围 | |
param_grid = { | |
'optimizer': ['adam', 'sgd'], | |
'activation': ['relu', 'tanh'], | |
'batch_size': [16, 32], | |
'epochs': [10, 20], | |
'fit__callbacks': [[early_stopping, tensorboard_callback]] # 这里传递 callbacks | |
} | |
# 进行网格搜索 | |
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=3) | |
grid_result = grid.fit(X, y) | |
# 输出最佳超参数 | |
print("最佳超参数:", grid_result.best_params_) | |
print("最佳准确率:", grid_result.best_score_) | |
print(f"TensorBoard 日志目录: {log_dir}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment