Created
April 9, 2017 04:54
-
-
Save dolphinsue319/7deea9c11023a216da9b39af4796448b to your computer and use it in GitHub Desktop.
用 scikit-learn 實作 logistic regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# In[ ]: | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
get_ipython().magic('matplotlib inline') | |
# ### 建立兩個 data frame: X 及 Y。 X 是因,Y 是果。白話說起來就是當 X 的值小於 51 時,Y 是 1。 | |
# In[24]: | |
X = pd.DataFrame(np.linspace(0, 100, num=1000)) | |
Y = pd.DataFrame(np.array([1 if x < 51.0 else 0 for x in X.values])) | |
plt.plot(X, Y, 'ro') | |
plt.show() | |
# # 將資料切成 train 組及 test 組, 70% 是 train,30% 是 test | |
# In[25]: | |
from sklearn.model_selection import train_test_split | |
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3) | |
# ### Train 及 predict | |
# In[26]: | |
from sklearn.linear_model import LogisticRegression | |
lr = LogisticRegression() | |
lr.fit(X_train, Y_train.values.ravel()) | |
Y_pred = lr.predict(X_train) | |
plt.plot(X_train, Y_pred, 'ro') | |
plt.show() | |
# ### 用測試組資料做預測 | |
# In[27]: | |
Y_pred = lr.predict(X_test) | |
plt.plot(X_test, Y_pred, 'ro') | |
plt.show() | |
# ## <span style="color:blue">增加一個 feature, 這次的數值是 0~1 之間的數字</style> | |
# In[28]: | |
X[1] = pd.DataFrame(np.linspace(0, 1, num=1000)) | |
# ### 因為兩個 feature 的數據範圍差很多,所以要做 Normalization | |
# In[29]: | |
from sklearn.preprocessing import StandardScaler | |
sc = StandardScaler() | |
sc.fit(X) | |
X_nor = sc.transform(X) | |
X_nor_train, X_nor_test, Y_nor_train, Y_nor_test = train_test_split(X_nor, Y, test_size=0.3) | |
# In[32]: | |
lr.fit(X_nor_train, Y_nor_train.values.ravel()) | |
Y_pred = lr.predict(X_nor_train) | |
plt.plot(X_nor_train, Y_pred, 'ro') | |
plt.show() | |
# ### 用測試組資料來試試看 model 預測的準不準。 | |
# In[33]: | |
Y_test_pred = lr.predict(X_nor_test) | |
plt.plot(X_nor_test, Y_test_pred, 'ro') | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment