Skip to content

Instantly share code, notes, and snippets.

@darden1
Created January 26, 2018 13:49
Show Gist options
  • Save darden1/ac585ec2a36032d474f71c0a0b449022 to your computer and use it in GitHub Desktop.
Save darden1/ac585ec2a36032d474f71c0a0b449022 to your computer and use it in GitHub Desktop.
my_multi_class_ logistic_regression2.py
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
# アヤメデータセット
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 教師データのonehotエンコーディング
ohe = OneHotEncoder()
Y = ohe.fit_transform(y[:,np.newaxis]).toarray()
# シャッフル
X, Y = shuffle(X, Y, random_state=0)
# トレーニング・テストデータ分割
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0)
# 特徴量の標準化
ss = StandardScaler()
X_train = ss.fit_transform(X_train)
X_test = ss.transform(X_test)
# 自作ロジステック回帰インスタンス
clf = MultiClassLogisticRegression()
batch_size = int(len(X_train)*0.2) # ミニバッチサイズ
epochs = 500 # エポック数
mu = 0.6 # 学習率
# 学習実施
clf.fit(X_train, Y_train, batch_size, epochs, mu, validation_data=(X_test, Y_test), verbose=1)
# 結果のプロット
plt.figure(figsize=(10, 7))
plt.subplot(211)
plt.title("learning log (loss)")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.plot(np.arange(len(clf.loss)), clf.loss, label="train")
plt.plot(np.arange(len(clf.loss)), clf.val_loss, label="test")
plt.legend(loc="best")
plt.grid(True)
plt.subplot(212)
plt.title("learning log (accuracy)")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.plot(np.arange(len(clf.loss)), clf.acc, label="train")
plt.plot(np.arange(len(clf.loss)), clf.val_acc, label="test")
plt.legend(loc="best")
plt.grid(True)
plt.tight_layout()
plt.show()
# 正答率
print("acc_train: "+ str(clf.acc[-1]) + " acc_test: "+ str(clf.val_acc[-1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment