Skip to content

Instantly share code, notes, and snippets.

@krrrr38
Created December 11, 2011 16:52
Show Gist options
  • Save krrrr38/1461518 to your computer and use it in GitHub Desktop.
Save krrrr38/1461518 to your computer and use it in GitHub Desktop.
Fisher linear discriminant
#coding:utf-8
# 4.1.4 フィッシャーの線形判別(p.185)
# 3クラスへの応用
import numpy as np
from pylab import *
import sys
N = 150 # データ数
def f(x, a, b):
# 決定境界の直線の方程式
return a * x + b
def example5():
# 訓練データを作成
cls1 = []
cls2 = []
cls3 = []
# データは正規分布に従って生成
mean1 = [1, 3]
mean2 = [3, 1]
mean3 = [-1, -1]
cov1 = [[2.0, 0.0], [0.0, 0.1]]
cov2 = [[1.0, 0.0], [0.0, 1.0]]
# データ作成
cls1.extend(np.random.multivariate_normal(mean1, cov1, N/3))
cls2.extend(np.random.multivariate_normal(mean2, cov1, N/3))
cls3.extend(np.random.multivariate_normal(mean3, cov2, N/3))
# 書くクラスの平均をプロット
m1 = np.mean(cls1, axis=0)
m2 = np.mean(cls2, axis=0)
m3 = np.mean(cls3, axis=0)
plot([m1[0]], [m1[1]], 'b+')
plot([m2[0]], [m2[1]], 'r+')
plot([m3[0]], [m3[1]], 'c+')
print m1, m2, m3
# 総クラス内共分散行列を計算
Sw = zeros((2, 2))
for n in range(len(cls1)):
xn = matrix(cls1[n]).reshape(2,1)
m1 = matrix(m1).reshape(2,1)
Sw += (xn - m1) * transpose(xn - m1)
for n in range(len(cls2)):
xn = matrix(cls2[n]).reshape(2,1)
m2 = matrix(m2).reshape(2,1)
Sw += (xn - m2) * transpose(xn - m2)
for n in range(len(cls3)):
xn = matrix(cls3[n]).reshape(2,1)
m3 = matrix(m3).reshape(2,1)
Sw += (xn - m3) * transpose(xn - m3)
Sw_inv = np.linalg.inv(Sw)
w1 = Sw_inv * (m2 - m1) # 1個目の識別
w2 = Sw_inv * (m3 - m2) # 2個目の識別
# 訓練データを描画#
x1, x2 = np.transpose(np.array(cls1))
plot(x1, x2, 'bo')
x1, x2 = np.transpose(np.array(cls2))
plot(x1, x2, 'ro')
x1, x2 = np.transpose(np.array(cls3))
plot(x1, x2, 'co')
# 識別境界を描画
# wは識別境界と直行するベクトル
a1 = -(w1[0,0] / w1[1,0]) # 識別直線1の傾き
a2 = -(w2[0,0] / w2[1,0]) # 識別直線2の傾き
# 傾きがaで平均の中点mを通る直線のy切片bを求める
m1 = (m1 + m2) / 2
m2 = (m2 + m3) / 2
b1 = -a1 * m1[0,0] + m1[1,0] # 識別直線1のy切片
b2 = -a2 * m2[0,0] + m2[1,0] # 識別直線2のy切片
x1 = np.linspace(-3, 6, 1000)
x2 = [f(x, a1, b1) for x in x1]
plot(x1, x2, 'g-')
x2 = [f(x, a2, b2) for x in x1]
plot(x1, x2, 'y-')
xlim(-3, 6)
ylim(-3, 4)
show()
if __name__ == "__main__":
example5()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment