Skip to content

Instantly share code, notes, and snippets.

@darden1
Created April 11, 2018 13:35
Show Gist options
  • Select an option

  • Save darden1/250eec77acba0f55bda89fd532b60fd2 to your computer and use it in GitHub Desktop.

Select an option

Save darden1/250eec77acba0f55bda89fd532b60fd2 to your computer and use it in GitHub Desktop.
train_with_my_mlp.py
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
# アヤメデータセット
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 教師データのonehotエンコーディング
ohe = OneHotEncoder()
Y = ohe.fit_transform(y[:,np.newaxis]).toarray()
# シャッフル
X, Y = shuffle(X, Y, random_state=0)
# トレーニング・テストデータ分割
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0)
# 特徴量の標準化
ss = StandardScaler()
X_train = ss.fit_transform(X_train)
X_test = ss.transform(X_test)
batch_size = int(len(X_train)*0.2) # ミニバッチサイズ
epochs = 100 # エポック数
mu = 0.05 # 学習率
# 学習器作成
clf = MultiLayerPerceptron(hidden_layer_sizes=(10,10,10), activation="relu", random_state=10)
# 学習実施
clf.fit(X_train, Y_train, batch_size, epochs, mu, validation_data=(X_test, Y_test), verbose=0)
# 結果のプロット
plt.figure(figsize=(10, 7))
plt.subplot(211)
plt.title("learning log (loss)")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.plot(np.arange(len(clf.loss)), clf.loss, label="train")
plt.plot(np.arange(len(clf.loss)), clf.val_loss, label="test")
plt.legend(loc="best")
plt.grid(True)
plt.subplot(212)
plt.title("learning log (accuracy)")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.plot(np.arange(len(clf.loss)), clf.acc, label="train")
plt.plot(np.arange(len(clf.loss)), clf.val_acc, label="test")
plt.legend(loc="best")
plt.grid(True)
plt.tight_layout()
plt.show()
# 正答率
print("acc_train: "+ str(clf.acc[-1]) + " acc_test: "+ str(clf.val_acc[-1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment