Last active
November 13, 2017 07:41
-
-
Save zhenghaoz/36706eb8e6e9913c1edb32474f6ec17f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**6.3** 选择两个UCI数据集,分别用线性核和高斯核训练一个SVM,并与BP神经网络和C4.5决策树进行实验比较。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import graphviz as gv\n", | |
"from sklearn import datasets, svm, cross_validation\n", | |
"from single_hidden_bp import SingleHiddenBP\n", | |
"from decision_tree import DecisionTree" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Iris Data Set**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"iris = datasets.load_iris()\n", | |
"iris_data_train, iris_data_test, iris_target_train, iris_target_test = \\\n", | |
" cross_validation.train_test_split(iris.data, iris.target, test_size=0.33, random_state=42)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"线性核SVM:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.0" | |
] | |
}, | |
"execution_count": 61, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"iris_linear_svm = svm.SVC(kernel='linear')\n", | |
"iris_linear_svm.fit(iris_data_train, iris_target_train)\n", | |
"np.sum(iris_linear_svm.predict(iris_data_test) == iris_target_test) / len(iris_data_test)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"高斯核SVM:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.0" | |
] | |
}, | |
"execution_count": 60, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"iris_rbf_svm = svm.SVC(kernel='rbf')\n", | |
"iris_rbf_svm.fit(iris_data_train, iris_target_train)\n", | |
"np.sum(iris_rbf_svm.predict(iris_data_test) == iris_target_test) / len(iris_data_test)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"BP神经网络:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.98\n" | |
] | |
} | |
], | |
"source": [ | |
"iris_data_labels = np.unique(iris_target_train)\n", | |
"iris_data_bp = np.array([])\n", | |
"for label in iris_data_labels:\n", | |
" iris_data_bp = np.append(iris_data_bp, iris_target_train == label)\n", | |
"iris_data_bp = iris_data_bp.reshape([len(iris_data_labels),-1]).T\n", | |
"iris_bp = SingleHiddenBP(4, 10, len(iris_data_labels), learning_rate=1, learning_round=500)\n", | |
"iris_bp.fit(iris_data_train, iris_data_bp)\n", | |
"print(np.sum(iris_data_labels[np.argmax([iris_bp.predict(w) for w in iris_data_test], axis=1)] == iris_target_test) / len(iris_target_test))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"C4.5决策树:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.98\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n", | |
" -->\n", | |
"<!-- Title: %3 Pages: 1 -->\n", | |
"<svg width=\"245pt\" height=\"218pt\"\n", | |
" viewBox=\"0.00 0.00 245.00 218.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 214)\">\n", | |
"<title>%3</title>\n", | |
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-214 241,-214 241,4 -4,4\"/>\n", | |
"<!-- petal length in cm -->\n", | |
"<g id=\"node1\" class=\"node\"><title>petal length in cm </title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"79\" cy=\"-192\" rx=\"77.1866\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"79\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">petal length in cm </text>\n", | |
"</g>\n", | |
"<!-- 0 -->\n", | |
"<g id=\"node2\" class=\"node\"><title>0</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-105\" rx=\"27\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"27\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">0</text>\n", | |
"</g>\n", | |
"<!-- petal length in cm ->0 -->\n", | |
"<g id=\"edge3\" class=\"edge\"><title>petal length in cm ->0</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M57.7704,-174.482C51.9878,-169.112 46.1752,-162.775 42,-156 37.7021,-149.026 34.5692,-140.743 32.3144,-132.946\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"35.676,-131.962 29.832,-123.125 28.8894,-133.678 35.676,-131.962\"/>\n", | |
"<text text-anchor=\"middle\" x=\"76\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\"><=1.900000</text>\n", | |
"</g>\n", | |
"<!-- petal width in cm -->\n", | |
"<g id=\"node3\" class=\"node\"><title>petal width in cm </title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"147\" cy=\"-105\" rx=\"75.2868\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"147\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">petal width in cm </text>\n", | |
"</g>\n", | |
"<!-- petal length in cm ->petal width in cm -->\n", | |
"<g id=\"edge4\" class=\"edge\"><title>petal length in cm ->petal width in cm </title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M94.4425,-174.169C99.5085,-168.49 105.094,-162.054 110,-156 116.39,-148.113 123.077,-139.277 128.979,-131.252\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"131.971,-133.09 135.026,-122.945 126.311,-128.97 131.971,-133.09\"/>\n", | |
"<text text-anchor=\"middle\" x=\"151\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">>1.900000</text>\n", | |
"</g>\n", | |
"<!-- 1 -->\n", | |
"<g id=\"node4\" class=\"node\"><title>1</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"101\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"101\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">1</text>\n", | |
"</g>\n", | |
"<!-- petal width in cm ->1 -->\n", | |
"<g id=\"edge1\" class=\"edge\"><title>petal width in cm ->1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M115.807,-88.4063C108.675,-83.2607 102.018,-76.8083 98,-69 94.4436,-62.0892 93.7361,-53.9096 94.3217,-46.1893\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"97.8121,-46.5061 95.7757,-36.1088 90.8839,-45.5067 97.8121,-46.5061\"/>\n", | |
"<text text-anchor=\"middle\" x=\"132\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\"><=1.700000</text>\n", | |
"</g>\n", | |
"<!-- 2 -->\n", | |
"<g id=\"node5\" class=\"node\"><title>2</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"192\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"192\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">2</text>\n", | |
"</g>\n", | |
"<!-- petal width in cm ->2 -->\n", | |
"<g id=\"edge2\" class=\"edge\"><title>petal width in cm ->2</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M159.053,-86.9052C162.767,-81.3187 166.736,-75.0047 170,-69 174.14,-61.384 178.129,-52.8322 181.546,-44.9839\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"184.867,-46.1159 185.537,-35.5422 178.42,-43.3902 184.867,-46.1159\"/>\n", | |
"<text text-anchor=\"middle\" x=\"207\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">>1.700000</text>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.files.Source at 0x7f05dc8a9c18>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"iris_decision_tree = DecisionTree(evaluator='c4.5')\n", | |
"iris_decision_tree.fit(iris_data_train, iris_target_train)\n", | |
"print(np.sum([iris_decision_tree.predict(w) for w in iris_data_test] == iris_target_test) / len(iris_data_test))\n", | |
"gv.Source(iris_decision_tree.visualize(['sepal length in cm', 'sepal width in cm','petal length in cm ','petal width in cm ']))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Wine Data Set**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"wine = datasets.load_wine()\n", | |
"wine_data_train, wine_data_test, wine_target_train, wine_target_test = \\\n", | |
" cross_validation.train_test_split(wine.data, wine.target, test_size=0.33, random_state=42)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"线性核SVM:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.98305084745762716" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"wine_linear_svm = svm.SVC(kernel='linear')\n", | |
"wine_linear_svm.fit(wine_data_train, wine_target_train)\n", | |
"np.sum(wine_linear_svm.predict(wine_data_test) == wine_target_test) / len(wine_data_test)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"高斯核SVM:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.42372881355932202" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"wine_rbf_svm = svm.SVC(kernel='rbf')\n", | |
"wine_rbf_svm.fit(wine_data_train, wine_target_train)\n", | |
"np.sum(wine_rbf_svm.predict(wine_data_test) == wine_target_test) / len(wine_data_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.40677966101694918" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"wine_data_labels = np.unique(wine_target_train)\n", | |
"wine_data_bp = np.array([])\n", | |
"for label in wine_data_labels:\n", | |
" wine_data_bp = np.append(wine_data_bp, wine_target_train == label)\n", | |
"wine_data_bp = wine_data_bp.reshape([len(wine_data_labels),-1]).T\n", | |
"wine_bp = SingleHiddenBP(13, 13, len(wine_data_labels), learning_rate=1, learning_round=1000)\n", | |
"wine_bp.fit(wine_data_train, wine_data_bp)\n", | |
"np.sum(wine_data_labels[np.argmax([wine_bp.predict(w) for w in wine_data_test], axis=1)] == wine_target_test) / len(wine_target_test)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"C4.5决策树:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.915254237288\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n", | |
" -->\n", | |
"<!-- Title: %3 Pages: 1 -->\n", | |
"<svg width=\"593pt\" height=\"305pt\"\n", | |
" viewBox=\"0.00 0.00 592.58 305.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 301)\">\n", | |
"<title>%3</title>\n", | |
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-301 588.577,-301 588.577,4 -4,4\"/>\n", | |
"<!-- OD280/OD315 of diluted wines -->\n", | |
"<g id=\"node1\" class=\"node\"><title>OD280/OD315 of diluted wines</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"288.883\" cy=\"-279\" rx=\"124.278\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"288.883\" y=\"-275.3\" font-family=\"Times,serif\" font-size=\"14.00\">OD280/OD315 of diluted wines</text>\n", | |
"</g>\n", | |
"<!-- Hue -->\n", | |
"<g id=\"node2\" class=\"node\"><title>Hue</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"230.883\" cy=\"-192\" rx=\"27\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"230.883\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">Hue</text>\n", | |
"</g>\n", | |
"<!-- OD280/OD315 of diluted wines->Hue -->\n", | |
"<g id=\"edge11\" class=\"edge\"><title>OD280/OD315 of diluted wines->Hue</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M263.951,-261.152C257.68,-255.941 251.444,-249.77 246.883,-243 242.279,-236.168 238.932,-227.927 236.528,-220.123\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"239.856,-219.024 233.884,-210.273 233.096,-220.838 239.856,-219.024\"/>\n", | |
"<text text-anchor=\"middle\" x=\"280.883\" y=\"-231.8\" font-family=\"Times,serif\" font-size=\"14.00\"><=2.190000</text>\n", | |
"</g>\n", | |
"<!-- Proline -->\n", | |
"<g id=\"node7\" class=\"node\"><title>Proline</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"356.883\" cy=\"-192\" rx=\"36.2938\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"356.883\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">Proline</text>\n", | |
"</g>\n", | |
"<!-- OD280/OD315 of diluted wines->Proline -->\n", | |
"<g id=\"edge12\" class=\"edge\"><title>OD280/OD315 of diluted wines->Proline</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M302.643,-260.799C312.67,-248.265 326.39,-231.116 337.522,-217.201\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"340.548,-219.02 344.062,-209.025 335.082,-214.648 340.548,-219.02\"/>\n", | |
"<text text-anchor=\"middle\" x=\"357.883\" y=\"-231.8\" font-family=\"Times,serif\" font-size=\"14.00\">>2.190000</text>\n", | |
"</g>\n", | |
"<!-- Flavanoids -->\n", | |
"<g id=\"node3\" class=\"node\"><title>Flavanoids</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"93.8825\" cy=\"-105\" rx=\"49.2915\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"93.8825\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">Flavanoids</text>\n", | |
"</g>\n", | |
"<!-- Hue->Flavanoids -->\n", | |
"<g id=\"edge5\" class=\"edge\"><title>Hue->Flavanoids</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M208.527,-181.551C193.988,-175.066 174.838,-165.873 158.883,-156 145.633,-147.802 131.833,-137.476 120.374,-128.339\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"122.458,-125.522 112.486,-121.944 118.05,-130.96 122.458,-125.522\"/>\n", | |
"<text text-anchor=\"middle\" x=\"192.883\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\"><=0.890000</text>\n", | |
"</g>\n", | |
"<!-- Alcohol -->\n", | |
"<g id=\"node6\" class=\"node\"><title>Alcohol</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"230.883\" cy=\"-105\" rx=\"38.9931\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"230.883\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">Alcohol</text>\n", | |
"</g>\n", | |
"<!-- Hue->Alcohol -->\n", | |
"<g id=\"edge6\" class=\"edge\"><title>Hue->Alcohol</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M230.883,-173.799C230.883,-162.163 230.883,-146.548 230.883,-133.237\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"234.383,-133.175 230.883,-123.175 227.383,-133.175 234.383,-133.175\"/>\n", | |
"<text text-anchor=\"middle\" x=\"260.883\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">>0.890000</text>\n", | |
"</g>\n", | |
"<!-- 2 -->\n", | |
"<g id=\"node4\" class=\"node\"><title>2</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.8825\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"89.8825\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">2</text>\n", | |
"</g>\n", | |
"<!-- Flavanoids->2 -->\n", | |
"<g id=\"edge1\" class=\"edge\"><title>Flavanoids->2</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M48.6717,-97.4161C31.7207,-92.5084 14.1053,-83.9441 3.88252,-69 -11.4049,-46.6522 24.2862,-32.9241 54.2496,-25.6396\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"55.1262,-29.0299 64.1097,-23.4131 53.5843,-22.2018 55.1262,-29.0299\"/>\n", | |
"<text text-anchor=\"middle\" x=\"37.8825\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\"><=1.490000</text>\n", | |
"</g>\n", | |
"<!-- 1 -->\n", | |
"<g id=\"node5\" class=\"node\"><title>1</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"311.883\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"311.883\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">1</text>\n", | |
"</g>\n", | |
"<!-- Flavanoids->1 -->\n", | |
"<g id=\"edge2\" class=\"edge\"><title>Flavanoids->1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M87.3316,-86.7937C84.4912,-76.1448 83.3145,-62.8609 90.8825,-54 102.663,-40.2065 214.904,-27.9087 274.736,-22.2671\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"275.279,-25.7319 284.914,-21.324 274.633,-18.7617 275.279,-25.7319\"/>\n", | |
"<text text-anchor=\"middle\" x=\"120.883\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">>1.490000</text>\n", | |
"</g>\n", | |
"<!-- Alcohol->2 -->\n", | |
"<g id=\"edge4\" class=\"edge\"><title>Alcohol->2</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M210.229,-89.7029C195.12,-79.419 174.007,-65.4098 154.883,-54 143.532,-47.2281 130.778,-40.2325 119.594,-34.2923\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"121.113,-31.1368 110.633,-29.5789 117.854,-37.332 121.113,-31.1368\"/>\n", | |
"<text text-anchor=\"middle\" x=\"212.383\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">>13.190000</text>\n", | |
"</g>\n", | |
"<!-- Alcohol->1 -->\n", | |
"<g id=\"edge3\" class=\"edge\"><title>Alcohol->1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M241.546,-87.3887C248.453,-77.2268 257.98,-64.2537 267.883,-54 273.29,-48.4005 279.598,-42.8931 285.722,-37.9638\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"288.105,-40.5447 293.855,-31.6461 283.81,-35.0166 288.105,-40.5447\"/>\n", | |
"<text text-anchor=\"middle\" x=\"305.383\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\"><=13.190000</text>\n", | |
"</g>\n", | |
"<!-- Proline->1 -->\n", | |
"<g id=\"edge9\" class=\"edge\"><title>Proline->1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M358.404,-173.707C360.06,-146.858 360.457,-94.0679 342.883,-54 340.776,-49.1984 337.76,-44.5916 334.426,-40.387\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"337.019,-38.0363 327.78,-32.8516 331.769,-42.6666 337.019,-38.0363\"/>\n", | |
"<text text-anchor=\"middle\" x=\"399.883\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\"><=726.500000</text>\n", | |
"</g>\n", | |
"<!-- Color intensity -->\n", | |
"<g id=\"node8\" class=\"node\"><title>Color intensity</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"520.883\" cy=\"-105\" rx=\"63.8893\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"520.883\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">Color intensity</text>\n", | |
"</g>\n", | |
"<!-- Proline->Color intensity -->\n", | |
"<g id=\"edge10\" class=\"edge\"><title>Proline->Color intensity</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M381.081,-178.458C407.924,-164.546 451.579,-141.919 482.969,-125.65\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"484.611,-128.742 491.878,-121.033 481.389,-122.527 484.611,-128.742\"/>\n", | |
"<text text-anchor=\"middle\" x=\"485.383\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">>726.500000</text>\n", | |
"</g>\n", | |
"<!-- Color intensity->1 -->\n", | |
"<g id=\"edge7\" class=\"edge\"><title>Color intensity->1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M485.901,-89.7729C446.325,-73.6775 382.404,-47.681 344.036,-32.0766\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"344.979,-28.682 334.397,-28.1567 342.342,-35.1662 344.979,-28.682\"/>\n", | |
"<text text-anchor=\"middle\" x=\"463.883\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\"><=3.465000</text>\n", | |
"</g>\n", | |
"<!-- 0 -->\n", | |
"<g id=\"node9\" class=\"node\"><title>0</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"520.883\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"520.883\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">0</text>\n", | |
"</g>\n", | |
"<!-- Color intensity->0 -->\n", | |
"<g id=\"edge8\" class=\"edge\"><title>Color intensity->0</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M520.883,-86.799C520.883,-75.1626 520.883,-59.5479 520.883,-46.2368\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"524.383,-46.1754 520.883,-36.1754 517.383,-46.1755 524.383,-46.1754\"/>\n", | |
"<text text-anchor=\"middle\" x=\"550.883\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">>3.465000</text>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.files.Source at 0x7f05dc865898>" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"wine_decision_tree = DecisionTree(evaluator='c4.5')\n", | |
"wine_decision_tree.fit(wine_data_train, wine_target_train)\n", | |
"print(np.sum([wine_decision_tree.predict(w) for w in wine_data_test] == wine_target_test) / len(wine_data_test))\n", | |
"gv.Source(wine_decision_tree.visualize(['Alcohol',\n", | |
" 'Malic acid',\n", | |
" 'Ash', \n", | |
" 'Alcalinity of ash', \n", | |
" 'Magnesium', \n", | |
" 'Total phenols', \n", | |
" 'Flavanoids', \n", | |
" 'Nonflavanoid phenols', \n", | |
" 'Proanthocyanins', \n", | |
" 'Color intensity',\n", | |
" 'Hue', \n", | |
" 'OD280/OD315 of diluted wines', \n", | |
" 'Proline']))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment