Skip to content

Instantly share code, notes, and snippets.

@sdwfrost
Last active June 8, 2019 13:01
Show Gist options
  • Save sdwfrost/450cd038405a20970bce307cc44403b8 to your computer and use it in GitHub Desktop.
Save sdwfrost/450cd038405a20970bce307cc44403b8 to your computer and use it in GitHub Desktop.
Jupyter notebook for MLJ using Iris data and DecisionTree
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table class=\"data-frame\"><thead><tr><th></th><th>SepalLength</th><th>SepalWidth</th><th>PetalLength</th><th>PetalWidth</th><th>Species</th></tr><tr><th></th><th>Float64</th><th>Float64</th><th>Float64</th><th>Float64</th><th>Categorical…</th></tr></thead><tbody><p>150 rows × 5 columns</p><tr><th>1</th><td>5.1</td><td>3.5</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>2</th><td>4.9</td><td>3.0</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>3</th><td>4.7</td><td>3.2</td><td>1.3</td><td>0.2</td><td>setosa</td></tr><tr><th>4</th><td>4.6</td><td>3.1</td><td>1.5</td><td>0.2</td><td>setosa</td></tr><tr><th>5</th><td>5.0</td><td>3.6</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>6</th><td>5.4</td><td>3.9</td><td>1.7</td><td>0.4</td><td>setosa</td></tr><tr><th>7</th><td>4.6</td><td>3.4</td><td>1.4</td><td>0.3</td><td>setosa</td></tr><tr><th>8</th><td>5.0</td><td>3.4</td><td>1.5</td><td>0.2</td><td>setosa</td></tr><tr><th>9</th><td>4.4</td><td>2.9</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>10</th><td>4.9</td><td>3.1</td><td>1.5</td><td>0.1</td><td>setosa</td></tr><tr><th>11</th><td>5.4</td><td>3.7</td><td>1.5</td><td>0.2</td><td>setosa</td></tr><tr><th>12</th><td>4.8</td><td>3.4</td><td>1.6</td><td>0.2</td><td>setosa</td></tr><tr><th>13</th><td>4.8</td><td>3.0</td><td>1.4</td><td>0.1</td><td>setosa</td></tr><tr><th>14</th><td>4.3</td><td>3.0</td><td>1.1</td><td>0.1</td><td>setosa</td></tr><tr><th>15</th><td>5.8</td><td>4.0</td><td>1.2</td><td>0.2</td><td>setosa</td></tr><tr><th>16</th><td>5.7</td><td>4.4</td><td>1.5</td><td>0.4</td><td>setosa</td></tr><tr><th>17</th><td>5.4</td><td>3.9</td><td>1.3</td><td>0.4</td><td>setosa</td></tr><tr><th>18</th><td>5.1</td><td>3.5</td><td>1.4</td><td>0.3</td><td>setosa</td></tr><tr><th>19</th><td>5.7</td><td>3.8</td><td>1.7</td><td>0.3</td><td>setosa</td></tr><tr><th>20</th><td>5.1</td><td>3.8</td><td>1.5</td><td>0.3</td><td>setosa</td></tr><tr><th>21</th><td>5.4</td><td>3.4</td><td>1.7</td><td>0.2</td><td>setosa</td></tr><tr><th>22</th><td>5.1</td><td>3.7</td><td>1.5</td><td>0.4</td><td>setosa</td></tr><tr><th>23</th><td>4.6</td><td>3.6</td><td>1.0</td><td>0.2</td><td>setosa</td></tr><tr><th>24</th><td>5.1</td><td>3.3</td><td>1.7</td><td>0.5</td><td>setosa</td></tr><tr><th>25</th><td>4.8</td><td>3.4</td><td>1.9</td><td>0.2</td><td>setosa</td></tr><tr><th>26</th><td>5.0</td><td>3.0</td><td>1.6</td><td>0.2</td><td>setosa</td></tr><tr><th>27</th><td>5.0</td><td>3.4</td><td>1.6</td><td>0.4</td><td>setosa</td></tr><tr><th>28</th><td>5.2</td><td>3.5</td><td>1.5</td><td>0.2</td><td>setosa</td></tr><tr><th>29</th><td>5.2</td><td>3.4</td><td>1.4</td><td>0.2</td><td>setosa</td></tr><tr><th>30</th><td>4.7</td><td>3.2</td><td>1.6</td><td>0.2</td><td>setosa</td></tr><tr><th>&vellip;</th><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td><td>&vellip;</td></tr></tbody></table>"
],
"text/latex": [
"\\begin{tabular}{r|ccccc}\n",
"\t& SepalLength & SepalWidth & PetalLength & PetalWidth & Species\\\\\n",
"\t\\hline\n",
"\t& Float64 & Float64 & Float64 & Float64 & Categorical…\\\\\n",
"\t\\hline\n",
"\t1 & 5.1 & 3.5 & 1.4 & 0.2 & setosa \\\\\n",
"\t2 & 4.9 & 3.0 & 1.4 & 0.2 & setosa \\\\\n",
"\t3 & 4.7 & 3.2 & 1.3 & 0.2 & setosa \\\\\n",
"\t4 & 4.6 & 3.1 & 1.5 & 0.2 & setosa \\\\\n",
"\t5 & 5.0 & 3.6 & 1.4 & 0.2 & setosa \\\\\n",
"\t6 & 5.4 & 3.9 & 1.7 & 0.4 & setosa \\\\\n",
"\t7 & 4.6 & 3.4 & 1.4 & 0.3 & setosa \\\\\n",
"\t8 & 5.0 & 3.4 & 1.5 & 0.2 & setosa \\\\\n",
"\t9 & 4.4 & 2.9 & 1.4 & 0.2 & setosa \\\\\n",
"\t10 & 4.9 & 3.1 & 1.5 & 0.1 & setosa \\\\\n",
"\t11 & 5.4 & 3.7 & 1.5 & 0.2 & setosa \\\\\n",
"\t12 & 4.8 & 3.4 & 1.6 & 0.2 & setosa \\\\\n",
"\t13 & 4.8 & 3.0 & 1.4 & 0.1 & setosa \\\\\n",
"\t14 & 4.3 & 3.0 & 1.1 & 0.1 & setosa \\\\\n",
"\t15 & 5.8 & 4.0 & 1.2 & 0.2 & setosa \\\\\n",
"\t16 & 5.7 & 4.4 & 1.5 & 0.4 & setosa \\\\\n",
"\t17 & 5.4 & 3.9 & 1.3 & 0.4 & setosa \\\\\n",
"\t18 & 5.1 & 3.5 & 1.4 & 0.3 & setosa \\\\\n",
"\t19 & 5.7 & 3.8 & 1.7 & 0.3 & setosa \\\\\n",
"\t20 & 5.1 & 3.8 & 1.5 & 0.3 & setosa \\\\\n",
"\t21 & 5.4 & 3.4 & 1.7 & 0.2 & setosa \\\\\n",
"\t22 & 5.1 & 3.7 & 1.5 & 0.4 & setosa \\\\\n",
"\t23 & 4.6 & 3.6 & 1.0 & 0.2 & setosa \\\\\n",
"\t24 & 5.1 & 3.3 & 1.7 & 0.5 & setosa \\\\\n",
"\t25 & 4.8 & 3.4 & 1.9 & 0.2 & setosa \\\\\n",
"\t26 & 5.0 & 3.0 & 1.6 & 0.2 & setosa \\\\\n",
"\t27 & 5.0 & 3.4 & 1.6 & 0.4 & setosa \\\\\n",
"\t28 & 5.2 & 3.5 & 1.5 & 0.2 & setosa \\\\\n",
"\t29 & 5.2 & 3.4 & 1.4 & 0.2 & setosa \\\\\n",
"\t30 & 4.7 & 3.2 & 1.6 & 0.2 & setosa \\\\\n",
"\t$\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ \\\\\n",
"\\end{tabular}\n"
],
"text/plain": [
"150×5 DataFrame\n",
"│ Row │ SepalLength │ SepalWidth │ PetalLength │ PetalWidth │ Species │\n",
"│ │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │ \u001b[90mCategorical…\u001b[39m │\n",
"├─────┼─────────────┼────────────┼─────────────┼────────────┼──────────────┤\n",
"│ 1 │ 5.1 │ 3.5 │ 1.4 │ 0.2 │ setosa │\n",
"│ 2 │ 4.9 │ 3.0 │ 1.4 │ 0.2 │ setosa │\n",
"│ 3 │ 4.7 │ 3.2 │ 1.3 │ 0.2 │ setosa │\n",
"│ 4 │ 4.6 │ 3.1 │ 1.5 │ 0.2 │ setosa │\n",
"│ 5 │ 5.0 │ 3.6 │ 1.4 │ 0.2 │ setosa │\n",
"│ 6 │ 5.4 │ 3.9 │ 1.7 │ 0.4 │ setosa │\n",
"│ 7 │ 4.6 │ 3.4 │ 1.4 │ 0.3 │ setosa │\n",
"│ 8 │ 5.0 │ 3.4 │ 1.5 │ 0.2 │ setosa │\n",
"│ 9 │ 4.4 │ 2.9 │ 1.4 │ 0.2 │ setosa │\n",
"│ 10 │ 4.9 │ 3.1 │ 1.5 │ 0.1 │ setosa │\n",
"⋮\n",
"│ 140 │ 6.9 │ 3.1 │ 5.4 │ 2.1 │ virginica │\n",
"│ 141 │ 6.7 │ 3.1 │ 5.6 │ 2.4 │ virginica │\n",
"│ 142 │ 6.9 │ 3.1 │ 5.1 │ 2.3 │ virginica │\n",
"│ 143 │ 5.8 │ 2.7 │ 5.1 │ 1.9 │ virginica │\n",
"│ 144 │ 6.8 │ 3.2 │ 5.9 │ 2.3 │ virginica │\n",
"│ 145 │ 6.7 │ 3.3 │ 5.7 │ 2.5 │ virginica │\n",
"│ 146 │ 6.7 │ 3.0 │ 5.2 │ 2.3 │ virginica │\n",
"│ 147 │ 6.3 │ 2.5 │ 5.0 │ 1.9 │ virginica │\n",
"│ 148 │ 6.5 │ 3.0 │ 5.2 │ 2.0 │ virginica │\n",
"│ 149 │ 6.2 │ 3.4 │ 5.4 │ 2.3 │ virginica │\n",
"│ 150 │ 5.9 │ 3.0 │ 5.1 │ 1.8 │ virginica │"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using MLJ\n",
"using RDatasets\n",
"iris = dataset(\"datasets\", \"iris\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"const X = iris[:, 1:4];\n",
"const y = iris[:, 5];"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"import MLJModels ✔\n",
"import DecisionTree ✔\n",
"import MLJModels.DecisionTree_.DecisionTreeClassifier ✔\n"
]
}
],
"source": [
"@load DecisionTreeClassifier"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeClassifier(pruning_purity = 1.0,\n",
" max_depth = 2,\n",
" min_samples_leaf = 1,\n",
" min_samples_split = 2,\n",
" min_purity_increase = 0.0,\n",
" n_subfeatures = 0.0,\n",
" display_depth = 5,\n",
" post_prune = false,\n",
" merge_purity_threshold = 0.9,)\u001b[34m @ 1…35\u001b[39m"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree_model = DecisionTreeClassifier(max_depth=2)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[34mMachine{DecisionTreeClassifier} @ 6…34\u001b[39m\n"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree = machine(tree_model, X, y)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Training \u001b[34mMachine{DecisionTreeClassifier} @ 6…34\u001b[39m.\n",
"└ @ MLJ /home/simon/.julia/dev/MLJ/src/machines.jl:130\n"
]
},
{
"ename": "MethodError",
"evalue": "MethodError: no method matching build_tree(::CategoricalArray{String,1,UInt8,String,CategoricalString{UInt8},Union{}}, ::Array{Float64,2}, ::Float64, ::Int64, ::Int64, ::Int64, ::Float64)\nClosest candidates are:\n build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27\n build_tree(!Matched::Array{T,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T} at /home/simon/.julia/dev/DecisionTree/src/classification/main.jl:83\n build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27\n ...",
"output_type": "error",
"traceback": [
"MethodError: no method matching build_tree(::CategoricalArray{String,1,UInt8,String,CategoricalString{UInt8},Union{}}, ::Array{Float64,2}, ::Float64, ::Int64, ::Int64, ::Int64, ::Float64)\nClosest candidates are:\n build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27\n build_tree(!Matched::Array{T,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T} at /home/simon/.julia/dev/DecisionTree/src/classification/main.jl:83\n build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27\n ...",
"",
"Stacktrace:",
" [1] fit(::DecisionTreeClassifier, ::Int64, ::DataFrame, ::CategoricalArray{String,1,UInt8,String,CategoricalString{UInt8},Union{}}) at /home/simon/.julia/dev/MLJModels/src/DecisionTree.jl:110",
" [2] #fit!#3(::Array{Int64,1}, ::Int64, ::Bool, ::Function, ::Machine{DecisionTreeClassifier}) at /home/simon/.julia/dev/MLJ/src/machines.jl:131",
" [3] (::getfield(StatsBase, Symbol(\"#kw##fit!\")))(::NamedTuple{(:rows,),Tuple{Array{Int64,1}}}, ::typeof(fit!), ::Machine{DecisionTreeClassifier}) at ./none:0",
" [4] top-level scope at In[6]:2"
]
}
],
"source": [
"train, test = partition(eachindex(y), 0.7, shuffle=true); # 70:30 split\n",
"fit!(tree, rows=train)\n",
"yhat = predict(tree, X[test,:]);\n",
"misclassification_rate(yhat, y[test])"
]
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"kernelspec": {
"display_name": "Julia 1.1.0",
"language": "julia",
"name": "julia-1.1"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.1.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment