Created
October 20, 2017 14:53
-
-
Save walkingpendulum/931019ea171c0ca242438ec18318f017 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from itertools import combinations | |
Alearn=np.array([ | |
[5.1, 3.5, 1.4, 0.2], | |
[4.9, 3.0, 1.4, 0.2], | |
[4.7, 3.2, 1.3, 0.2], | |
[4.6, 3.1, 1.5, 0.2], | |
[5.0, 3.6, 1.4, 0.2], | |
[5.4, 3.9, 1.7, 0.4], | |
[4.6, 3.4, 1.4, 0.3], | |
[5.0, 3.4, 1.5, 0.2], | |
[4.4, 2.9, 1.4, 0.2], | |
[4.9, 3.1, 1.5, 0.1], | |
[5.4, 3.7, 1.5, 0.2], | |
[4.8, 3.4, 1.6, 0.2], | |
[4.8, 3.0, 1.4, 0.1], | |
[4.3, 3.0, 1.1, 0.1], | |
[5.8, 4.0, 1.2, 0.2], | |
[5.7, 4.4, 1.5, 0.4], | |
[5.4, 3.9, 1.3, 0.4], | |
[5.1, 3.5, 1.4, 0.3], | |
[5.7, 3.8, 1.7, 0.3], | |
[5.1, 3.8, 1.5, 0.3]]); | |
Aexam=np.array([ | |
[5.4, 3.4, 1.7, 0.2], | |
[5.1, 3.7, 1.5, 0.4], | |
[4.6, 3.6, 1.0, 0.2], | |
[5.1, 3.3, 1.7, 0.5], | |
[4.8, 3.4, 1.9, 0.2], | |
[5.0, 3.0, 1.6, 0.2], | |
[5.0, 3.4, 1.6, 0.4], | |
[5.2, 3.5, 1.5, 0.2], | |
[5.2, 3.4, 1.4, 0.2], | |
[4.7, 3.2, 1.6, 0.2], | |
[4.8, 3.1, 1.6, 0.2], | |
[5.4, 3.4, 1.5, 0.4], | |
[5.2, 4.1, 1.5, 0.1], | |
[5.5, 4.2, 1.4, 0.2], | |
[4.9, 3.1, 1.5, 0.2], | |
[5.0, 3.2, 1.2, 0.2], | |
[5.5, 3.5, 1.3, 0.2], | |
[4.9, 3.6, 1.4, 0.1], | |
[4.4, 3.0, 1.3, 0.2], | |
[5.1, 3.4, 1.5, 0.2], | |
[5.0, 3.5, 1.3, 0.3], | |
[4.5, 2.3, 1.3, 0.3], | |
[4.4, 3.2, 1.3, 0.2], | |
[5.0, 3.5, 1.6, 0.6], | |
[5.1, 3.8, 1.9, 0.4], | |
[4.8, 3.0, 1.4, 0.3], | |
[5.1, 3.8, 1.6, 0.2], | |
[4.6, 3.2, 1.4, 0.2], | |
[5.3, 3.7, 1.5, 0.2], | |
[5.0, 3.3, 1.4, 0.2]]); | |
Blearn=np.array([ | |
[7.0, 3.2, 4.7, 1.4], | |
[6.4, 3.2, 4.5, 1.5], | |
[6.9, 3.1, 4.9, 1.5], | |
[5.5, 2.3, 4.0, 1.3], | |
[6.5, 2.8, 4.6, 1.5], | |
[5.7, 2.8, 4.5, 1.3], | |
[6.3, 3.3, 4.7, 1.6], | |
[4.9, 2.4, 3.3, 1.0], | |
[6.6, 2.9, 4.6, 1.3], | |
[5.2, 2.7, 3.9, 1.4], | |
[5.0, 2.0, 3.5, 1.0], | |
[5.9, 3.0, 4.2, 1.5], | |
[6.0, 2.2, 4.0, 1.0], | |
[6.1, 2.9, 4.7, 1.4], | |
[5.6, 2.9, 3.6, 1.3], | |
[6.7, 3.1, 4.4, 1.4], | |
[5.6, 3.0, 4.5, 1.5], | |
[5.8, 2.7, 4.1, 1.0], | |
[6.2, 2.2, 4.5, 1.5], | |
[5.6, 2.5, 3.9, 1.1]]); | |
Bexam=np.array([ | |
[5.9, 3.2, 4.8, 1.8], | |
[6.1, 2.8, 4.0, 1.3], | |
[6.3, 2.5, 4.9, 1.5], | |
[6.1, 2.8, 4.7, 1.2], | |
[6.4, 2.9, 4.3, 1.3], | |
[6.6, 3.0, 4.4, 1.4], | |
[6.8, 2.8, 4.8, 1.4], | |
[6.7, 3.0, 5.0, 1.7], | |
[6.0, 2.9, 4.5, 1.5], | |
[5.7, 2.6, 3.5, 1.0], | |
[5.5, 2.4, 3.8, 1.1], | |
[5.5, 2.4, 3.7, 1.0], | |
[5.8, 2.7, 3.9, 1.2], | |
[6.0, 2.7, 5.1, 1.6], | |
[5.4, 3.0, 4.5, 1.5], | |
[6.0, 3.4, 4.5, 1.6], | |
[6.7, 3.1, 4.7, 1.5], | |
[6.3, 2.3, 4.4, 1.3], | |
[5.6, 3.0, 4.1, 1.3], | |
[5.5, 2.5, 4.0, 1.3], | |
[5.5, 2.6, 4.4, 1.2], | |
[6.1, 3.0, 4.6, 1.4], | |
[5.8, 2.6, 4.0, 1.2], | |
[5.0, 2.3, 3.3, 1.0], | |
[5.6, 2.7, 4.2, 1.3], | |
[5.7, 3.0, 4.2, 1.2], | |
[5.7, 2.9, 4.2, 1.3], | |
[6.2, 2.9, 4.3, 1.3], | |
[5.1, 2.5, 3.0, 1.1], | |
[5.7, 2.8, 4.1, 1.3]]); | |
Clearn=np.array([ | |
[6.3, 3.3, 6.0, 2.5], | |
[5.8, 2.7, 5.1, 1.9], | |
[7.1, 3.0, 5.9, 2.1], | |
[6.3, 2.9, 5.6, 1.8], | |
[6.5, 3.0, 5.8, 2.2], | |
[7.6, 3.0, 6.6, 2.1], | |
[4.9, 2.5, 4.5, 1.7], | |
[7.3, 2.9, 6.3, 1.8], | |
[6.7, 2.5, 5.8, 1.8], | |
[7.2, 3.6, 6.1, 2.5], | |
[6.5, 3.2, 5.1, 2.0], | |
[6.4, 2.7, 5.3, 1.9], | |
[6.8, 3.0, 5.5, 2.1], | |
[5.7, 2.5, 5.0, 2.0], | |
[5.8, 2.8, 5.1, 2.4], | |
[6.4, 3.2, 5.3, 2.3], | |
[6.5, 3.0, 5.5, 1.8], | |
[7.7, 3.8, 6.7, 2.2], | |
[7.7, 2.6, 6.9, 2.3], | |
[6.0, 2.2, 5.0, 1.5]]); | |
Cexam=np.array([ | |
[6.9, 3.2, 5.7, 2.3], | |
[5.6, 2.8, 4.9, 2.0], | |
[7.7, 2.8, 6.7, 2.0], | |
[6.3, 2.7, 4.9, 1.8], | |
[6.7, 3.3, 5.7, 2.1], | |
[7.2, 3.2, 6.0, 1.8], | |
[6.2, 2.8, 4.8, 1.8], | |
[6.1, 3.0, 4.9, 1.8], | |
[6.4, 2.8, 5.6, 2.1], | |
[7.2, 3.0, 5.8, 1.6], | |
[7.4, 2.8, 6.1, 1.9], | |
[7.9, 3.8, 6.4, 2.0], | |
[6.4, 2.8, 5.6, 2.2], | |
[6.3, 2.8, 5.1, 1.5], | |
[6.1, 2.6, 5.6, 1.4], | |
[7.7, 3.0, 6.1, 2.3], | |
[6.3, 3.4, 5.6, 2.4], | |
[6.4, 3.1, 5.5, 1.8], | |
[6.0, 3.0, 4.8, 1.8], | |
[6.9, 3.1, 5.4, 2.1], | |
[6.7, 3.1, 5.6, 2.4], | |
[6.9, 3.1, 5.1, 2.3], | |
[5.8, 2.7, 5.1, 1.9], | |
[6.8, 3.2, 5.9, 2.3], | |
[6.7, 3.3, 5.7, 2.5], | |
[6.7, 3.0, 5.2, 2.3], | |
[6.3, 2.5, 5.0, 1.9], | |
[6.5, 3.0, 5.2, 2.0], | |
[6.2, 3.4, 5.4, 2.3], | |
[5.9, 3.0, 5.1, 1.8]]); | |
phi = np.ones((1, len(Alearn[0]) + 1)) | |
def predict(phi, X_augmented): | |
labels = (np.multiply(phi, X_augmented).sum(axis=1) >= 0).astype(int) | |
return labels.reshape((len(labels), 1)) | |
def updated_phi(phi_old, X_augmented, y_true, y_predicted): | |
X_err_1 = (X_augmented * ((y_true != y_predicted) & (y_true == 1))).sum(axis=0) | |
X_err_0 = (X_augmented * ((y_true != y_predicted) & (y_true == 0))).sum(axis=0) | |
grad = X_err_0 - X_err_1 | |
return phi_old - grad | |
def loss(phi, X_augmented, y_true, y_predicted): | |
X_err_1 = (X_augmented * ((y_true != y_predicted) & (y_true == 1))) | |
X_err_0 = (X_augmented * ((y_true != y_predicted) & (y_true == 0))) | |
value = (X_err_0.dot(phi) - X_err_1.dot(phi)).sum() | |
return value | |
def train(phi_init, X_augmented, y_true): | |
phi_old = phi_init | |
for i in range(1000): | |
phi_new = updated_phi(phi_old, X_augmented, y_true, predict(phi_old, X_augmented)) | |
loss_ = loss(phi_new, X_augmented, y_true, predict(phi_new, X_augmented)) | |
# print('iteration #%s, loss: %s' % (i, loss_)) | |
if not loss_: | |
break | |
phi_old = phi_new | |
return phi_new | |
if __name__ == '__main__': | |
input_ = combinations([(Alearn, Aexam), (Blearn, Bexam), (Clearn, Cexam)], 2) | |
labels = combinations('ABC', 2) | |
results = {} | |
for label, ((first_learn, first_exam), (second_learn, second_exam)) in zip(labels, input_): | |
X = np.concatenate((first_learn, second_learn), axis=0) | |
X_augmented = np.concatenate((X, np.ones((len(X), 1))), axis=1) | |
y_true = np.concatenate((np.zeros((len(first_learn), 1)), np.ones((len(second_learn), 1))), axis=0) | |
phi_init = np.array([1, 1, 1, 1, 1]) | |
phi_star = train(phi_init, X_augmented, y_true) | |
exam = np.concatenate((first_exam, second_exam), axis=0) | |
exam_augmented = np.concatenate((exam, np.ones((len(exam), 1))), axis=1) | |
y_true = np.concatenate((np.zeros((len(first_exam), 1)), np.ones((len(second_exam), 1))), axis=0) | |
y_predicted = predict(phi_star, exam_augmented) | |
loss_ = loss(phi_star, X_augmented=exam_augmented, y_predicted=y_predicted, y_true=y_true) | |
mis_0 = ((y_true != y_predicted) & (y_true == 0)).sum() | |
mis_1 = ((y_true != y_predicted) & (y_true == 1)).sum() | |
results[' vs '.join(label)] = 'loss={:.2f}, misclassified zero class:{}, first cass:{}'.format(loss_, mis_0, mis_1) | |
for item in results.items(): | |
print("%s: %s" % item) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment