これは、下記のサイトを参考にpython(scikit-learn)で決定木を試したソースコード
http://data-hacker.blogspot.jp/2014/05/pythonscikit-learn.html
これは、下記のサイトを参考にpython(scikit-learn)で決定木を試したソースコード
http://data-hacker.blogspot.jp/2014/05/pythonscikit-learn.html
| # -*- coding: utf-8 -*- | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn import tree | |
| # 教師データをロード | |
| df = pd.read_csv('xor_simple.csv'); | |
| data_array = df[['x', 'y']].values | |
| class_array = df['class'].values | |
| # 学習(決定木) | |
| clf = tree.DecisionTreeClassifier() | |
| clf = clf.fit(data_array, class_array) | |
| #学習後に、2つのデータを与えてそれらを分類。 | |
| #与えられた教師データの特徴から考えると | |
| # x=2.0, y=1.0 であれば、クラス「0」に分類されるはず。 | |
| # x=1.0, y= -0.5であれば、クラス「1」に分類されるはず。 | |
| result = clf.predict([[2., 1.], [1., -0.5]]) | |
| print result | |
| ### 決定境界の可視化 | |
| import matplotlib.pyplot as plt | |
| # Parameters for plot | |
| n_classes = 2 | |
| plot_colors = "br" | |
| plot_step = 0.05 | |
| #グラフ描画時の説明変数 x、yの最大値&最小値を算出。 | |
| #グラフ描画のメッシュを定義 | |
| x_min, x_max = data_array[:, 0].min() - 1, data_array[:, 0].max() + 1 | |
| y_min, y_max = data_array[:, 1].min() - 1, data_array[:, 1].max() + 1 | |
| xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), | |
| np.arange(y_min, y_max, plot_step)) | |
| #各メッシュ上での決定木による分類を計算 | |
| Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) | |
| Z = Z.reshape(xx.shape) | |
| #決定木による分類を等高線フィールドプロットでプロット | |
| cs = plt.contour(xx, yy, Z, cmap=plt.cm.Paired) | |
| plt.xlabel('x') | |
| plt.ylabel('y') | |
| plt.axis("tight") | |
| #教師データも重ねてプロット | |
| for i, color in zip(range(n_classes), plot_colors): | |
| idx = np.where(class_array == i) | |
| plt.scatter(data_array[idx, 0], data_array[idx, 1], c=color, label=['a','b'], | |
| cmap=plt.cm.Paired) | |
| plt.axis("tight") | |
| plt.show() |
| x | y | class | |
|---|---|---|---|
| -2.121580967 | -0.365665506 | 0 | |
| -1.797266776 | -1.618523073 | 0 | |
| -0.717394571 | -0.738177485 | 0 | |
| -0.830662087 | -1.058791442 | 0 | |
| -1.145322845 | -0.750618964 | 0 | |
| -1.193923462 | -0.606689289 | 0 | |
| -1.510185511 | -0.071761198 | 0 | |
| -1.204388261 | -0.894366562 | 0 | |
| -1.089685904 | -0.957121487 | 0 | |
| 0.120336282 | -0.822792142 | 0 | |
| -0.613491473 | -0.574858212 | 0 | |
| -0.42222762 | -1.156157105 | 0 | |
| -1.46093916 | -1.569185406 | 0 | |
| -0.979112372 | -1.907546145 | 0 | |
| -0.817423506 | -1.262125852 | 0 | |
| -0.697717895 | -1.693726266 | 0 | |
| -0.90938065 | -1.014607998 | 0 | |
| -0.684790594 | -0.715484122 | 0 | |
| -1.001101395 | -0.677211991 | 0 | |
| -0.773787445 | -0.781490351 | 0 | |
| 0.679649577 | 2.009710795 | 0 | |
| 1.087990067 | 1.128682442 | 0 | |
| -0.471220285 | 0.658977628 | 0 | |
| 0.75666077 | 1.253800712 | 0 | |
| 1.107170215 | 0.942305548 | 0 | |
| 0.547427043 | 1.25471797 | 0 | |
| 0.172337119 | 1.321508052 | 0 | |
| 0.75441344 | 0.577555831 | 0 | |
| 1.193582186 | 1.536259738 | 0 | |
| 2.010374491 | 0.793770601 | 0 | |
| 0.712591176 | 1.993967333 | 0 | |
| 1.575928337 | 0.465069191 | 0 | |
| 0.312669319 | 2.340031348 | 0 | |
| 0.474045066 | 0.882761763 | 0 | |
| 1.627177872 | 0.911410143 | 0 | |
| 0.581137661 | 1.335835314 | 0 | |
| 0.513237718 | 1.339515202 | 0 | |
| 2.08397179 | 0.525077456 | 0 | |
| 1.462142301 | 0.515768791 | 0 | |
| -0.808576845 | 1.678407316 | 1 | |
| -1.113784281 | 0.673034693 | 1 | |
| -1.755298852 | 1.081698325 | 1 | |
| -0.901312814 | 0.374982319 | 1 | |
| 0.468540221 | 1.0636286 | 1 | |
| -0.657399227 | 1.340507018 | 1 | |
| -1.444319158 | 0.84839978 | 1 | |
| 0.165544458 | 1.204151487 | 1 | |
| -1.886536914 | 0.865973372 | 1 | |
| -1.738291657 | 1.325458792 | 1 | |
| -1.187196553 | 0.958926432 | 1 | |
| -1.481173228 | 0.748988483 | 1 | |
| -1.230384608 | 1.538903521 | 1 | |
| -1.338228883 | 0.406475801 | 1 | |
| -1.501449356 | 0.691336663 | 1 | |
| -0.271862827 | 0.775127338 | 1 | |
| -0.5807241 | 1.011103064 | 1 | |
| -1.075546799 | 1.192832784 | 1 | |
| -1.290304416 | 1.306105006 | 1 | |
| -1.047113498 | 0.573244887 | 1 | |
| -1.829345211 | 1.526520998 | 1 | |
| 1.655654668 | -0.919084924 | 1 | |
| 0.354419891 | -1.514486921 | 1 | |
| 1.294325506 | -1.025373154 | 1 | |
| 1.059651738 | -0.329919638 | 1 | |
| 0.776688583 | -0.981650811 | 1 | |
| 0.852488155 | -1.050404224 | 1 | |
| 0.709031861 | -0.754367705 | 1 | |
| 0.286452122 | -0.675478837 | 1 | |
| 1.375405012 | -2.322473849 | 1 | |
| 0.802597412 | -1.04669331 | 1 | |
| 1.086726869 | -0.457528884 | 1 | |
| 1.176494782 | -1.678981604 | 1 | |
| 1.919258297 | -0.608970286 | 1 | |
| 0.416852043 | -1.265871275 | 1 | |
| 0.849027825 | -0.682913596 | 1 | |
| 0.838646307 | -0.948194085 | 1 | |
| 0.61533752 | -0.138612196 | 1 | |
| 1.938780666 | -1.05737308 | 1 | |
| 1.226448812 | -0.534711982 | 1 | |
| 1.059234006 | -0.374870802 | 1 |