Created
December 11, 2011 16:52
-
-
Save krrrr38/1461518 to your computer and use it in GitHub Desktop.
Fisher linear discriminant
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#coding:utf-8 | |
# 4.1.4 フィッシャーの線形判別(p.185) | |
# 3クラスへの応用 | |
import numpy as np | |
from pylab import * | |
import sys | |
N = 150 # データ数 | |
def f(x, a, b): | |
# 決定境界の直線の方程式 | |
return a * x + b | |
def example5(): | |
# 訓練データを作成 | |
cls1 = [] | |
cls2 = [] | |
cls3 = [] | |
# データは正規分布に従って生成 | |
mean1 = [1, 3] | |
mean2 = [3, 1] | |
mean3 = [-1, -1] | |
cov1 = [[2.0, 0.0], [0.0, 0.1]] | |
cov2 = [[1.0, 0.0], [0.0, 1.0]] | |
# データ作成 | |
cls1.extend(np.random.multivariate_normal(mean1, cov1, N/3)) | |
cls2.extend(np.random.multivariate_normal(mean2, cov1, N/3)) | |
cls3.extend(np.random.multivariate_normal(mean3, cov2, N/3)) | |
# 書くクラスの平均をプロット | |
m1 = np.mean(cls1, axis=0) | |
m2 = np.mean(cls2, axis=0) | |
m3 = np.mean(cls3, axis=0) | |
plot([m1[0]], [m1[1]], 'b+') | |
plot([m2[0]], [m2[1]], 'r+') | |
plot([m3[0]], [m3[1]], 'c+') | |
print m1, m2, m3 | |
# 総クラス内共分散行列を計算 | |
Sw = zeros((2, 2)) | |
for n in range(len(cls1)): | |
xn = matrix(cls1[n]).reshape(2,1) | |
m1 = matrix(m1).reshape(2,1) | |
Sw += (xn - m1) * transpose(xn - m1) | |
for n in range(len(cls2)): | |
xn = matrix(cls2[n]).reshape(2,1) | |
m2 = matrix(m2).reshape(2,1) | |
Sw += (xn - m2) * transpose(xn - m2) | |
for n in range(len(cls3)): | |
xn = matrix(cls3[n]).reshape(2,1) | |
m3 = matrix(m3).reshape(2,1) | |
Sw += (xn - m3) * transpose(xn - m3) | |
Sw_inv = np.linalg.inv(Sw) | |
w1 = Sw_inv * (m2 - m1) # 1個目の識別 | |
w2 = Sw_inv * (m3 - m2) # 2個目の識別 | |
# 訓練データを描画# | |
x1, x2 = np.transpose(np.array(cls1)) | |
plot(x1, x2, 'bo') | |
x1, x2 = np.transpose(np.array(cls2)) | |
plot(x1, x2, 'ro') | |
x1, x2 = np.transpose(np.array(cls3)) | |
plot(x1, x2, 'co') | |
# 識別境界を描画 | |
# wは識別境界と直行するベクトル | |
a1 = -(w1[0,0] / w1[1,0]) # 識別直線1の傾き | |
a2 = -(w2[0,0] / w2[1,0]) # 識別直線2の傾き | |
# 傾きがaで平均の中点mを通る直線のy切片bを求める | |
m1 = (m1 + m2) / 2 | |
m2 = (m2 + m3) / 2 | |
b1 = -a1 * m1[0,0] + m1[1,0] # 識別直線1のy切片 | |
b2 = -a2 * m2[0,0] + m2[1,0] # 識別直線2のy切片 | |
x1 = np.linspace(-3, 6, 1000) | |
x2 = [f(x, a1, b1) for x in x1] | |
plot(x1, x2, 'g-') | |
x2 = [f(x, a2, b2) for x in x1] | |
plot(x1, x2, 'y-') | |
xlim(-3, 6) | |
ylim(-3, 4) | |
show() | |
if __name__ == "__main__": | |
example5() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment