Skip to content

Instantly share code, notes, and snippets.

@nickfox-taterli
Created March 14, 2025 15:37
Show Gist options
  • Save nickfox-taterli/9e4b297b260535f6311c2935ff5c80c4 to your computer and use it in GitHub Desktop.
Save nickfox-taterli/9e4b297b260535f6311c2935ff5c80c4 to your computer and use it in GitHub Desktop.
GridSearchCV + Keras Sequential
import numpy as np
import os
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV
from datetime import datetime
# 创建 Keras 模型
def create_model(optimizer='adam', activation='relu'):
model = Sequential([
Dense(16, activation=activation, input_shape=(20,)), # 20 维特征
Dense(8, activation=activation),
Dense(1, activation='sigmoid') # 二分类
])
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
return model
# 生成模拟数据
X = np.random.rand(500, 20) # 500 个样本,20 维特征
y = np.random.randint(0, 2, 500) # 二分类标签
# 定义 EarlyStopping 回调,防止过拟合
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
# 定义动态 TensorBoard 目录(防止覆盖)
log_dir = "logs/gridsearch_" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
# 用 KerasClassifier 适配模型
model = KerasClassifier(build_fn=create_model, verbose=0)
# 定义超参数搜索范围
param_grid = {
'optimizer': ['adam', 'sgd'],
'activation': ['relu', 'tanh'],
'batch_size': [16, 32],
'epochs': [10, 20],
'fit__callbacks': [[early_stopping, tensorboard_callback]] # 这里传递 callbacks
}
# 进行网格搜索
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=3)
grid_result = grid.fit(X, y)
# 输出最佳超参数
print("最佳超参数:", grid_result.best_params_)
print("最佳准确率:", grid_result.best_score_)
print(f"TensorBoard 日志目录: {log_dir}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment