Created
October 25, 2021 12:25
-
-
Save billju/8bc406bcc298d8cd8f83b787fae4b146 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
A = np.array([['男', '市場部', '24000', '無'], | |
['女', '研發部', '45000', '有'], | |
['男', '會計部', '45000', '無'], | |
['男', '研發部', '40000', '無'], | |
['女', '市場部', '24000', '有'], | |
['女', '研發部', '40000', '無'], | |
['男', '市場部', '24000', '有']]) | |
def Hamming_distance(A): | |
# 維度n列p行 | |
n, p = A.shape | |
M = np.zeros((n,n)) | |
for i in range(p): | |
# 取出第i行,並向右複製n次 | |
B = np.tile(A[:,i],(n,1)) | |
# 與轉置矩陣比對後加總 | |
M += B==B.T | |
# 取上三角矩陣,並擷取上方一格 | |
M = np.triu(M, k=1) | |
# 取log2(n)時跳過0 | |
log2 = lambda M: np.log2(M,where=M!=0) | |
# 轉換成機率 | |
p = M / p | |
# 計算熵 | |
H = -(p*log2(p)+(1-p)*log2(1-p)) | |
# 加總 | |
return H.sum() | |
Hamming_distance(A) | |
for i,col in enumerate(['性別','部門','薪資','配偶']): | |
H = Hamming_distance(np.delete(A, i, axis=1)) | |
print(col, H) # H 降低越多代表該特徵越重要 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment