Created
February 1, 2020 09:08
-
-
Save vigsterkr/48a5f0523528bbab85bec04464ca2b6c to your computer and use it in GitHub Desktop.
ShogunML with SciRuby stack
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"false" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"require 'daru'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<b> Daru::DataFrame(150x5) </b>\n", | |
"<table>\n", | |
" <thead>\n", | |
" \n", | |
" <tr>\n", | |
" <th></th>\n", | |
" \n", | |
" <th>sepal_length</th>\n", | |
" \n", | |
" <th>sepal_width</th>\n", | |
" \n", | |
" <th>petal_length</th>\n", | |
" \n", | |
" <th>petal_width</th>\n", | |
" \n", | |
" <th>species</th>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
"</thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" \n", | |
" <td>5.1</td>\n", | |
" \n", | |
" <td>3.5</td>\n", | |
" \n", | |
" <td>1.4</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" \n", | |
" <td>4.9</td>\n", | |
" \n", | |
" <td>3.0</td>\n", | |
" \n", | |
" <td>1.4</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" \n", | |
" <td>4.7</td>\n", | |
" \n", | |
" <td>3.2</td>\n", | |
" \n", | |
" <td>1.3</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" \n", | |
" <td>4.6</td>\n", | |
" \n", | |
" <td>3.1</td>\n", | |
" \n", | |
" <td>1.5</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" \n", | |
" <td>5.0</td>\n", | |
" \n", | |
" <td>3.6</td>\n", | |
" \n", | |
" <td>1.4</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>5</td>\n", | |
" \n", | |
" <td>5.4</td>\n", | |
" \n", | |
" <td>3.9</td>\n", | |
" \n", | |
" <td>1.7</td>\n", | |
" \n", | |
" <td>0.4</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>6</td>\n", | |
" \n", | |
" <td>4.6</td>\n", | |
" \n", | |
" <td>3.4</td>\n", | |
" \n", | |
" <td>1.4</td>\n", | |
" \n", | |
" <td>0.3</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>7</td>\n", | |
" \n", | |
" <td>5.0</td>\n", | |
" \n", | |
" <td>3.4</td>\n", | |
" \n", | |
" <td>1.5</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>8</td>\n", | |
" \n", | |
" <td>4.4</td>\n", | |
" \n", | |
" <td>2.9</td>\n", | |
" \n", | |
" <td>1.4</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>9</td>\n", | |
" \n", | |
" <td>4.9</td>\n", | |
" \n", | |
" <td>3.1</td>\n", | |
" \n", | |
" <td>1.5</td>\n", | |
" \n", | |
" <td>0.1</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>10</td>\n", | |
" \n", | |
" <td>5.4</td>\n", | |
" \n", | |
" <td>3.7</td>\n", | |
" \n", | |
" <td>1.5</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>11</td>\n", | |
" \n", | |
" <td>4.8</td>\n", | |
" \n", | |
" <td>3.4</td>\n", | |
" \n", | |
" <td>1.6</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>12</td>\n", | |
" \n", | |
" <td>4.8</td>\n", | |
" \n", | |
" <td>3.0</td>\n", | |
" \n", | |
" <td>1.4</td>\n", | |
" \n", | |
" <td>0.1</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>13</td>\n", | |
" \n", | |
" <td>4.3</td>\n", | |
" \n", | |
" <td>3.0</td>\n", | |
" \n", | |
" <td>1.1</td>\n", | |
" \n", | |
" <td>0.1</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>14</td>\n", | |
" \n", | |
" <td>5.8</td>\n", | |
" \n", | |
" <td>4.0</td>\n", | |
" \n", | |
" <td>1.2</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>15</td>\n", | |
" \n", | |
" <td>5.7</td>\n", | |
" \n", | |
" <td>4.4</td>\n", | |
" \n", | |
" <td>1.5</td>\n", | |
" \n", | |
" <td>0.4</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>16</td>\n", | |
" \n", | |
" <td>5.4</td>\n", | |
" \n", | |
" <td>3.9</td>\n", | |
" \n", | |
" <td>1.3</td>\n", | |
" \n", | |
" <td>0.4</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>17</td>\n", | |
" \n", | |
" <td>5.1</td>\n", | |
" \n", | |
" <td>3.5</td>\n", | |
" \n", | |
" <td>1.4</td>\n", | |
" \n", | |
" <td>0.3</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>18</td>\n", | |
" \n", | |
" <td>5.7</td>\n", | |
" \n", | |
" <td>3.8</td>\n", | |
" \n", | |
" <td>1.7</td>\n", | |
" \n", | |
" <td>0.3</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>19</td>\n", | |
" \n", | |
" <td>5.1</td>\n", | |
" \n", | |
" <td>3.8</td>\n", | |
" \n", | |
" <td>1.5</td>\n", | |
" \n", | |
" <td>0.3</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>20</td>\n", | |
" \n", | |
" <td>5.4</td>\n", | |
" \n", | |
" <td>3.4</td>\n", | |
" \n", | |
" <td>1.7</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>21</td>\n", | |
" \n", | |
" <td>5.1</td>\n", | |
" \n", | |
" <td>3.7</td>\n", | |
" \n", | |
" <td>1.5</td>\n", | |
" \n", | |
" <td>0.4</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>22</td>\n", | |
" \n", | |
" <td>4.6</td>\n", | |
" \n", | |
" <td>3.6</td>\n", | |
" \n", | |
" <td>1.0</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>23</td>\n", | |
" \n", | |
" <td>5.1</td>\n", | |
" \n", | |
" <td>3.3</td>\n", | |
" \n", | |
" <td>1.7</td>\n", | |
" \n", | |
" <td>0.5</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>24</td>\n", | |
" \n", | |
" <td>4.8</td>\n", | |
" \n", | |
" <td>3.4</td>\n", | |
" \n", | |
" <td>1.9</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>25</td>\n", | |
" \n", | |
" <td>5.0</td>\n", | |
" \n", | |
" <td>3.0</td>\n", | |
" \n", | |
" <td>1.6</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>26</td>\n", | |
" \n", | |
" <td>5.0</td>\n", | |
" \n", | |
" <td>3.4</td>\n", | |
" \n", | |
" <td>1.6</td>\n", | |
" \n", | |
" <td>0.4</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>27</td>\n", | |
" \n", | |
" <td>5.2</td>\n", | |
" \n", | |
" <td>3.5</td>\n", | |
" \n", | |
" <td>1.5</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>28</td>\n", | |
" \n", | |
" <td>5.2</td>\n", | |
" \n", | |
" <td>3.4</td>\n", | |
" \n", | |
" <td>1.4</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
" <tr>\n", | |
" <td>29</td>\n", | |
" \n", | |
" <td>4.7</td>\n", | |
" \n", | |
" <td>3.2</td>\n", | |
" \n", | |
" <td>1.6</td>\n", | |
" \n", | |
" <td>0.2</td>\n", | |
" \n", | |
" <td>setosa</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
"\n", | |
" \n", | |
" <tr>\n", | |
" \n", | |
" <td>...</td>\n", | |
" \n", | |
" <td>...</td>\n", | |
" \n", | |
" <td>...</td>\n", | |
" \n", | |
" <td>...</td>\n", | |
" \n", | |
" <td>...</td>\n", | |
" \n", | |
" <td>...</td>\n", | |
" \n", | |
" </tr>\n", | |
"\n", | |
" \n", | |
"\n", | |
" <tr>\n", | |
" <td>149</td>\n", | |
" \n", | |
" <td>5.9</td>\n", | |
" \n", | |
" <td>3.0</td>\n", | |
" \n", | |
" <td>5.1</td>\n", | |
" \n", | |
" <td>1.8</td>\n", | |
" \n", | |
" <td>virginica</td>\n", | |
" \n", | |
" </tr>\n", | |
" \n", | |
"</tbody>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"#<Daru::DataFrame(150x5)>\n", | |
" sepal_leng sepal_widt petal_leng petal_widt species\n", | |
" 0 5.1 3.5 1.4 0.2 setosa\n", | |
" 1 4.9 3.0 1.4 0.2 setosa\n", | |
" 2 4.7 3.2 1.3 0.2 setosa\n", | |
" 3 4.6 3.1 1.5 0.2 setosa\n", | |
" 4 5.0 3.6 1.4 0.2 setosa\n", | |
" 5 5.4 3.9 1.7 0.4 setosa\n", | |
" 6 4.6 3.4 1.4 0.3 setosa\n", | |
" 7 5.0 3.4 1.5 0.2 setosa\n", | |
" 8 4.4 2.9 1.4 0.2 setosa\n", | |
" 9 4.9 3.1 1.5 0.1 setosa\n", | |
" 10 5.4 3.7 1.5 0.2 setosa\n", | |
" 11 4.8 3.4 1.6 0.2 setosa\n", | |
" 12 4.8 3.0 1.4 0.1 setosa\n", | |
" 13 4.3 3.0 1.1 0.1 setosa\n", | |
" 14 5.8 4.0 1.2 0.2 setosa\n", | |
" ... ... ... ... ... ..." | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df = Daru::DataFrame.from_csv \"https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/639388c2cbc2120a14dcf466e85730eb8be498bb/iris.csv\"\n", | |
"df" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Shogun ML\n", | |
"\n", | |
"You need to compile shogun and the ruby interface first:\n", | |
"```\n", | |
"git clone https://github.com/shogun-toolbox/shogun.git\n", | |
"cd shogun\n", | |
"mkdir build\n", | |
"cd build\n", | |
"cmake -G\"Ninja\" -DINTERFACE_RUBY=ON ..\n", | |
"ninja\n", | |
"```\n", | |
"\n", | |
"once you've built it either you install the generated binaries with `ninja install` or simply just set `RUBYLIB` runtime environment before you start the jupyter notebook, for example while still in the `build` directory run the following command:\n", | |
"```\n", | |
"export RUBYLIB=$PWD/src/interfaces/ruby:$RUBYLIB\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"false" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"require 'shogun'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Prepare the data for the ShogunML model: `X` variables contain the features and `y` contains the labels." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/var/lib/gems/2.5.0/gems/nmatrix-0.2.4/lib/nmatrix/monkeys.rb:49: warning: constant ::Fixnum is deprecated\n", | |
"<main>: warning: already initialized constant X\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"#<Shogun::Labels:0x0000564db87f5178 @__swigtype__=\"_p_std__shared_ptrT_shogun__Labels_t\">" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X = Shogun::features(df['sepal_length','sepal_width', 'petal_length', 'petal_width'].to_nmatrix.transpose)\n", | |
"y = Shogun::labels(df.species.to_category.to_ints.to_ary)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Create a OneVSOne multiclass classifier that uses LibLinear as a base binary classifier" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"classifier = Shogun::machine(\"MulticlassLibLinear\")\n", | |
"classifier.put(\"C\", 1.0)\n", | |
"classifier.put(\"labels\", y)\n", | |
"classifier.put(\"use_bias\", true)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Train the model using the `X` features" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"#<Shogun::MulticlassLabels:0x0000564db8822ab0 @__swigtype__=\"_p_std__shared_ptrT_shogun__MulticlassLabels_t\">" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"classifier.train(X)\n", | |
"y_pred = classifier.apply_multiclass(X)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Measure the model's performance on the train data (note plz create a train/test split to actually measure the real performance of your model!)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.98" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"eval = Shogun::evaluation(\"MulticlassAccuracy\")\n", | |
"accuracy = eval.evaluate(y, y_pred)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"binary_classifier = Shogun::machine(\"LibLinear\")\n", | |
"strategy = Shogun::multiclass_strategy(\"MulticlassOneVsRestStrategy\")\n", | |
"mc_classifier = Shogun::machine(\"LinearMulticlassMachine\")\n", | |
"mc_classifier.put(\"multiclass_strategy\", strategy)\n", | |
"mc_classifier.put(\"machine\", binary_classifier)\n", | |
"mc_classifier.put(\"labels\", y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"true" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mc_classifier.train(X)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.94" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_mc_pred = mc_classifier.apply_multiclass(X)\n", | |
"accuracy = eval.evaluate(y, y_mc_pred)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Ruby 2.5.5", | |
"language": "ruby", | |
"name": "ruby" | |
}, | |
"language_info": { | |
"file_extension": ".rb", | |
"mimetype": "application/x-ruby", | |
"name": "ruby", | |
"version": "2.5.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment