Last active
January 22, 2019 16:08
-
-
Save alexlenail/7a31a0e5e8bb64256b59493d5add5e85 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Test SPAMS.fistaGraph\n", | |
"\n", | |
"Docs http://spams-devel.gforge.inria.fr/doc-R/html/doc_spams006.html#sec39" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import networkx as nx\n", | |
"import scipy\n", | |
"import scipy.sparse as ssp\n", | |
"\n", | |
"import time\n", | |
"import spams\n", | |
"\n", | |
"def flatten(list_of_lists): return [item for sublist in list_of_lists for item in sublist]\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## I. Generate synthetic \"easy\" graph and graph signals" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_nodes = 15\n", | |
"p_edge = 0.3\n", | |
"g1 = nx.fast_gnp_random_graph(num_nodes, p_edge)\n", | |
"g2 = nx.fast_gnp_random_graph(num_nodes, p_edge)\n", | |
"\n", | |
"nx.relabel_nodes(g2, {number: number + 15 for number in g2.nodes}, copy=False)\n", | |
"\n", | |
"g = nx.compose(g1, g2)\n", | |
"\n", | |
"g.add_edge(1, 25)\n", | |
"g.add_edge(2, 20)\n", | |
"g.add_edge(3, 15)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/alex/miniconda3/lib/python3.7/site-packages/networkx/drawing/nx_pylab.py:611: MatplotlibDeprecationWarning: isinstance(..., numbers.Number)\n", | |
" if cb.is_numlike(alpha):\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"nx.draw_spring(g)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"edgelist = nx.to_pandas_edgelist(g).values" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### True signal comes from 0 and its neighbors" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[0, 2, 11, 12, 14]" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"signal_nodes = [0]+[n for n in g.neighbors(0)]\n", | |
"signal_nodes" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### First half is class 1, second half is class 0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | |
" 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0.])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"NUM_POSITIVES = 20\n", | |
"NUM_NEGATIVES = 20\n", | |
"y = np.concatenate((np.ones(NUM_POSITIVES), np.zeros(NUM_NEGATIVES)))\n", | |
"y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>0</th>\n", | |
" <th>1</th>\n", | |
" <th>2</th>\n", | |
" <th>3</th>\n", | |
" <th>4</th>\n", | |
" <th>5</th>\n", | |
" <th>6</th>\n", | |
" <th>7</th>\n", | |
" <th>8</th>\n", | |
" <th>9</th>\n", | |
" <th>...</th>\n", | |
" <th>20</th>\n", | |
" <th>21</th>\n", | |
" <th>22</th>\n", | |
" <th>23</th>\n", | |
" <th>24</th>\n", | |
" <th>25</th>\n", | |
" <th>26</th>\n", | |
" <th>27</th>\n", | |
" <th>28</th>\n", | |
" <th>29</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1.0</td>\n", | |
" <td>-1.626836</td>\n", | |
" <td>1.0</td>\n", | |
" <td>-1.667516</td>\n", | |
" <td>-0.108261</td>\n", | |
" <td>0.390096</td>\n", | |
" <td>-0.861851</td>\n", | |
" <td>-0.156386</td>\n", | |
" <td>1.577740</td>\n", | |
" <td>1.355505</td>\n", | |
" <td>...</td>\n", | |
" <td>0.494552</td>\n", | |
" <td>0.357644</td>\n", | |
" <td>-0.999935</td>\n", | |
" <td>0.391030</td>\n", | |
" <td>-0.382606</td>\n", | |
" <td>1.747134</td>\n", | |
" <td>-0.019657</td>\n", | |
" <td>2.119312</td>\n", | |
" <td>0.032667</td>\n", | |
" <td>-0.146469</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1.0</td>\n", | |
" <td>0.362366</td>\n", | |
" <td>1.0</td>\n", | |
" <td>-0.635774</td>\n", | |
" <td>-0.836827</td>\n", | |
" <td>-0.276659</td>\n", | |
" <td>-0.809957</td>\n", | |
" <td>0.002832</td>\n", | |
" <td>0.183393</td>\n", | |
" <td>1.224518</td>\n", | |
" <td>...</td>\n", | |
" <td>0.983904</td>\n", | |
" <td>0.622035</td>\n", | |
" <td>-0.619685</td>\n", | |
" <td>0.268317</td>\n", | |
" <td>-1.240670</td>\n", | |
" <td>-1.471808</td>\n", | |
" <td>0.637272</td>\n", | |
" <td>0.492749</td>\n", | |
" <td>0.539597</td>\n", | |
" <td>-0.780348</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1.0</td>\n", | |
" <td>0.126977</td>\n", | |
" <td>1.0</td>\n", | |
" <td>-0.112256</td>\n", | |
" <td>1.780466</td>\n", | |
" <td>0.342744</td>\n", | |
" <td>-0.631568</td>\n", | |
" <td>-1.072373</td>\n", | |
" <td>-2.127742</td>\n", | |
" <td>3.137458</td>\n", | |
" <td>...</td>\n", | |
" <td>0.010640</td>\n", | |
" <td>0.457412</td>\n", | |
" <td>-0.346469</td>\n", | |
" <td>0.837266</td>\n", | |
" <td>0.216169</td>\n", | |
" <td>-1.255527</td>\n", | |
" <td>2.048838</td>\n", | |
" <td>-1.903665</td>\n", | |
" <td>1.030514</td>\n", | |
" <td>1.808216</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1.0</td>\n", | |
" <td>-1.144054</td>\n", | |
" <td>1.0</td>\n", | |
" <td>-0.673975</td>\n", | |
" <td>-0.488836</td>\n", | |
" <td>-0.597826</td>\n", | |
" <td>-0.467641</td>\n", | |
" <td>-0.079044</td>\n", | |
" <td>-0.620556</td>\n", | |
" <td>-0.000212</td>\n", | |
" <td>...</td>\n", | |
" <td>-1.718976</td>\n", | |
" <td>-1.400485</td>\n", | |
" <td>0.870043</td>\n", | |
" <td>-0.137372</td>\n", | |
" <td>-0.695848</td>\n", | |
" <td>-0.011468</td>\n", | |
" <td>-0.990155</td>\n", | |
" <td>0.890150</td>\n", | |
" <td>-0.059695</td>\n", | |
" <td>2.337334</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1.0</td>\n", | |
" <td>-1.383579</td>\n", | |
" <td>1.0</td>\n", | |
" <td>-0.772876</td>\n", | |
" <td>-0.020679</td>\n", | |
" <td>1.385638</td>\n", | |
" <td>-0.299431</td>\n", | |
" <td>-0.074756</td>\n", | |
" <td>-0.465050</td>\n", | |
" <td>0.641146</td>\n", | |
" <td>...</td>\n", | |
" <td>-0.083365</td>\n", | |
" <td>0.821267</td>\n", | |
" <td>0.327011</td>\n", | |
" <td>1.879729</td>\n", | |
" <td>0.859166</td>\n", | |
" <td>-1.927820</td>\n", | |
" <td>-0.632995</td>\n", | |
" <td>-0.974812</td>\n", | |
" <td>1.312887</td>\n", | |
" <td>0.643763</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>5 rows × 30 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 0 1 2 3 4 5 6 7 \\\n", | |
"0 1.0 -1.626836 1.0 -1.667516 -0.108261 0.390096 -0.861851 -0.156386 \n", | |
"1 1.0 0.362366 1.0 -0.635774 -0.836827 -0.276659 -0.809957 0.002832 \n", | |
"2 1.0 0.126977 1.0 -0.112256 1.780466 0.342744 -0.631568 -1.072373 \n", | |
"3 1.0 -1.144054 1.0 -0.673975 -0.488836 -0.597826 -0.467641 -0.079044 \n", | |
"4 1.0 -1.383579 1.0 -0.772876 -0.020679 1.385638 -0.299431 -0.074756 \n", | |
"\n", | |
" 8 9 ... 20 21 22 23 \\\n", | |
"0 1.577740 1.355505 ... 0.494552 0.357644 -0.999935 0.391030 \n", | |
"1 0.183393 1.224518 ... 0.983904 0.622035 -0.619685 0.268317 \n", | |
"2 -2.127742 3.137458 ... 0.010640 0.457412 -0.346469 0.837266 \n", | |
"3 -0.620556 -0.000212 ... -1.718976 -1.400485 0.870043 -0.137372 \n", | |
"4 -0.465050 0.641146 ... -0.083365 0.821267 0.327011 1.879729 \n", | |
"\n", | |
" 24 25 26 27 28 29 \n", | |
"0 -0.382606 1.747134 -0.019657 2.119312 0.032667 -0.146469 \n", | |
"1 -1.240670 -1.471808 0.637272 0.492749 0.539597 -0.780348 \n", | |
"2 0.216169 -1.255527 2.048838 -1.903665 1.030514 1.808216 \n", | |
"3 -0.695848 -0.011468 -0.990155 0.890150 -0.059695 2.337334 \n", | |
"4 0.859166 -1.927820 -0.632995 -0.974812 1.312887 0.643763 \n", | |
"\n", | |
"[5 rows x 30 columns]" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data = pd.DataFrame(np.random.normal(size=(NUM_POSITIVES+NUM_NEGATIVES,30)))\n", | |
"data.loc[0:NUM_POSITIVES, signal_nodes] = 1\n", | |
"data.loc[NUM_POSITIVES:NUM_POSITIVES+NUM_NEGATIVES, signal_nodes] = -1\n", | |
"\n", | |
"data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>0</th>\n", | |
" <th>1</th>\n", | |
" <th>2</th>\n", | |
" <th>3</th>\n", | |
" <th>4</th>\n", | |
" <th>5</th>\n", | |
" <th>6</th>\n", | |
" <th>7</th>\n", | |
" <th>8</th>\n", | |
" <th>9</th>\n", | |
" <th>...</th>\n", | |
" <th>20</th>\n", | |
" <th>21</th>\n", | |
" <th>22</th>\n", | |
" <th>23</th>\n", | |
" <th>24</th>\n", | |
" <th>25</th>\n", | |
" <th>26</th>\n", | |
" <th>27</th>\n", | |
" <th>28</th>\n", | |
" <th>29</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>35</th>\n", | |
" <td>-1.0</td>\n", | |
" <td>0.711782</td>\n", | |
" <td>-1.0</td>\n", | |
" <td>-0.300740</td>\n", | |
" <td>0.771417</td>\n", | |
" <td>0.419459</td>\n", | |
" <td>0.134785</td>\n", | |
" <td>-1.810631</td>\n", | |
" <td>1.698634</td>\n", | |
" <td>0.089898</td>\n", | |
" <td>...</td>\n", | |
" <td>0.161639</td>\n", | |
" <td>0.903178</td>\n", | |
" <td>-0.773467</td>\n", | |
" <td>0.145566</td>\n", | |
" <td>0.344990</td>\n", | |
" <td>0.548841</td>\n", | |
" <td>0.444323</td>\n", | |
" <td>0.594573</td>\n", | |
" <td>0.014439</td>\n", | |
" <td>0.111183</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>36</th>\n", | |
" <td>-1.0</td>\n", | |
" <td>-0.363750</td>\n", | |
" <td>-1.0</td>\n", | |
" <td>0.319550</td>\n", | |
" <td>0.582871</td>\n", | |
" <td>-0.366024</td>\n", | |
" <td>2.009703</td>\n", | |
" <td>-0.849986</td>\n", | |
" <td>0.252170</td>\n", | |
" <td>-0.349139</td>\n", | |
" <td>...</td>\n", | |
" <td>-1.003231</td>\n", | |
" <td>0.109738</td>\n", | |
" <td>-0.425068</td>\n", | |
" <td>-2.085400</td>\n", | |
" <td>-0.354644</td>\n", | |
" <td>-0.923705</td>\n", | |
" <td>1.521942</td>\n", | |
" <td>-0.776429</td>\n", | |
" <td>0.755865</td>\n", | |
" <td>-1.172837</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>37</th>\n", | |
" <td>-1.0</td>\n", | |
" <td>-0.212570</td>\n", | |
" <td>-1.0</td>\n", | |
" <td>0.809884</td>\n", | |
" <td>-1.436927</td>\n", | |
" <td>0.728994</td>\n", | |
" <td>-0.141139</td>\n", | |
" <td>-0.658784</td>\n", | |
" <td>0.074903</td>\n", | |
" <td>1.463954</td>\n", | |
" <td>...</td>\n", | |
" <td>-0.643148</td>\n", | |
" <td>-0.576813</td>\n", | |
" <td>-1.599052</td>\n", | |
" <td>1.233322</td>\n", | |
" <td>-1.907276</td>\n", | |
" <td>1.503759</td>\n", | |
" <td>0.632590</td>\n", | |
" <td>-0.616824</td>\n", | |
" <td>1.147559</td>\n", | |
" <td>1.531823</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>38</th>\n", | |
" <td>-1.0</td>\n", | |
" <td>-0.561549</td>\n", | |
" <td>-1.0</td>\n", | |
" <td>0.625856</td>\n", | |
" <td>-0.822934</td>\n", | |
" <td>-0.979575</td>\n", | |
" <td>0.199540</td>\n", | |
" <td>-1.276706</td>\n", | |
" <td>0.575109</td>\n", | |
" <td>-0.314934</td>\n", | |
" <td>...</td>\n", | |
" <td>-1.487596</td>\n", | |
" <td>0.377102</td>\n", | |
" <td>1.008312</td>\n", | |
" <td>0.622991</td>\n", | |
" <td>0.523273</td>\n", | |
" <td>-1.566366</td>\n", | |
" <td>0.030276</td>\n", | |
" <td>0.092158</td>\n", | |
" <td>1.426438</td>\n", | |
" <td>1.271735</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>39</th>\n", | |
" <td>-1.0</td>\n", | |
" <td>-1.453606</td>\n", | |
" <td>-1.0</td>\n", | |
" <td>1.025963</td>\n", | |
" <td>-0.199703</td>\n", | |
" <td>0.196850</td>\n", | |
" <td>-0.033403</td>\n", | |
" <td>-0.131592</td>\n", | |
" <td>0.952464</td>\n", | |
" <td>-1.219901</td>\n", | |
" <td>...</td>\n", | |
" <td>-0.476985</td>\n", | |
" <td>1.392822</td>\n", | |
" <td>-1.000271</td>\n", | |
" <td>-0.647108</td>\n", | |
" <td>0.612353</td>\n", | |
" <td>1.975893</td>\n", | |
" <td>-0.744764</td>\n", | |
" <td>-1.440025</td>\n", | |
" <td>-0.248124</td>\n", | |
" <td>1.060998</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>5 rows × 30 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 0 1 2 3 4 5 6 7 \\\n", | |
"35 -1.0 0.711782 -1.0 -0.300740 0.771417 0.419459 0.134785 -1.810631 \n", | |
"36 -1.0 -0.363750 -1.0 0.319550 0.582871 -0.366024 2.009703 -0.849986 \n", | |
"37 -1.0 -0.212570 -1.0 0.809884 -1.436927 0.728994 -0.141139 -0.658784 \n", | |
"38 -1.0 -0.561549 -1.0 0.625856 -0.822934 -0.979575 0.199540 -1.276706 \n", | |
"39 -1.0 -1.453606 -1.0 1.025963 -0.199703 0.196850 -0.033403 -0.131592 \n", | |
"\n", | |
" 8 9 ... 20 21 22 23 \\\n", | |
"35 1.698634 0.089898 ... 0.161639 0.903178 -0.773467 0.145566 \n", | |
"36 0.252170 -0.349139 ... -1.003231 0.109738 -0.425068 -2.085400 \n", | |
"37 0.074903 1.463954 ... -0.643148 -0.576813 -1.599052 1.233322 \n", | |
"38 0.575109 -0.314934 ... -1.487596 0.377102 1.008312 0.622991 \n", | |
"39 0.952464 -1.219901 ... -0.476985 1.392822 -1.000271 -0.647108 \n", | |
"\n", | |
" 24 25 26 27 28 29 \n", | |
"35 0.344990 0.548841 0.444323 0.594573 0.014439 0.111183 \n", | |
"36 -0.354644 -0.923705 1.521942 -0.776429 0.755865 -1.172837 \n", | |
"37 -1.907276 1.503759 0.632590 -0.616824 1.147559 1.531823 \n", | |
"38 0.523273 -1.566366 0.030276 0.092158 1.426438 1.271735 \n", | |
"39 0.612353 1.975893 -0.744764 -1.440025 -0.248124 1.060998 \n", | |
"\n", | |
"[5 rows x 30 columns]" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data.tail()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"edge_weights = [1]*len(edgelist)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Groups are assigned as the one-hop neighborhood of every node" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"30" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"neighborhoods = [[i]+[n for n in g.neighbors(node)] for i, node in enumerate(g.nodes)]\n", | |
"num_groups = len(neighborhoods)\n", | |
"num_groups" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## II. Test SPAMS.fistaGraph on \"easy\" graph" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Name: spams.fistaGraph\n", | |
"#\n", | |
"# Description:\n", | |
"# spams.fistaGraph solves sparse regularized problems.\n", | |
"# X is a design matrix of size m x p\n", | |
"# X=[x^1,...,x^n]', where the x_i's are the rows of X\n", | |
"# Y=[y^1,...,y^n] is a matrix of size m x n\n", | |
"# \n", | |
"# It implements the algorithms FISTA, ISTA and subgradient descent for solving\n", | |
"# \n", | |
"# min_W loss(W) + lambda1 psi(W)\n", | |
"# \n", | |
"# The function psi are those used by spams.proximalGraph (see documentation)\n", | |
"# for the loss functions, see the documentation of spams.fistaFlat\n", | |
"# \n", | |
"# This function can also handle intercepts (last row of W is not regularized),\n", | |
"# and/or non-negativity constraints on W.\n", | |
"#" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# graph: struct\n", | |
"# with three fields, eta_g, groups, and groups_var\n", | |
"# \n", | |
"# The first fields sets the weights for every group\n", | |
"# graph.eta_g double N vector \n", | |
"\n", | |
"eta_g = np.ones(num_groups)\n", | |
" \n", | |
"# The next field sets inclusion relations between groups (but not between groups and variables):\n", | |
"# graph.groups sparse (double or boolean) N x N matrix \n", | |
"# the (i,j) entry is non-zero if and only if i is different than j and \n", | |
"# gi is included in gj.\n", | |
"\n", | |
"groups = scipy.sparse.csc_matrix(np.zeros((num_groups,num_groups)),dtype=np.bool)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"i, j = zip(*flatten([[(i, j) for j in neighbors] for i, neighbors in enumerate(neighborhoods)]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# The next field sets inclusion relations between groups and variables\n", | |
"# graph.groups_var sparse (double or boolean) p x N matrix\n", | |
"# the (i,j) entry is non-zero if and only if the variable i is included \n", | |
"# in gj, but not in any children of gj.\n", | |
"\n", | |
"# scipy.sparse.csc_matrix((data, (row_ind, col_ind)), [shape=(M, N)])\n", | |
"# where data, row_ind and col_ind satisfy the relationship a[row_ind[k], col_ind[k]] = data[k].\n", | |
"\n", | |
"groups_var = scipy.sparse.csc_matrix((np.ones(len(i)),(i,j)),dtype=np.bool)\n", | |
"\n", | |
"# graph: struct\n", | |
"# with three fields, eta_g, groups, and groups_var\n", | |
"# \n", | |
"graph = {'eta_g':eta_g,'groups':groups,'groups_var':groups_var}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | |
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"graph['eta_g']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<30x30 sparse matrix of type '<class 'numpy.bool_'>'\n", | |
"\twith 0 stored elements in Compressed Sparse Column format>" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"graph['groups']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<30x30 sparse matrix of type '<class 'numpy.bool_'>'\n", | |
"\twith 152 stored elements in Compressed Sparse Column format>" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"graph['groups_var']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Usage: spams.fistaGraph( Y,\n", | |
"# X,\n", | |
"# W0,\n", | |
"# graph,\n", | |
"# return_optim_info=False,\n", | |
"# numThreads=-1,\n", | |
"# max_it=1000,\n", | |
"# L0=1.0,\n", | |
"# fixed_step=False,\n", | |
"# gamma=1.5,\n", | |
"# lambda1=1.0,\n", | |
"# lambda2=0.,\n", | |
"# lambda3=0.,\n", | |
"# a=1.0,\n", | |
"# b=0.,\n", | |
"# tol=0.000001,\n", | |
"# it0=100,\n", | |
"# compute_gram=False,\n", | |
"# intercept=False,\n", | |
"# regul=\"\",\n", | |
"# loss=\"\",\n", | |
"# verbose=False,\n", | |
"# pos=False,\n", | |
"# ista=False,\n", | |
"# subgrad=False,\n", | |
"# linesearch_mode=0) \n", | |
"#\n", | |
"# Inputs:\n", | |
"# Y : double dense m x n matrix\n", | |
"\n", | |
"Y = np.asfortranarray(np.expand_dims(y, axis=1)).astype(float)\n", | |
"Y = spams.normalize(Y)\n", | |
"\n", | |
"# X : double dense or sparse m x p matrix\n", | |
"\n", | |
"X = np.asfortranarray(data.values).astype(float)\n", | |
"X = spams.normalize(X)\n", | |
"\n", | |
"# W0 : double dense p x n matrix or p x Nn matrix for multi-logistic loss initial guess\n", | |
"\n", | |
"W0 = np.zeros((X.shape[1],Y.shape[1]),dtype=np.float64,order=\"F\")\n", | |
"\n", | |
"# graph : struct see documentation of proximalGraph\n", | |
"# return_optim_info : if true the function will return a tuple of matrices.\n", | |
"# loss : choice of loss, see above\n", | |
"# regul : choice of regularization, see below\n", | |
"# lambda1 : regularization parameter\n", | |
"# lambda2 : regularization parameter, 0 by default\n", | |
"# lambda3 : regularization parameter, 0 by default\n", | |
"# verbose : verbosity level, false by default\n", | |
"# pos : adds positivity constraints on the coefficients, false by default\n", | |
"# numThreads : number of threads for exploiting multi-core / multi-cpus. By default, it takes the value -1, which automatically selects all the available CPUs/cores.\n", | |
"# max_it : maximum number of iterations, 100 by default\n", | |
"# it0 : frequency for computing duality gap, every 10 iterations by default\n", | |
"# tol : tolerance for stopping criteration, which is a relative duality gap if it is available, or a relative change of parameters.\n", | |
"# gamma : multiplier for increasing the parameter L in fista, 1.5 by default\n", | |
"# L0 : initial parameter L in fista, 0.1 by default, should be small enough\n", | |
"# fixed_step : deactive the line search for L in fista and use L0 instead\n", | |
"# compute_gram : pre-compute X^TX, false by default.\n", | |
"# intercept : do not regularize last row of W, false by default.\n", | |
"# ista : use ista instead of fista, false by default.\n", | |
"# subgrad : if not ista, use subradient descent instead of fista, false by default.\n", | |
"# a :\n", | |
"# b : if subgrad, the gradient step is a/(t+b) also similar options as proximalTree\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Regularizers: \n", | |
"# Given an input matrix U=[u^1,\\ldots,u^n], it computes a matrix V=[v^1,\\ldots,v^n] such that\n", | |
"#\n", | |
"# if one chooses a regularization functions on vectors, it computes for each column u of U, a column v of V solving\n", | |
"# if regul='l0' argmin 0.5||u-v||_2^2 + lambda1||v||_0\n", | |
"# if regul='l1' argmin 0.5||u-v||_2^2 + lambda1||v||_1\n", | |
"# if regul='l2' argmin 0.5||u-v||_2^2 + 0.5lambda1||v||_2^2\n", | |
"# if regul='elastic-net' argmin 0.5||u-v||_2^2 + lambda1||v||_1 + lambda1_2||v||_2^2\n", | |
"# if regul='fused-lasso' argmin 0.5||u-v||_2^2 + lambda1 FL(v) + lambda1_2||v||_1 + lambda1_3||v||_2^2\n", | |
"# if regul='linf' argmin 0.5||u-v||_2^2 + lambda1||v||_inf\n", | |
"# if regul='l1-constraint' argmin 0.5||u-v||_2^2 s.t. ||v||_1 <= lambda1\n", | |
"# if regul='l2-not-squared' argmin 0.5||u-v||_2^2 + lambda1||v||_2\n", | |
"# if regul='group-lasso-l2' argmin 0.5||u-v||_2^2 + lambda1 sum_g ||v_g||_2 where the groups are either defined by groups or by size_group,\n", | |
"# if regul='group-lasso-linf' argmin 0.5||u-v||_2^2 + lambda1 sum_g ||v_g||_inf\n", | |
"# if regul='sparse-group-lasso-l2' argmin 0.5||u-v||_2^2 + lambda1 sum_g ||v_g||_2 + lambda1_2 ||v||_1 where the groups are either defined by groups or by size_group,\n", | |
"# if regul='sparse-group-lasso-linf' argmin 0.5||u-v||_2^2 + lambda1 sum_g ||v_g||_inf + lambda1_2 ||v||_1\n", | |
"# if regul='trace-norm-vec' argmin 0.5||u-v||_2^2 + lambda1 ||mat(v)||_* where mat(v) has size_group rows\n", | |
"#\n", | |
"# if regul='graph' argmin 0.5||u-v||_2^2 + lambda1\\sum_{g \\in G} \\eta_g||v_g||_inf\n", | |
"# if regul='graph+ridge' argmin 0.5||u-v||_2^2 + lambda1\\sum_{g \\in G} \\eta_g||v_g||_inf + lambda1_2||v||_2^2\n", | |
"#\n", | |
"# if one chooses a regularization function on matrices\n", | |
"# if regul='l1l2', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/2}\n", | |
"# if regul='l1linf', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/inf}\n", | |
"# if regul='l1l2+l1', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/2} + lambda1_2||V||_{1/1}\n", | |
"# if regul='l1linf+l1', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/inf} + lambda1_2||V||_{1/1}\n", | |
"# if regul='l1linf+row-column', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/inf} + lambda1_2||V'||_{1/inf}\n", | |
"# if regul='trace-norm', V= argmin 0.5||U-V||_F^2 + lambda1||V||_*\n", | |
"# if regul='rank', V= argmin 0.5||U-V||_F^2 + lambda1 rank(V)\n", | |
"# if regul='none', V= argmin 0.5||U-V||_F^2\n", | |
"#\n", | |
"# if regul='multi-task-graph' V=argmin 0.5||U-V||_F^2 + lambda1 \\sum_{i=1}^n\\sum_{g \\in G} \\eta_g||v^i_g||_inf + lambda1_2 \\sum_{g \\in G} \\eta_g max_{j in g}||V_j||_{inf}\n", | |
"#\n", | |
"# for all these regularizations, it is possible to enforce non-negativity constraints\n", | |
"# with the option pos, and to prevent the last row of U to be regularized, with\n", | |
"# the option intercept\n", | |
"\n", | |
"# Note:\n", | |
"# Valid values for the regularization parameter (regul) for fistaGraph (beyond those listed above) are:\n", | |
"# \"tree-l0\"\n", | |
"# \"tree-l2\"\n", | |
"# \"tree-linf\"\n", | |
"# \"graph-l2\",\n", | |
"# \"multi-task-tree\"\n", | |
"# \"rank-vec\"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Loss: \n", | |
"# - if loss='square' and regul is a regularization function for vectors,\n", | |
"# the entries of Y are real-valued, W = [w^1,...,w^n] is a matrix of size p x n\n", | |
"# For all column y of Y, it computes a column w of W such that\n", | |
"# w = argmin 0.5||y- X w||_2^2 + lambda1 psi(w)\n", | |
"# \n", | |
"# - if loss='square' and regul is a regularization function for matrices\n", | |
"# the entries of Y are real-valued, W is a matrix of size p x n. \n", | |
"# It computes the matrix W such that\n", | |
"# W = argmin 0.5||Y- X W||_F^2 + lambda1 psi(W)\n", | |
"# \n", | |
"# - loss='square-missing' same as loss='square', but handles missing data\n", | |
"# represented by NaN (not a number) in the matrix Y\n", | |
"# \n", | |
"# - if loss='logistic' and regul is a regularization function for vectors,\n", | |
"# the entries of Y are either -1 or +1, W = [w^1,...,w^n] is a matrix of size p x n\n", | |
"# For all column y of Y, it computes a column w of W such that\n", | |
"# w = argmin (1/m)sum_{j=1}^m log(1+e^(-y_j x^j' w)) + lambda1 psi(w),\n", | |
"# where x^j is the j-th row of X.\n", | |
"# \n", | |
"# - if loss='logistic' and regul is a regularization function for matrices\n", | |
"# the entries of Y are either -1 or +1, W is a matrix of size p x n\n", | |
"# W = argmin sum_{i=1}^n(1/m)sum_{j=1}^m log(1+e^(-y^i_j x^j' w^i)) + lambda1 psi(W)\n", | |
"# \n", | |
"# - if loss='multi-logistic' and regul is a regularization function for vectors,\n", | |
"# the entries of Y are in {0,1,...,N} where N is the total number of classes\n", | |
"# W = [W^1,...,W^n] is a matrix of size p x Nn, each submatrix W^i is of size p x N\n", | |
"# for all submatrix WW of W, and column y of Y, it computes\n", | |
"# WW = argmin (1/m)sum_{j=1}^m log(sum_{j=1}^r e^(x^j'(ww^j-ww^{y_j}))) + lambda1 sum_{j=1}^N psi(ww^j),\n", | |
"# where ww^j is the j-th column of WW.\n", | |
"# \n", | |
"# - if loss='multi-logistic' and regul is a regularization function for matrices,\n", | |
"# the entries of Y are in {0,1,...,N} where N is the total number of classes\n", | |
"# W is a matrix of size p x N, it computes\n", | |
"# W = argmin (1/m)sum_{j=1}^m log(sum_{j=1}^r e^(x^j'(w^j-w^{y_j}))) + lambda1 psi(W)\n", | |
"# where ww^j is the j-th column of WW.\n", | |
"# \n", | |
"# - loss='cur' useful to perform sparse CUR matrix decompositions, \n", | |
"# W = argmin 0.5||Y-X*W*X||_F^2 + lambda1 psi(W)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mean loss: 0.094094, mean relative duality_gap: -1.000000, time: 0.002162, number of iterations: 482.000000\n" | |
] | |
} | |
], | |
"source": [ | |
"verbose = True\n", | |
"lambda1 = 0 # regularization term (no regularization)\n", | |
"max_it = 100 # maximum number of iterations\n", | |
"L0 = 0.1\n", | |
"tol = 1e-5\n", | |
"intercept = False\n", | |
"pos = False\n", | |
"compute_gram = True\n", | |
"\n", | |
"loss = 'square'\n", | |
"regul = 'none'\n", | |
"tic = time.time()\n", | |
"\n", | |
"(W, optim_info) = spams.fistaGraph(Y, X, W0, graph, return_optim_info=True, loss=loss, regul=regul, verbose=verbose)\n", | |
"\n", | |
"tac = time.time()\n", | |
"t = tac - tic\n", | |
"\n", | |
"print('mean loss: %f, mean relative duality_gap: %f, time: %f, number of iterations: %f' %(np.mean(optim_info[0,:]),np.mean(optim_info[2,:]),t,np.mean(optim_info[3,:])))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mean loss: 0.500000, mean relative duality_gap: 0.000000, time: 0.003627, number of iterations: 100.000000\n" | |
] | |
} | |
], | |
"source": [ | |
"verbose = True\n", | |
"lambda1 = 0 # regularization term (no regularization)\n", | |
"max_it = 100 # maximum number of iterations\n", | |
"L0 = 0.1\n", | |
"tol = 1e-5\n", | |
"intercept = False\n", | |
"pos = False\n", | |
"compute_gram = True\n", | |
"\n", | |
"loss = 'square'\n", | |
"regul = 'graph'\n", | |
"tic = time.time()\n", | |
"\n", | |
"(W, optim_info) = spams.fistaGraph(Y, X, W0, graph, return_optim_info=True, loss=loss, regul=regul, verbose=verbose)\n", | |
"\n", | |
"tac = time.time()\n", | |
"t = tac - tic\n", | |
"\n", | |
"print('mean loss: %f, mean relative duality_gap: %f, time: %f, number of iterations: %f' %(np.mean(optim_info[0,:]),np.mean(optim_info[2,:]),t,np.mean(optim_info[3,:])))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Output:\n", | |
"# W: double dense p x n matrix or p x Nn matrix (for multi-logistic loss)\n", | |
"# optim: optional, double dense 4 x n matrix.\n", | |
"# first row: values of the objective functions.\n", | |
"# third row: values of the relative duality gap (if available)\n", | |
"# fourth row: number of iterations\n", | |
"# optim_info: vector of size 4, containing information of the optimization.\n", | |
"# W = spams.fistaGraph(Y,X,W0,graph,return_optim_info = False,...)\n", | |
"# (W,optim_info) = spams.fistaGraph(Y,X,W0,graph,return_optim_info = True,...)\n", | |
"#" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[0, 2, 11, 12, 14]" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"signal_nodes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>0</th>\n", | |
" <th>1</th>\n", | |
" <th>2</th>\n", | |
" <th>3</th>\n", | |
" <th>4</th>\n", | |
" <th>5</th>\n", | |
" <th>6</th>\n", | |
" <th>7</th>\n", | |
" <th>8</th>\n", | |
" <th>9</th>\n", | |
" <th>...</th>\n", | |
" <th>20</th>\n", | |
" <th>21</th>\n", | |
" <th>22</th>\n", | |
" <th>23</th>\n", | |
" <th>24</th>\n", | |
" <th>25</th>\n", | |
" <th>26</th>\n", | |
" <th>27</th>\n", | |
" <th>28</th>\n", | |
" <th>29</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.0</td>\n", | |
" <td>-0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.0</td>\n", | |
" <td>-0.0</td>\n", | |
" <td>-0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>...</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.0</td>\n", | |
" <td>-0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>1 rows × 30 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 0 1 2 3 4 5 6 7 8 9 ... 20 21 22 23 \\\n", | |
"0 0.0 -0.0 0.0 -0.0 -0.0 0.0 -0.0 -0.0 -0.0 0.0 ... 0.0 -0.0 0.0 0.0 \n", | |
"\n", | |
" 24 25 26 27 28 29 \n", | |
"0 -0.0 -0.0 0.0 0.0 0.0 0.0 \n", | |
"\n", | |
"[1 rows x 30 columns]" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W = pd.DataFrame(W)\n", | |
"W.T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0 False\n", | |
"dtype: bool" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W[W != 0].any()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment