Skip to content

Instantly share code, notes, and snippets.

@akkijp
Last active June 10, 2016 17:02
Show Gist options
  • Save akkijp/4d5df0412d2676f9ca00ffbae31beaee to your computer and use it in GitHub Desktop.
Save akkijp/4d5df0412d2676f9ca00ffbae31beaee to your computer and use it in GitHub Desktop.
python(scikit-learn)で決定木を試したソースコード
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
from sklearn import tree
# 教師データをロード
df = pd.read_csv('xor_simple.csv');
data_array = df[['x', 'y']].values
class_array = df['class'].values
# 学習(決定木)
clf = tree.DecisionTreeClassifier()
clf = clf.fit(data_array, class_array)
#学習後に、2つのデータを与えてそれらを分類。
#与えられた教師データの特徴から考えると
# x=2.0, y=1.0 であれば、クラス「0」に分類されるはず。
# x=1.0, y= -0.5であれば、クラス「1」に分類されるはず。
result = clf.predict([[2., 1.], [1., -0.5]])
print result
### 決定境界の可視化
import matplotlib.pyplot as plt
# Parameters for plot
n_classes = 2
plot_colors = "br"
plot_step = 0.05
#グラフ描画時の説明変数 x、yの最大値&最小値を算出。
#グラフ描画のメッシュを定義
x_min, x_max = data_array[:, 0].min() - 1, data_array[:, 0].max() + 1
y_min, y_max = data_array[:, 1].min() - 1, data_array[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
np.arange(y_min, y_max, plot_step))
#各メッシュ上での決定木による分類を計算
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
#決定木による分類を等高線フィールドプロットでプロット
cs = plt.contour(xx, yy, Z, cmap=plt.cm.Paired)
plt.xlabel('x')
plt.ylabel('y')
plt.axis("tight")
#教師データも重ねてプロット
for i, color in zip(range(n_classes), plot_colors):
idx = np.where(class_array == i)
plt.scatter(data_array[idx, 0], data_array[idx, 1], c=color, label=['a','b'],
cmap=plt.cm.Paired)
plt.axis("tight")
plt.show()
x y class
-2.121580967 -0.365665506 0
-1.797266776 -1.618523073 0
-0.717394571 -0.738177485 0
-0.830662087 -1.058791442 0
-1.145322845 -0.750618964 0
-1.193923462 -0.606689289 0
-1.510185511 -0.071761198 0
-1.204388261 -0.894366562 0
-1.089685904 -0.957121487 0
0.120336282 -0.822792142 0
-0.613491473 -0.574858212 0
-0.42222762 -1.156157105 0
-1.46093916 -1.569185406 0
-0.979112372 -1.907546145 0
-0.817423506 -1.262125852 0
-0.697717895 -1.693726266 0
-0.90938065 -1.014607998 0
-0.684790594 -0.715484122 0
-1.001101395 -0.677211991 0
-0.773787445 -0.781490351 0
0.679649577 2.009710795 0
1.087990067 1.128682442 0
-0.471220285 0.658977628 0
0.75666077 1.253800712 0
1.107170215 0.942305548 0
0.547427043 1.25471797 0
0.172337119 1.321508052 0
0.75441344 0.577555831 0
1.193582186 1.536259738 0
2.010374491 0.793770601 0
0.712591176 1.993967333 0
1.575928337 0.465069191 0
0.312669319 2.340031348 0
0.474045066 0.882761763 0
1.627177872 0.911410143 0
0.581137661 1.335835314 0
0.513237718 1.339515202 0
2.08397179 0.525077456 0
1.462142301 0.515768791 0
-0.808576845 1.678407316 1
-1.113784281 0.673034693 1
-1.755298852 1.081698325 1
-0.901312814 0.374982319 1
0.468540221 1.0636286 1
-0.657399227 1.340507018 1
-1.444319158 0.84839978 1
0.165544458 1.204151487 1
-1.886536914 0.865973372 1
-1.738291657 1.325458792 1
-1.187196553 0.958926432 1
-1.481173228 0.748988483 1
-1.230384608 1.538903521 1
-1.338228883 0.406475801 1
-1.501449356 0.691336663 1
-0.271862827 0.775127338 1
-0.5807241 1.011103064 1
-1.075546799 1.192832784 1
-1.290304416 1.306105006 1
-1.047113498 0.573244887 1
-1.829345211 1.526520998 1
1.655654668 -0.919084924 1
0.354419891 -1.514486921 1
1.294325506 -1.025373154 1
1.059651738 -0.329919638 1
0.776688583 -0.981650811 1
0.852488155 -1.050404224 1
0.709031861 -0.754367705 1
0.286452122 -0.675478837 1
1.375405012 -2.322473849 1
0.802597412 -1.04669331 1
1.086726869 -0.457528884 1
1.176494782 -1.678981604 1
1.919258297 -0.608970286 1
0.416852043 -1.265871275 1
0.849027825 -0.682913596 1
0.838646307 -0.948194085 1
0.61533752 -0.138612196 1
1.938780666 -1.05737308 1
1.226448812 -0.534711982 1
1.059234006 -0.374870802 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment