Created
January 8, 2016 19:25
-
-
Save ikegami-yukino/02d171304e01d8039470 to your computer and use it in GitHub Desktop.
scikit-learnの決定木をルールベースのコードに変換
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# scikit-learnの決定木をルールベースのコードに変換" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 変換コードのダウンロード" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['--2016-01-09 04:18:48-- https://raw.githubusercontent.com/ikegami-yukino/misc/master/machinelearning/dt2code.py',\n", | |
" 'Resolving raw.githubusercontent.com... 103.245.222.133',\n", | |
" 'Connecting to raw.githubusercontent.com|103.245.222.133|:443... connected.',\n", | |
" 'HTTP request sent, awaiting response... 200 OK',\n", | |
" 'Length: 2001 (2.0K) [text/plain]',\n", | |
" \"Saving to: 'dt2code.py.1'\",\n", | |
" '',\n", | |
" ' 0K . 100% 90.9M=0s',\n", | |
" '',\n", | |
" \"2016-01-09 04:18:48 (90.9 MB/s) - 'dt2code.py.1' saved [2001/2001]\",\n", | |
" '']" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%system wget https://raw.githubusercontent.com/ikegami-yukino/misc/master/machinelearning/dt2code.py" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## irisのデータで試してみる" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"def f(sepal_length=0, sepal_width=0, petal_length=0, petal_width=0):\n", | |
" \"\"\"\n", | |
" 0 -> setosa\n", | |
" 1 -> versicolor\n", | |
" 2 -> virginica\n", | |
" \"\"\"\n", | |
" if petal_length <= 2.45000004768: # samples=150\n", | |
" return 0 # samples=50\n", | |
" else:\n", | |
" if petal_width <= 1.75: # samples=100\n", | |
" if petal_length <= 4.94999980927: # samples=54\n", | |
" if petal_width <= 1.65000009537: # samples=48\n", | |
" return 1 # samples=47\n", | |
" else:\n", | |
" return 2 # samples=1\n", | |
" else:\n", | |
" if petal_width <= 1.54999995232: # samples=6\n", | |
" return 2 # samples=3\n", | |
" else:\n", | |
" if petal_length <= 5.44999980927: # samples=3\n", | |
" return 1 # samples=2\n", | |
" else:\n", | |
" return 2 # samples=1\n", | |
" else:\n", | |
" if petal_length <= 4.85000038147: # samples=46\n", | |
" if sepal_width <= 3.09999990463: # samples=3\n", | |
" return 2 # samples=2\n", | |
" else:\n", | |
" return 1 # samples=1\n", | |
" else:\n", | |
" return 2 # samples=43\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"from sklearn.datasets import load_iris\n", | |
"from sklearn import tree\n", | |
"import dt2code\n", | |
"iris = load_iris()\n", | |
"clf = tree.DecisionTreeClassifier()\n", | |
"clf = clf.fit(iris.data, iris.target)\n", | |
"print(dt2code.dt2code(clf, iris.feature_names, iris.target_names))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 得られたコードはこんな風にして使える" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def predict(X):\n", | |
" def f(sepal_length=0, sepal_width=0, petal_length=0, petal_width=0):\n", | |
" \"\"\"\n", | |
" 0 -> setosa\n", | |
" 1 -> versicolor\n", | |
" 2 -> virginica\n", | |
" \"\"\"\n", | |
" if petal_length <= 2.45000004768: # samples=150\n", | |
" return 0 # samples=50\n", | |
" else:\n", | |
" if petal_width <= 1.75: # samples=100\n", | |
" if petal_length <= 4.94999980927: # samples=54\n", | |
" if petal_width <= 1.65000009537: # samples=48\n", | |
" return 1 # samples=47\n", | |
" else:\n", | |
" return 2 # samples=1\n", | |
" else:\n", | |
" if petal_width <= 1.54999995232: # samples=6\n", | |
" return 2 # samples=3\n", | |
" else:\n", | |
" if petal_length <= 5.44999980927: # samples=3\n", | |
" return 1 # samples=2\n", | |
" else:\n", | |
" return 2 # samples=1\n", | |
" else:\n", | |
" if petal_length <= 4.85000038147: # samples=46\n", | |
" if sepal_width <= 3.09999990463: # samples=3\n", | |
" return 2 # samples=2\n", | |
" else:\n", | |
" return 1 # samples=1\n", | |
" else:\n", | |
" return 2 # samples=43\n", | |
"\n", | |
" return list(map(lambda x: f(*x), X))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[0, 1, 2]" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"predict([[5.1, 3.5, 1.4, 0.2], [5, 2, 3.5, 1], [6.3, 3.3, 6, 2.5]])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0, 1, 2])" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"clf.predict([[5.1, 3.5, 1.4, 0.2], [5, 2, 3.5, 1], [6.3, 3.3, 6, 2.5]])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment