Last active
August 29, 2015 14:16
-
-
Save matsuken92/ad6d48574b92b54cefb6 to your computer and use it in GitHub Desktop.
手書き数字をpythonでもてあそぶ その2(識別する) ref: http://qiita.com/kenmatsu4/items/2d21466078917c200033
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
C = \{0, 1, 2, 3, 4, 5, 6, 7, 8, 9\} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
y_i= (y_1, y_2,...,y_{784}) (i=0,1,...,9) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
\hat{y}_i = \frac{1}{n_i}\sum_{j=1}^{n_i} y_j |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x_j= (x_1, x_2,...,x_{784}) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{\rm argmin}_i{({\rm distance}_i)} = {\rm argmin}_i{(\|\hat{y}_i - x_j\|)} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
from collections import defaultdict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
image_dict = dict() | |
def add_image(label, image_vector): | |
vec = np.array(image_vector) | |
if label in image_dict: | |
image_dict[label] += vec | |
else: | |
image_dict[label] = vec | |
return image_dict | |
label_dd = defaultdict(int) | |
def count_label(data): | |
for d in data: | |
label_dd[d[0]] += 1 | |
return label_dd | |
def plot_digits(X, Y, Z, size_x, size_y, counter, title): | |
plt.subplot(size_x, size_y, counter) | |
plt.title(title) | |
plt.xlim(0,27) | |
plt.ylim(0,27) | |
plt.pcolor(X, Y, Z) | |
plt.gray() | |
plt.tick_params(labelbottom="off") | |
plt.tick_params(labelleft="off") | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
size = 28 | |
raw_data= np.loadtxt('train_master.csv',delimiter=',',skiprows=1) | |
# draw digit images | |
plt.figure(figsize=(11, 6)) | |
# data aggregation | |
for i in range(len(raw_data)): | |
add_image(raw_data[i,0],raw_data[i,1:785]) | |
count_dict = count_label(raw_data) | |
standardized_digit_dict = dict() # 代表値を格納する辞書オブジェクト | |
count = 0 | |
for key in image_dict.keys(): | |
count += 1 | |
X, Y = np.meshgrid(range(size),range(size)) | |
num = label_dd[key] | |
Z = image_dict[key].reshape(size,size)/num | |
Z = Z[::-1,:] | |
standardized_digit_dict[int(key)] = Z | |
plot_digits(X, Y, standardized_digit_dict[int(key)], 2, 5, count, "") | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_data= np.loadtxt('test_small.csv',delimiter=',',skiprows=1) | |
# compare 1 tested digit vs average digits with norm | |
plt.figure(figsize=(10, 9)) | |
for i in range(1): # 最初の1つだけ試してみる | |
result_dict = defaultdict(float) | |
X, Y = np.meshgrid(range(size),range(size)) | |
Z = test_data[i].reshape(size,size) | |
Z = Z[::-1,:] | |
flat_Z = Z.flatten() | |
plot_digits(X, Y, Z, 3, 4, 1, "tested") | |
count = 0 | |
for key in standardized_digit_dict.keys(): | |
count += 1 | |
X1 = standardized_digit_dict[key] | |
flat_X1 = standardized_digit_dict[key].flatten() | |
norm = np.linalg.norm(flat_X1 - flat_Z) # 各代表値と識別対象データとの距離の導出 | |
plot_digits(X, Y, X1, 3, 4, (1+count), "d=%.3f"% norm) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# recognize digits | |
plt.figure(figsize=(15, 130)) | |
for i in range(len(test_data)): | |
result_dict = defaultdict(float) | |
X, Y = np.meshgrid(range(size),range(size)) | |
tested = test_data[i].reshape(size,size) | |
tested = tested[::-1,:] | |
flat_tested = tested.flatten() | |
norm_list=[] | |
count = 0 | |
for key in standardized_digit_dict.keys(): | |
count += 1 | |
sdd = standardized_digit_dict[key] | |
flat_sdd = sdd.flatten() | |
norm = np.linalg.norm(flat_sdd - flat_tested) | |
norm_list.append((key, norm)) | |
norm_list = np.array(norm_list) | |
min_result = norm_list[np.argmin(norm_list[:,1])] | |
plot_digits(X, Y, tested, 40, 5, i+1, "l=%d, n=%d" % (min_result[0], min_result[1])) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment