これは、下記のサイトを参考にpython(scikit-learn)で決定木を試したソースコード
http://data-hacker.blogspot.jp/2014/05/pythonscikit-learn.html
これは、下記のサイトを参考にpython(scikit-learn)で決定木を試したソースコード
http://data-hacker.blogspot.jp/2014/05/pythonscikit-learn.html
# -*- coding: utf-8 -*- | |
import pandas as pd | |
import numpy as np | |
from sklearn import tree | |
# 教師データをロード | |
df = pd.read_csv('xor_simple.csv'); | |
data_array = df[['x', 'y']].values | |
class_array = df['class'].values | |
# 学習(決定木) | |
clf = tree.DecisionTreeClassifier() | |
clf = clf.fit(data_array, class_array) | |
#学習後に、2つのデータを与えてそれらを分類。 | |
#与えられた教師データの特徴から考えると | |
# x=2.0, y=1.0 であれば、クラス「0」に分類されるはず。 | |
# x=1.0, y= -0.5であれば、クラス「1」に分類されるはず。 | |
result = clf.predict([[2., 1.], [1., -0.5]]) | |
print result | |
### 決定境界の可視化 | |
import matplotlib.pyplot as plt | |
# Parameters for plot | |
n_classes = 2 | |
plot_colors = "br" | |
plot_step = 0.05 | |
#グラフ描画時の説明変数 x、yの最大値&最小値を算出。 | |
#グラフ描画のメッシュを定義 | |
x_min, x_max = data_array[:, 0].min() - 1, data_array[:, 0].max() + 1 | |
y_min, y_max = data_array[:, 1].min() - 1, data_array[:, 1].max() + 1 | |
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), | |
np.arange(y_min, y_max, plot_step)) | |
#各メッシュ上での決定木による分類を計算 | |
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) | |
Z = Z.reshape(xx.shape) | |
#決定木による分類を等高線フィールドプロットでプロット | |
cs = plt.contour(xx, yy, Z, cmap=plt.cm.Paired) | |
plt.xlabel('x') | |
plt.ylabel('y') | |
plt.axis("tight") | |
#教師データも重ねてプロット | |
for i, color in zip(range(n_classes), plot_colors): | |
idx = np.where(class_array == i) | |
plt.scatter(data_array[idx, 0], data_array[idx, 1], c=color, label=['a','b'], | |
cmap=plt.cm.Paired) | |
plt.axis("tight") | |
plt.show() |
x | y | class | |
---|---|---|---|
-2.121580967 | -0.365665506 | 0 | |
-1.797266776 | -1.618523073 | 0 | |
-0.717394571 | -0.738177485 | 0 | |
-0.830662087 | -1.058791442 | 0 | |
-1.145322845 | -0.750618964 | 0 | |
-1.193923462 | -0.606689289 | 0 | |
-1.510185511 | -0.071761198 | 0 | |
-1.204388261 | -0.894366562 | 0 | |
-1.089685904 | -0.957121487 | 0 | |
0.120336282 | -0.822792142 | 0 | |
-0.613491473 | -0.574858212 | 0 | |
-0.42222762 | -1.156157105 | 0 | |
-1.46093916 | -1.569185406 | 0 | |
-0.979112372 | -1.907546145 | 0 | |
-0.817423506 | -1.262125852 | 0 | |
-0.697717895 | -1.693726266 | 0 | |
-0.90938065 | -1.014607998 | 0 | |
-0.684790594 | -0.715484122 | 0 | |
-1.001101395 | -0.677211991 | 0 | |
-0.773787445 | -0.781490351 | 0 | |
0.679649577 | 2.009710795 | 0 | |
1.087990067 | 1.128682442 | 0 | |
-0.471220285 | 0.658977628 | 0 | |
0.75666077 | 1.253800712 | 0 | |
1.107170215 | 0.942305548 | 0 | |
0.547427043 | 1.25471797 | 0 | |
0.172337119 | 1.321508052 | 0 | |
0.75441344 | 0.577555831 | 0 | |
1.193582186 | 1.536259738 | 0 | |
2.010374491 | 0.793770601 | 0 | |
0.712591176 | 1.993967333 | 0 | |
1.575928337 | 0.465069191 | 0 | |
0.312669319 | 2.340031348 | 0 | |
0.474045066 | 0.882761763 | 0 | |
1.627177872 | 0.911410143 | 0 | |
0.581137661 | 1.335835314 | 0 | |
0.513237718 | 1.339515202 | 0 | |
2.08397179 | 0.525077456 | 0 | |
1.462142301 | 0.515768791 | 0 | |
-0.808576845 | 1.678407316 | 1 | |
-1.113784281 | 0.673034693 | 1 | |
-1.755298852 | 1.081698325 | 1 | |
-0.901312814 | 0.374982319 | 1 | |
0.468540221 | 1.0636286 | 1 | |
-0.657399227 | 1.340507018 | 1 | |
-1.444319158 | 0.84839978 | 1 | |
0.165544458 | 1.204151487 | 1 | |
-1.886536914 | 0.865973372 | 1 | |
-1.738291657 | 1.325458792 | 1 | |
-1.187196553 | 0.958926432 | 1 | |
-1.481173228 | 0.748988483 | 1 | |
-1.230384608 | 1.538903521 | 1 | |
-1.338228883 | 0.406475801 | 1 | |
-1.501449356 | 0.691336663 | 1 | |
-0.271862827 | 0.775127338 | 1 | |
-0.5807241 | 1.011103064 | 1 | |
-1.075546799 | 1.192832784 | 1 | |
-1.290304416 | 1.306105006 | 1 | |
-1.047113498 | 0.573244887 | 1 | |
-1.829345211 | 1.526520998 | 1 | |
1.655654668 | -0.919084924 | 1 | |
0.354419891 | -1.514486921 | 1 | |
1.294325506 | -1.025373154 | 1 | |
1.059651738 | -0.329919638 | 1 | |
0.776688583 | -0.981650811 | 1 | |
0.852488155 | -1.050404224 | 1 | |
0.709031861 | -0.754367705 | 1 | |
0.286452122 | -0.675478837 | 1 | |
1.375405012 | -2.322473849 | 1 | |
0.802597412 | -1.04669331 | 1 | |
1.086726869 | -0.457528884 | 1 | |
1.176494782 | -1.678981604 | 1 | |
1.919258297 | -0.608970286 | 1 | |
0.416852043 | -1.265871275 | 1 | |
0.849027825 | -0.682913596 | 1 | |
0.838646307 | -0.948194085 | 1 | |
0.61533752 | -0.138612196 | 1 | |
1.938780666 | -1.05737308 | 1 | |
1.226448812 | -0.534711982 | 1 | |
1.059234006 | -0.374870802 | 1 |