Skip to content

Instantly share code, notes, and snippets.

@staybuzz
Created June 19, 2016 07:15
Show Gist options
  • Save staybuzz/26f41e8c2cbb79803a8a11782f4df7cd to your computer and use it in GitHub Desktop.
Save staybuzz/26f41e8c2cbb79803a8a11782f4df7cd to your computer and use it in GitHub Desktop.
# coding: utf-8
# http://momijiame.tumblr.com/post/114751531866/python-iris-%E3%83%87%E3%83%BC%E3%82%BF%E3%82%BB%E3%83%83%E3%83%88%E3%82%92%E3%82%B5%E3%83%9D%E3%83%BC%E3%83%88%E3%83%99%E3%82%AF%E3%82%BF%E3%83%BC%E3%83%9E%E3%82%B7%E3%83%B3%E3%81%A7%E5%88%86%E9%A1%9E%E3%81%97%E3%81%A6%E3%81%BF%E3%82%8B
# https://github.com/levelfour/machine-learning-2014/wiki/%E7%AC%AC1%E5%9B%9E---iris-classification
# In[1]:
#get_ipython().magic('matplotlib inline')
# In[2]:
from sklearn.svm import LinearSVC
import numpy as np
# In[62]:
data = [[1,2], [1,4], [2,4], [2,1], [5,1], [4,2]] # 例題6.1の学習データ
#label = ["c1", "c1", "c1", "c2", "c2", "c2"] # 正解ラベル
label = ["1", "1", "1", "2", "2", "2"] # 正解ラベル
print(data)
print(label)
# In[87]:
data = np.asarray(data, np.float32)
#data[3:,1]
# In[89]:
classifier = LinearSVC() # 線形SVC
classifier.fit(data, label) # 学習
# In[90]:
result = classifier.predict(data)
#result
# In[66]:
# 意味ないけど
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(label, result)
print(accuracy)
# ### グラフにプロットしたい
# In[67]:
import matplotlib.pyplot as plt
# In[68]:
#plt.scatter(data[:3,0], data[:3,1], c='red') # クラス1
#plt.scatter(data[3:,0], data[3:,1], c='blue') # クラス2
# In[69]:
# データの範囲でメッシュ状に点を取る
x_min = data[:, 0].min() - 1
x_max = data[:, 0].max() + 1
y_min = data[:, 1].min() - 1
y_max = data[:, 1].max() + 1
grid_interval = 1.0
xx, yy = np.meshgrid(
np.arange(x_min, x_max, grid_interval),
np.arange(y_min, y_max, grid_interval)
)
# In[70]:
Z = classifier.predict(np.c_[xx.ravel(), yy.ravel()])
#Z
# In[71]:
Z.reshape(xx.shape)
# In[82]:
Z = Z.reshape(xx.shape)
plt.figure(figsize=(10,6), dpi=80)
plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.2)
plt.scatter(data[:3,0], data[:3,1], c='red') # クラス1
plt.scatter(data[3:,0], data[3:,1], c='blue') # クラス2
#plt.figure(figsize=(8,6))
plt.grid()
plt.savefig("svm.jpg", dpi=80)
plt.show()
# In[ ]:
# In[ ]:
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment