Last active
June 16, 2019 14:11
-
-
Save BrambleXu/52b0aaf10987015a078d36c97729dace to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
# read data | |
data = np.loadtxt("non_linear_data.csv", delimiter=',', skiprows=1) | |
train_x = data[:, 0:2] | |
train_y = data[:, 2] | |
# plot data points | |
# plt.plot(train_x[train_y == 1, 0], train_x[train_y == 1, 1], 'o') | |
# plt.plot(train_x[train_y == 0, 0], train_x[train_y == 0, 1], 'x') | |
# plt.show() | |
# initialize parameter | |
theta = np.random.randn(4) | |
# standardization | |
mu = train_x.mean(axis=0) | |
sigma = train_x.std(axis=0) | |
def standardizer(x): | |
return (x - mu) / sigma | |
std_x = standardizer(train_x) | |
# add x0 and x3 to get matrix | |
def to_matrix(x): | |
x0 = np.ones([x.shape[0], 1]) # (20, 1) | |
x3 = x[:, 0, np.newaxis] ** 2 # (20, 1) | |
return np.hstack([x0, x, x3]) | |
mat_x = to_matrix(std_x) # (20, 4) | |
# sigmoid function | |
def f(x): | |
""" | |
theta: (4,) | |
x: (n, 4) | |
return sigmoid(x) -> (4, 1) | |
""" | |
return 1 / (1 + np.exp(-np.dot(x, theta))) | |
# classify sample to 0 or 1 | |
def classify(x): | |
return (f(x) >= 0.5).astype(np.int) | |
# update times | |
epoch = 2000 | |
# learning rate | |
ETA = 1e-3 | |
# accuracy log | |
accuracies = [] | |
# update parameter | |
for _ in range(epoch): | |
""" | |
f(mat_x) - train_y: (20,) | |
mat_x: (20, 4) | |
theta: (4,) | |
dot production: (20,) x (20, 4) -> (4,) | |
""" | |
theta = theta - ETA * np.dot(f(mat_x) - train_y, mat_x) | |
result = classify(mat_x) == train_y # result is [Ture, False, ...] | |
accuracy = sum(result) / len(result) | |
accuracies.append(accuracy) | |
## plot line | |
# x1 = np.linspace(-2, 2, 100) | |
# x2 = - (theta[0] + x1 * theta[1] + theta[3] * x1**2) / theta[2] | |
# plt.plot(std_x[train_y == 1, 0], std_x[train_y == 1, 1], 'o') # train data of class 1 | |
# plt.plot(std_x[train_y == 0, 0], std_x[train_y == 0, 1], 'x') # train data of class 0 | |
# plt.plot(x1, x2, linestyle='dashed') # plot the line we learned | |
# plt.show() | |
# plot accuracy line | |
x = np.arange(len(accuracies)) | |
plt.plot(x, accuracies) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment