Last active
June 8, 2019 13:01
-
-
Save sdwfrost/450cd038405a20970bce307cc44403b8 to your computer and use it in GitHub Desktop.
Jupyter notebook for MLJ using Iris data and DecisionTree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table class=\"data-frame\"><thead><tr><th></th><th>SepalLength</th><th>SepalWidth</th><th>PetalLength</th><th>PetalWidth</th><th>Species</th></tr><tr><th></th><th>Float64</th><th>Float64</th><th>Float64</th><th>Float64</th><th>Categorical…</th></tr></thead><tbody><p>150 rows × 5 columns</p><tr><th>1</th><td>5.1</td><td>3.5</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>2</th><td>4.9</td><td>3.0</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>3</th><td>4.7</td><td>3.2</td><td>1.3</td><td>0.2</td><td>setosa</td></tr><tr><th>4</th><td>4.6</td><td>3.1</td><td>1.5</td><td>0.2</td><td>setosa</td></tr><tr><th>5</th><td>5.0</td><td>3.6</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>6</th><td>5.4</td><td>3.9</td><td>1.7</td><td>0.4</td><td>setosa</td></tr><tr><th>7</th><td>4.6</td><td>3.4</td><td>1.4</td><td>0.3</td><td>setosa</td></tr><tr><th>8</th><td>5.0</td><td>3.4</td><td>1.5</td><td>0.2</td><td>setosa</td></tr><tr><th>9</th><td>4.4</td><td>2.9</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>10</th><td>4.9</td><td>3.1</td><td>1.5</td><td>0.1</td><td>setosa</td></tr><tr><th>11</th><td>5.4</td><td>3.7</td><td>1.5</td><td>0.2</td><td>setosa</td></tr><tr><th>12</th><td>4.8</td><td>3.4</td><td>1.6</td><td>0.2</td><td>setosa</td></tr><tr><th>13</th><td>4.8</td><td>3.0</td><td>1.4</td><td>0.1</td><td>setosa</td></tr><tr><th>14</th><td>4.3</td><td>3.0</td><td>1.1</td><td>0.1</td><td>setosa</td></tr><tr><th>15</th><td>5.8</td><td>4.0</td><td>1.2</td><td>0.2</td><td>setosa</td></tr><tr><th>16</th><td>5.7</td><td>4.4</td><td>1.5</td><td>0.4</td><td>setosa</td></tr><tr><th>17</th><td>5.4</td><td>3.9</td><td>1.3</td><td>0.4</td><td>setosa</td></tr><tr><th>18</th><td>5.1</td><td>3.5</td><td>1.4</td><td>0.3</td><td>setosa</td></tr><tr><th>19</th><td>5.7</td><td>3.8</td><td>1.7</td><td>0.3</td><td>setosa</td></tr><tr><th>20</th><td>5.1</td><td>3.8</td><td>1.5</td><td>0.3</td><td>setosa</td></tr><tr><th>21</th><td>5.4</td><td>3.4</td><td>1.7</td><td>0.2</td><td>setosa</td></tr><tr><th>22</th><td>5.1</td><td>3.7</td><td>1.5</td><td>0.4</td><td>setosa</td></tr><tr><th>23</th><td>4.6</td><td>3.6</td><td>1.0</td><td>0.2</td><td>setosa</td></tr><tr><th>24</th><td>5.1</td><td>3.3</td><td>1.7</td><td>0.5</td><td>setosa</td></tr><tr><th>25</th><td>4.8</td><td>3.4</td><td>1.9</td><td>0.2</td><td>setosa</td></tr><tr><th>26</th><td>5.0</td><td>3.0</td><td>1.6</td><td>0.2</td><td>setosa</td></tr><tr><th>27</th><td>5.0</td><td>3.4</td><td>1.6</td><td>0.4</td><td>setosa</td></tr><tr><th>28</th><td>5.2</td><td>3.5</td><td>1.5</td><td>0.2</td><td>setosa</td></tr><tr><th>29</th><td>5.2</td><td>3.4</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>30</th><td>4.7</td><td>3.2</td><td>1.6</td><td>0.2</td><td>setosa</td></tr><tr><th>⋮</th><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td></tr></tbody></table>" | |
], | |
"text/latex": [ | |
"\\begin{tabular}{r|ccccc}\n", | |
"\t& SepalLength & SepalWidth & PetalLength & PetalWidth & Species\\\\\n", | |
"\t\\hline\n", | |
"\t& Float64 & Float64 & Float64 & Float64 & Categorical…\\\\\n", | |
"\t\\hline\n", | |
"\t1 & 5.1 & 3.5 & 1.4 & 0.2 & setosa \\\\\n", | |
"\t2 & 4.9 & 3.0 & 1.4 & 0.2 & setosa \\\\\n", | |
"\t3 & 4.7 & 3.2 & 1.3 & 0.2 & setosa \\\\\n", | |
"\t4 & 4.6 & 3.1 & 1.5 & 0.2 & setosa \\\\\n", | |
"\t5 & 5.0 & 3.6 & 1.4 & 0.2 & setosa \\\\\n", | |
"\t6 & 5.4 & 3.9 & 1.7 & 0.4 & setosa \\\\\n", | |
"\t7 & 4.6 & 3.4 & 1.4 & 0.3 & setosa \\\\\n", | |
"\t8 & 5.0 & 3.4 & 1.5 & 0.2 & setosa \\\\\n", | |
"\t9 & 4.4 & 2.9 & 1.4 & 0.2 & setosa \\\\\n", | |
"\t10 & 4.9 & 3.1 & 1.5 & 0.1 & setosa \\\\\n", | |
"\t11 & 5.4 & 3.7 & 1.5 & 0.2 & setosa \\\\\n", | |
"\t12 & 4.8 & 3.4 & 1.6 & 0.2 & setosa \\\\\n", | |
"\t13 & 4.8 & 3.0 & 1.4 & 0.1 & setosa \\\\\n", | |
"\t14 & 4.3 & 3.0 & 1.1 & 0.1 & setosa \\\\\n", | |
"\t15 & 5.8 & 4.0 & 1.2 & 0.2 & setosa \\\\\n", | |
"\t16 & 5.7 & 4.4 & 1.5 & 0.4 & setosa \\\\\n", | |
"\t17 & 5.4 & 3.9 & 1.3 & 0.4 & setosa \\\\\n", | |
"\t18 & 5.1 & 3.5 & 1.4 & 0.3 & setosa \\\\\n", | |
"\t19 & 5.7 & 3.8 & 1.7 & 0.3 & setosa \\\\\n", | |
"\t20 & 5.1 & 3.8 & 1.5 & 0.3 & setosa \\\\\n", | |
"\t21 & 5.4 & 3.4 & 1.7 & 0.2 & setosa \\\\\n", | |
"\t22 & 5.1 & 3.7 & 1.5 & 0.4 & setosa \\\\\n", | |
"\t23 & 4.6 & 3.6 & 1.0 & 0.2 & setosa \\\\\n", | |
"\t24 & 5.1 & 3.3 & 1.7 & 0.5 & setosa \\\\\n", | |
"\t25 & 4.8 & 3.4 & 1.9 & 0.2 & setosa \\\\\n", | |
"\t26 & 5.0 & 3.0 & 1.6 & 0.2 & setosa \\\\\n", | |
"\t27 & 5.0 & 3.4 & 1.6 & 0.4 & setosa \\\\\n", | |
"\t28 & 5.2 & 3.5 & 1.5 & 0.2 & setosa \\\\\n", | |
"\t29 & 5.2 & 3.4 & 1.4 & 0.2 & setosa \\\\\n", | |
"\t30 & 4.7 & 3.2 & 1.6 & 0.2 & setosa \\\\\n", | |
"\t$\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ \\\\\n", | |
"\\end{tabular}\n" | |
], | |
"text/plain": [ | |
"150×5 DataFrame\n", | |
"│ Row │ SepalLength │ SepalWidth │ PetalLength │ PetalWidth │ Species │\n", | |
"│ │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │ \u001b[90mCategorical…\u001b[39m │\n", | |
"├─────┼─────────────┼────────────┼─────────────┼────────────┼──────────────┤\n", | |
"│ 1 │ 5.1 │ 3.5 │ 1.4 │ 0.2 │ setosa │\n", | |
"│ 2 │ 4.9 │ 3.0 │ 1.4 │ 0.2 │ setosa │\n", | |
"│ 3 │ 4.7 │ 3.2 │ 1.3 │ 0.2 │ setosa │\n", | |
"│ 4 │ 4.6 │ 3.1 │ 1.5 │ 0.2 │ setosa │\n", | |
"│ 5 │ 5.0 │ 3.6 │ 1.4 │ 0.2 │ setosa │\n", | |
"│ 6 │ 5.4 │ 3.9 │ 1.7 │ 0.4 │ setosa │\n", | |
"│ 7 │ 4.6 │ 3.4 │ 1.4 │ 0.3 │ setosa │\n", | |
"│ 8 │ 5.0 │ 3.4 │ 1.5 │ 0.2 │ setosa │\n", | |
"│ 9 │ 4.4 │ 2.9 │ 1.4 │ 0.2 │ setosa │\n", | |
"│ 10 │ 4.9 │ 3.1 │ 1.5 │ 0.1 │ setosa │\n", | |
"⋮\n", | |
"│ 140 │ 6.9 │ 3.1 │ 5.4 │ 2.1 │ virginica │\n", | |
"│ 141 │ 6.7 │ 3.1 │ 5.6 │ 2.4 │ virginica │\n", | |
"│ 142 │ 6.9 │ 3.1 │ 5.1 │ 2.3 │ virginica │\n", | |
"│ 143 │ 5.8 │ 2.7 │ 5.1 │ 1.9 │ virginica │\n", | |
"│ 144 │ 6.8 │ 3.2 │ 5.9 │ 2.3 │ virginica │\n", | |
"│ 145 │ 6.7 │ 3.3 │ 5.7 │ 2.5 │ virginica │\n", | |
"│ 146 │ 6.7 │ 3.0 │ 5.2 │ 2.3 │ virginica │\n", | |
"│ 147 │ 6.3 │ 2.5 │ 5.0 │ 1.9 │ virginica │\n", | |
"│ 148 │ 6.5 │ 3.0 │ 5.2 │ 2.0 │ virginica │\n", | |
"│ 149 │ 6.2 │ 3.4 │ 5.4 │ 2.3 │ virginica │\n", | |
"│ 150 │ 5.9 │ 3.0 │ 5.1 │ 1.8 │ virginica │" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using MLJ\n", | |
"using RDatasets\n", | |
"iris = dataset(\"datasets\", \"iris\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"const X = iris[:, 1:4];\n", | |
"const y = iris[:, 5];" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"import MLJModels ✔\n", | |
"import DecisionTree ✔\n", | |
"import MLJModels.DecisionTree_.DecisionTreeClassifier ✔\n" | |
] | |
} | |
], | |
"source": [ | |
"@load DecisionTreeClassifier" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"DecisionTreeClassifier(pruning_purity = 1.0,\n", | |
" max_depth = 2,\n", | |
" min_samples_leaf = 1,\n", | |
" min_samples_split = 2,\n", | |
" min_purity_increase = 0.0,\n", | |
" n_subfeatures = 0.0,\n", | |
" display_depth = 5,\n", | |
" post_prune = false,\n", | |
" merge_purity_threshold = 0.9,)\u001b[34m @ 1…35\u001b[39m" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tree_model = DecisionTreeClassifier(max_depth=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"\u001b[34mMachine{DecisionTreeClassifier} @ 6…34\u001b[39m\n" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tree = machine(tree_model, X, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: Training \u001b[34mMachine{DecisionTreeClassifier} @ 6…34\u001b[39m.\n", | |
"└ @ MLJ /home/simon/.julia/dev/MLJ/src/machines.jl:130\n" | |
] | |
}, | |
{ | |
"ename": "MethodError", | |
"evalue": "MethodError: no method matching build_tree(::CategoricalArray{String,1,UInt8,String,CategoricalString{UInt8},Union{}}, ::Array{Float64,2}, ::Float64, ::Int64, ::Int64, ::Int64, ::Float64)\nClosest candidates are:\n build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27\n build_tree(!Matched::Array{T,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T} at /home/simon/.julia/dev/DecisionTree/src/classification/main.jl:83\n build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27\n ...", | |
"output_type": "error", | |
"traceback": [ | |
"MethodError: no method matching build_tree(::CategoricalArray{String,1,UInt8,String,CategoricalString{UInt8},Union{}}, ::Array{Float64,2}, ::Float64, ::Int64, ::Int64, ::Int64, ::Float64)\nClosest candidates are:\n build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27\n build_tree(!Matched::Array{T,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T} at /home/simon/.julia/dev/DecisionTree/src/classification/main.jl:83\n build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27\n ...", | |
"", | |
"Stacktrace:", | |
" [1] fit(::DecisionTreeClassifier, ::Int64, ::DataFrame, ::CategoricalArray{String,1,UInt8,String,CategoricalString{UInt8},Union{}}) at /home/simon/.julia/dev/MLJModels/src/DecisionTree.jl:110", | |
" [2] #fit!#3(::Array{Int64,1}, ::Int64, ::Bool, ::Function, ::Machine{DecisionTreeClassifier}) at /home/simon/.julia/dev/MLJ/src/machines.jl:131", | |
" [3] (::getfield(StatsBase, Symbol(\"#kw##fit!\")))(::NamedTuple{(:rows,),Tuple{Array{Int64,1}}}, ::typeof(fit!), ::Machine{DecisionTreeClassifier}) at ./none:0", | |
" [4] top-level scope at In[6]:2" | |
] | |
} | |
], | |
"source": [ | |
"train, test = partition(eachindex(y), 0.7, shuffle=true); # 70:30 split\n", | |
"fit!(tree, rows=train)\n", | |
"yhat = predict(tree, X[test,:]);\n", | |
"misclassification_rate(yhat, y[test])" | |
] | |
} | |
], | |
"metadata": { | |
"@webio": { | |
"lastCommId": null, | |
"lastKernelId": null | |
}, | |
"kernelspec": { | |
"display_name": "Julia 1.1.0", | |
"language": "julia", | |
"name": "julia-1.1" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.1.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment