Skip to content

Instantly share code, notes, and snippets.

@ikegami-yukino
Created January 8, 2016 19:25
Show Gist options
  • Save ikegami-yukino/02d171304e01d8039470 to your computer and use it in GitHub Desktop.
Save ikegami-yukino/02d171304e01d8039470 to your computer and use it in GitHub Desktop.
scikit-learnの決定木をルールベースのコードに変換
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# scikit-learnの決定木をルールベースのコードに変換"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 変換コードのダウンロード"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"['--2016-01-09 04:18:48-- https://raw.githubusercontent.com/ikegami-yukino/misc/master/machinelearning/dt2code.py',\n",
" 'Resolving raw.githubusercontent.com... 103.245.222.133',\n",
" 'Connecting to raw.githubusercontent.com|103.245.222.133|:443... connected.',\n",
" 'HTTP request sent, awaiting response... 200 OK',\n",
" 'Length: 2001 (2.0K) [text/plain]',\n",
" \"Saving to: 'dt2code.py.1'\",\n",
" '',\n",
" ' 0K . 100% 90.9M=0s',\n",
" '',\n",
" \"2016-01-09 04:18:48 (90.9 MB/s) - 'dt2code.py.1' saved [2001/2001]\",\n",
" '']"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%system wget https://raw.githubusercontent.com/ikegami-yukino/misc/master/machinelearning/dt2code.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## irisのデータで試してみる"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def f(sepal_length=0, sepal_width=0, petal_length=0, petal_width=0):\n",
" \"\"\"\n",
" 0 -> setosa\n",
" 1 -> versicolor\n",
" 2 -> virginica\n",
" \"\"\"\n",
" if petal_length <= 2.45000004768: # samples=150\n",
" return 0 # samples=50\n",
" else:\n",
" if petal_width <= 1.75: # samples=100\n",
" if petal_length <= 4.94999980927: # samples=54\n",
" if petal_width <= 1.65000009537: # samples=48\n",
" return 1 # samples=47\n",
" else:\n",
" return 2 # samples=1\n",
" else:\n",
" if petal_width <= 1.54999995232: # samples=6\n",
" return 2 # samples=3\n",
" else:\n",
" if petal_length <= 5.44999980927: # samples=3\n",
" return 1 # samples=2\n",
" else:\n",
" return 2 # samples=1\n",
" else:\n",
" if petal_length <= 4.85000038147: # samples=46\n",
" if sepal_width <= 3.09999990463: # samples=3\n",
" return 2 # samples=2\n",
" else:\n",
" return 1 # samples=1\n",
" else:\n",
" return 2 # samples=43\n",
"\n"
]
}
],
"source": [
"from sklearn.datasets import load_iris\n",
"from sklearn import tree\n",
"import dt2code\n",
"iris = load_iris()\n",
"clf = tree.DecisionTreeClassifier()\n",
"clf = clf.fit(iris.data, iris.target)\n",
"print(dt2code.dt2code(clf, iris.feature_names, iris.target_names))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 得られたコードはこんな風にして使える"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def predict(X):\n",
" def f(sepal_length=0, sepal_width=0, petal_length=0, petal_width=0):\n",
" \"\"\"\n",
" 0 -> setosa\n",
" 1 -> versicolor\n",
" 2 -> virginica\n",
" \"\"\"\n",
" if petal_length <= 2.45000004768: # samples=150\n",
" return 0 # samples=50\n",
" else:\n",
" if petal_width <= 1.75: # samples=100\n",
" if petal_length <= 4.94999980927: # samples=54\n",
" if petal_width <= 1.65000009537: # samples=48\n",
" return 1 # samples=47\n",
" else:\n",
" return 2 # samples=1\n",
" else:\n",
" if petal_width <= 1.54999995232: # samples=6\n",
" return 2 # samples=3\n",
" else:\n",
" if petal_length <= 5.44999980927: # samples=3\n",
" return 1 # samples=2\n",
" else:\n",
" return 2 # samples=1\n",
" else:\n",
" if petal_length <= 4.85000038147: # samples=46\n",
" if sepal_width <= 3.09999990463: # samples=3\n",
" return 2 # samples=2\n",
" else:\n",
" return 1 # samples=1\n",
" else:\n",
" return 2 # samples=43\n",
"\n",
" return list(map(lambda x: f(*x), X))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[0, 1, 2]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict([[5.1, 3.5, 1.4, 0.2], [5, 2, 3.5, 1], [6.3, 3.3, 6, 2.5]])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 1, 2])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.predict([[5.1, 3.5, 1.4, 0.2], [5, 2, 3.5, 1], [6.3, 3.3, 6, 2.5]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment