Last active
April 16, 2018 10:45
-
-
Save darden1/ef5b19004380b98beb9dc2684461cab5 to your computer and use it in GitHub Desktop.
train_with_sklearn_mlp.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.neural_network import MLPClassifier | |
from sklearn.metrics import accuracy_score | |
# 教師データを通常の表記に戻す(sikit-learnの多クラスロジスティック回帰ではonehotにしなくてよい) | |
y_train = np.array([np.argmax(yi) for yi in Y_train]) | |
y_test = np.array([np.argmax(yi) for yi in Y_test]) | |
batch_size = int(len(X_train)*0.2) # ミニバッチサイズ | |
epochs = 100 # エポック数 | |
mu = 0.05 # 学習率 | |
clf_sk = MLPClassifier(hidden_layer_sizes=(10,10,10), | |
activation='relu', # 活性化関数 | |
solver='sgd', # 確率的勾配効果法 | |
alpha=1e-10, # 正則化係数1e-10⇒正則化入れない | |
batch_size=batch_size, # バッチサイズ | |
learning_rate='constant', # 学習率を学習途中で変更しない | |
learning_rate_init=mu, # 学習率 | |
momentum=0.0, # モメンタム入れない | |
max_iter=epochs, # 総エポック数 | |
shuffle=False, # エポック毎にデータをシャッフルしない | |
random_state=8, # 乱数シード | |
tol=-1, # tol=-1でmax_iterまでループを回す。best_loss>loss-tolが2回続くと自動的にループが終了するので。 | |
verbose=False) # Lossの履歴をprintしない。 | |
# 学習実施 | |
clf_sk.fit(X_train, y_train) | |
# 正答率 | |
acc_train = accuracy_score(y_train, clf_sk.predict(X_train)) | |
acc_test = accuracy_score(y_test, clf_sk.predict(X_test)) | |
print("acc_train: "+ str(acc_train) + " acc_test: "+ str(acc_test)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment