Skip to content

Instantly share code, notes, and snippets.

@zhenghaoz
Last active November 13, 2017 07:41
Show Gist options
  • Save zhenghaoz/36706eb8e6e9913c1edb32474f6ec17f to your computer and use it in GitHub Desktop.
Save zhenghaoz/36706eb8e6e9913c1edb32474f6ec17f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**6.3** 选择两个UCI数据集,分别用线性核和高斯核训练一个SVM,并与BP神经网络和C4.5决策树进行实验比较。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import graphviz as gv\n",
"from sklearn import datasets, svm, cross_validation\n",
"from single_hidden_bp import SingleHiddenBP\n",
"from decision_tree import DecisionTree"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Iris Data Set**"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"iris = datasets.load_iris()\n",
"iris_data_train, iris_data_test, iris_target_train, iris_target_test = \\\n",
" cross_validation.train_test_split(iris.data, iris.target, test_size=0.33, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"线性核SVM:"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"iris_linear_svm = svm.SVC(kernel='linear')\n",
"iris_linear_svm.fit(iris_data_train, iris_target_train)\n",
"np.sum(iris_linear_svm.predict(iris_data_test) == iris_target_test) / len(iris_data_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"高斯核SVM:"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"iris_rbf_svm = svm.SVC(kernel='rbf')\n",
"iris_rbf_svm.fit(iris_data_train, iris_target_train)\n",
"np.sum(iris_rbf_svm.predict(iris_data_test) == iris_target_test) / len(iris_data_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"BP神经网络:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.98\n"
]
}
],
"source": [
"iris_data_labels = np.unique(iris_target_train)\n",
"iris_data_bp = np.array([])\n",
"for label in iris_data_labels:\n",
" iris_data_bp = np.append(iris_data_bp, iris_target_train == label)\n",
"iris_data_bp = iris_data_bp.reshape([len(iris_data_labels),-1]).T\n",
"iris_bp = SingleHiddenBP(4, 10, len(iris_data_labels), learning_rate=1, learning_round=500)\n",
"iris_bp.fit(iris_data_train, iris_data_bp)\n",
"print(np.sum(iris_data_labels[np.argmax([iris_bp.predict(w) for w in iris_data_test], axis=1)] == iris_target_test) / len(iris_target_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"C4.5决策树:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.98\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"245pt\" height=\"218pt\"\n",
" viewBox=\"0.00 0.00 245.00 218.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 214)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-214 241,-214 241,4 -4,4\"/>\n",
"<!-- petal length in cm -->\n",
"<g id=\"node1\" class=\"node\"><title>petal length in cm </title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"79\" cy=\"-192\" rx=\"77.1866\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"79\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">petal length in cm </text>\n",
"</g>\n",
"<!-- 0 -->\n",
"<g id=\"node2\" class=\"node\"><title>0</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-105\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"27\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">0</text>\n",
"</g>\n",
"<!-- petal length in cm &#45;&gt;0 -->\n",
"<g id=\"edge3\" class=\"edge\"><title>petal length in cm &#45;&gt;0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M57.7704,-174.482C51.9878,-169.112 46.1752,-162.775 42,-156 37.7021,-149.026 34.5692,-140.743 32.3144,-132.946\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"35.676,-131.962 29.832,-123.125 28.8894,-133.678 35.676,-131.962\"/>\n",
"<text text-anchor=\"middle\" x=\"76\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">&lt;=1.900000</text>\n",
"</g>\n",
"<!-- petal width in cm -->\n",
"<g id=\"node3\" class=\"node\"><title>petal width in cm </title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"147\" cy=\"-105\" rx=\"75.2868\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"147\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">petal width in cm </text>\n",
"</g>\n",
"<!-- petal length in cm &#45;&gt;petal width in cm -->\n",
"<g id=\"edge4\" class=\"edge\"><title>petal length in cm &#45;&gt;petal width in cm </title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M94.4425,-174.169C99.5085,-168.49 105.094,-162.054 110,-156 116.39,-148.113 123.077,-139.277 128.979,-131.252\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"131.971,-133.09 135.026,-122.945 126.311,-128.97 131.971,-133.09\"/>\n",
"<text text-anchor=\"middle\" x=\"151\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">&gt;1.900000</text>\n",
"</g>\n",
"<!-- 1 -->\n",
"<g id=\"node4\" class=\"node\"><title>1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"101\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"101\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">1</text>\n",
"</g>\n",
"<!-- petal width in cm &#45;&gt;1 -->\n",
"<g id=\"edge1\" class=\"edge\"><title>petal width in cm &#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M115.807,-88.4063C108.675,-83.2607 102.018,-76.8083 98,-69 94.4436,-62.0892 93.7361,-53.9096 94.3217,-46.1893\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"97.8121,-46.5061 95.7757,-36.1088 90.8839,-45.5067 97.8121,-46.5061\"/>\n",
"<text text-anchor=\"middle\" x=\"132\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">&lt;=1.700000</text>\n",
"</g>\n",
"<!-- 2 -->\n",
"<g id=\"node5\" class=\"node\"><title>2</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"192\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"192\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">2</text>\n",
"</g>\n",
"<!-- petal width in cm &#45;&gt;2 -->\n",
"<g id=\"edge2\" class=\"edge\"><title>petal width in cm &#45;&gt;2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M159.053,-86.9052C162.767,-81.3187 166.736,-75.0047 170,-69 174.14,-61.384 178.129,-52.8322 181.546,-44.9839\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"184.867,-46.1159 185.537,-35.5422 178.42,-43.3902 184.867,-46.1159\"/>\n",
"<text text-anchor=\"middle\" x=\"207\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">&gt;1.700000</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.files.Source at 0x7f05dc8a9c18>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"iris_decision_tree = DecisionTree(evaluator='c4.5')\n",
"iris_decision_tree.fit(iris_data_train, iris_target_train)\n",
"print(np.sum([iris_decision_tree.predict(w) for w in iris_data_test] == iris_target_test) / len(iris_data_test))\n",
"gv.Source(iris_decision_tree.visualize(['sepal length in cm', 'sepal width in cm','petal length in cm ','petal width in cm ']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Wine Data Set**"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"wine = datasets.load_wine()\n",
"wine_data_train, wine_data_test, wine_target_train, wine_target_test = \\\n",
" cross_validation.train_test_split(wine.data, wine.target, test_size=0.33, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"线性核SVM:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.98305084745762716"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_linear_svm = svm.SVC(kernel='linear')\n",
"wine_linear_svm.fit(wine_data_train, wine_target_train)\n",
"np.sum(wine_linear_svm.predict(wine_data_test) == wine_target_test) / len(wine_data_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"高斯核SVM:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.42372881355932202"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_rbf_svm = svm.SVC(kernel='rbf')\n",
"wine_rbf_svm.fit(wine_data_train, wine_target_train)\n",
"np.sum(wine_rbf_svm.predict(wine_data_test) == wine_target_test) / len(wine_data_test)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.40677966101694918"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_data_labels = np.unique(wine_target_train)\n",
"wine_data_bp = np.array([])\n",
"for label in wine_data_labels:\n",
" wine_data_bp = np.append(wine_data_bp, wine_target_train == label)\n",
"wine_data_bp = wine_data_bp.reshape([len(wine_data_labels),-1]).T\n",
"wine_bp = SingleHiddenBP(13, 13, len(wine_data_labels), learning_rate=1, learning_round=1000)\n",
"wine_bp.fit(wine_data_train, wine_data_bp)\n",
"np.sum(wine_data_labels[np.argmax([wine_bp.predict(w) for w in wine_data_test], axis=1)] == wine_target_test) / len(wine_target_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"C4.5决策树:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.915254237288\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"593pt\" height=\"305pt\"\n",
" viewBox=\"0.00 0.00 592.58 305.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 301)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-301 588.577,-301 588.577,4 -4,4\"/>\n",
"<!-- OD280/OD315 of diluted wines -->\n",
"<g id=\"node1\" class=\"node\"><title>OD280/OD315 of diluted wines</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"288.883\" cy=\"-279\" rx=\"124.278\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"288.883\" y=\"-275.3\" font-family=\"Times,serif\" font-size=\"14.00\">OD280/OD315 of diluted wines</text>\n",
"</g>\n",
"<!-- Hue -->\n",
"<g id=\"node2\" class=\"node\"><title>Hue</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"230.883\" cy=\"-192\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"230.883\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">Hue</text>\n",
"</g>\n",
"<!-- OD280/OD315 of diluted wines&#45;&gt;Hue -->\n",
"<g id=\"edge11\" class=\"edge\"><title>OD280/OD315 of diluted wines&#45;&gt;Hue</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M263.951,-261.152C257.68,-255.941 251.444,-249.77 246.883,-243 242.279,-236.168 238.932,-227.927 236.528,-220.123\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"239.856,-219.024 233.884,-210.273 233.096,-220.838 239.856,-219.024\"/>\n",
"<text text-anchor=\"middle\" x=\"280.883\" y=\"-231.8\" font-family=\"Times,serif\" font-size=\"14.00\">&lt;=2.190000</text>\n",
"</g>\n",
"<!-- Proline -->\n",
"<g id=\"node7\" class=\"node\"><title>Proline</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"356.883\" cy=\"-192\" rx=\"36.2938\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"356.883\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">Proline</text>\n",
"</g>\n",
"<!-- OD280/OD315 of diluted wines&#45;&gt;Proline -->\n",
"<g id=\"edge12\" class=\"edge\"><title>OD280/OD315 of diluted wines&#45;&gt;Proline</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M302.643,-260.799C312.67,-248.265 326.39,-231.116 337.522,-217.201\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"340.548,-219.02 344.062,-209.025 335.082,-214.648 340.548,-219.02\"/>\n",
"<text text-anchor=\"middle\" x=\"357.883\" y=\"-231.8\" font-family=\"Times,serif\" font-size=\"14.00\">&gt;2.190000</text>\n",
"</g>\n",
"<!-- Flavanoids -->\n",
"<g id=\"node3\" class=\"node\"><title>Flavanoids</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"93.8825\" cy=\"-105\" rx=\"49.2915\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"93.8825\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">Flavanoids</text>\n",
"</g>\n",
"<!-- Hue&#45;&gt;Flavanoids -->\n",
"<g id=\"edge5\" class=\"edge\"><title>Hue&#45;&gt;Flavanoids</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M208.527,-181.551C193.988,-175.066 174.838,-165.873 158.883,-156 145.633,-147.802 131.833,-137.476 120.374,-128.339\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"122.458,-125.522 112.486,-121.944 118.05,-130.96 122.458,-125.522\"/>\n",
"<text text-anchor=\"middle\" x=\"192.883\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">&lt;=0.890000</text>\n",
"</g>\n",
"<!-- Alcohol -->\n",
"<g id=\"node6\" class=\"node\"><title>Alcohol</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"230.883\" cy=\"-105\" rx=\"38.9931\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"230.883\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">Alcohol</text>\n",
"</g>\n",
"<!-- Hue&#45;&gt;Alcohol -->\n",
"<g id=\"edge6\" class=\"edge\"><title>Hue&#45;&gt;Alcohol</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M230.883,-173.799C230.883,-162.163 230.883,-146.548 230.883,-133.237\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"234.383,-133.175 230.883,-123.175 227.383,-133.175 234.383,-133.175\"/>\n",
"<text text-anchor=\"middle\" x=\"260.883\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">&gt;0.890000</text>\n",
"</g>\n",
"<!-- 2 -->\n",
"<g id=\"node4\" class=\"node\"><title>2</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"89.8825\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"89.8825\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">2</text>\n",
"</g>\n",
"<!-- Flavanoids&#45;&gt;2 -->\n",
"<g id=\"edge1\" class=\"edge\"><title>Flavanoids&#45;&gt;2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M48.6717,-97.4161C31.7207,-92.5084 14.1053,-83.9441 3.88252,-69 -11.4049,-46.6522 24.2862,-32.9241 54.2496,-25.6396\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"55.1262,-29.0299 64.1097,-23.4131 53.5843,-22.2018 55.1262,-29.0299\"/>\n",
"<text text-anchor=\"middle\" x=\"37.8825\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">&lt;=1.490000</text>\n",
"</g>\n",
"<!-- 1 -->\n",
"<g id=\"node5\" class=\"node\"><title>1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"311.883\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"311.883\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">1</text>\n",
"</g>\n",
"<!-- Flavanoids&#45;&gt;1 -->\n",
"<g id=\"edge2\" class=\"edge\"><title>Flavanoids&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M87.3316,-86.7937C84.4912,-76.1448 83.3145,-62.8609 90.8825,-54 102.663,-40.2065 214.904,-27.9087 274.736,-22.2671\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"275.279,-25.7319 284.914,-21.324 274.633,-18.7617 275.279,-25.7319\"/>\n",
"<text text-anchor=\"middle\" x=\"120.883\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">&gt;1.490000</text>\n",
"</g>\n",
"<!-- Alcohol&#45;&gt;2 -->\n",
"<g id=\"edge4\" class=\"edge\"><title>Alcohol&#45;&gt;2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M210.229,-89.7029C195.12,-79.419 174.007,-65.4098 154.883,-54 143.532,-47.2281 130.778,-40.2325 119.594,-34.2923\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"121.113,-31.1368 110.633,-29.5789 117.854,-37.332 121.113,-31.1368\"/>\n",
"<text text-anchor=\"middle\" x=\"212.383\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">&gt;13.190000</text>\n",
"</g>\n",
"<!-- Alcohol&#45;&gt;1 -->\n",
"<g id=\"edge3\" class=\"edge\"><title>Alcohol&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M241.546,-87.3887C248.453,-77.2268 257.98,-64.2537 267.883,-54 273.29,-48.4005 279.598,-42.8931 285.722,-37.9638\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"288.105,-40.5447 293.855,-31.6461 283.81,-35.0166 288.105,-40.5447\"/>\n",
"<text text-anchor=\"middle\" x=\"305.383\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">&lt;=13.190000</text>\n",
"</g>\n",
"<!-- Proline&#45;&gt;1 -->\n",
"<g id=\"edge9\" class=\"edge\"><title>Proline&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M358.404,-173.707C360.06,-146.858 360.457,-94.0679 342.883,-54 340.776,-49.1984 337.76,-44.5916 334.426,-40.387\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"337.019,-38.0363 327.78,-32.8516 331.769,-42.6666 337.019,-38.0363\"/>\n",
"<text text-anchor=\"middle\" x=\"399.883\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">&lt;=726.500000</text>\n",
"</g>\n",
"<!-- Color intensity -->\n",
"<g id=\"node8\" class=\"node\"><title>Color intensity</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"520.883\" cy=\"-105\" rx=\"63.8893\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"520.883\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">Color intensity</text>\n",
"</g>\n",
"<!-- Proline&#45;&gt;Color intensity -->\n",
"<g id=\"edge10\" class=\"edge\"><title>Proline&#45;&gt;Color intensity</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M381.081,-178.458C407.924,-164.546 451.579,-141.919 482.969,-125.65\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"484.611,-128.742 491.878,-121.033 481.389,-122.527 484.611,-128.742\"/>\n",
"<text text-anchor=\"middle\" x=\"485.383\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">&gt;726.500000</text>\n",
"</g>\n",
"<!-- Color intensity&#45;&gt;1 -->\n",
"<g id=\"edge7\" class=\"edge\"><title>Color intensity&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M485.901,-89.7729C446.325,-73.6775 382.404,-47.681 344.036,-32.0766\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"344.979,-28.682 334.397,-28.1567 342.342,-35.1662 344.979,-28.682\"/>\n",
"<text text-anchor=\"middle\" x=\"463.883\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">&lt;=3.465000</text>\n",
"</g>\n",
"<!-- 0 -->\n",
"<g id=\"node9\" class=\"node\"><title>0</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"520.883\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"520.883\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">0</text>\n",
"</g>\n",
"<!-- Color intensity&#45;&gt;0 -->\n",
"<g id=\"edge8\" class=\"edge\"><title>Color intensity&#45;&gt;0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M520.883,-86.799C520.883,-75.1626 520.883,-59.5479 520.883,-46.2368\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"524.383,-46.1754 520.883,-36.1754 517.383,-46.1755 524.383,-46.1754\"/>\n",
"<text text-anchor=\"middle\" x=\"550.883\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">&gt;3.465000</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.files.Source at 0x7f05dc865898>"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wine_decision_tree = DecisionTree(evaluator='c4.5')\n",
"wine_decision_tree.fit(wine_data_train, wine_target_train)\n",
"print(np.sum([wine_decision_tree.predict(w) for w in wine_data_test] == wine_target_test) / len(wine_data_test))\n",
"gv.Source(wine_decision_tree.visualize(['Alcohol',\n",
" 'Malic acid',\n",
" 'Ash', \n",
" 'Alcalinity of ash', \n",
" 'Magnesium', \n",
" 'Total phenols', \n",
" 'Flavanoids', \n",
" 'Nonflavanoid phenols', \n",
" 'Proanthocyanins', \n",
" 'Color intensity',\n",
" 'Hue', \n",
" 'OD280/OD315 of diluted wines', \n",
" 'Proline']))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment