Last active
March 7, 2020 07:54
-
-
Save jonathan-taylor/a4311d4c0f662c4e97f99475f389ef90 to your computer and use it in GitHub Desktop.
Comparison to regular sparse group LASSO
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"lines_to_next_cell": 2 | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(['X_list', 'y_list', 'censor_list'], dtype='<U11')" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"from regreg.smooth.cox import cox_loglike\n", | |
"import regreg.api as rr\n", | |
"import regreg.affine as ra\n", | |
"%load_ext rpy2.ipython\n", | |
"%R load('instance.RData')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def load_data(idx):\n", | |
" %R -i idx -o X X = X_list[[idx]]\n", | |
" %R -o Y Y = y_list[[idx]]\n", | |
" %R -o C C = censor_list[[idx]]\n", | |
" return X, Y, C\n", | |
"datasets = [load_data(idx) for idx in range(1, 21)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"losses = [rr.cox_loglike(Y.shape[0], Y.reshape(-1), C.reshape(-1), coef=1./Y.shape[0]) for _, Y, C in datasets]\n", | |
"Xblock = ra.block_transform([X for X, _, _ in datasets])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((5000, 20), (22100,))" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Xblock.input_shape, Xblock.output_shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(22100,)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"class cox_stacked(rr.smooth_atom):\n", | |
"\n", | |
" def __init__(self,\n", | |
" losses,\n", | |
" X,\n", | |
" quadratic=None, \n", | |
" initial=None,\n", | |
" offset=None):\n", | |
" \n", | |
" self.losses = losses\n", | |
" self.ndisease = len(losses)\n", | |
" self.nfeature = X.shape[0]\n", | |
"\n", | |
" self.X, self.X_T = X, X.T\n", | |
" \n", | |
" rr.smooth_atom.__init__(self,\n", | |
" self.X.output_shape,\n", | |
" offset=offset,\n", | |
" quadratic=quadratic,\n", | |
" initial=initial)\n", | |
" self._gradient = np.zeros(X.output_shape)\n", | |
"\n", | |
" def smooth_objective(self, arg, mode='both', check_feasibility=False):\n", | |
"\n", | |
" arg = self.apply_offset(arg) # (nfeature, ndisease)\n", | |
" linpred = self.X.dot(arg) # (ndisease, ncase)\n", | |
" if mode == 'grad':\n", | |
" for d, slice in enumerate(self.X._slices):\n", | |
" self._gradient[slice] = self.losses[d].smooth_objective(linpred[slice], 'grad')\n", | |
" return self.scale(self.X_T.dot(self._gradient))\n", | |
" elif mode == 'func':\n", | |
" value = 0\n", | |
" for d, slice in enumerate(self.X._slices):\n", | |
" value += self.losses[d].smooth_objective(linpred[slice], 'func')\n", | |
" return self.scale(value)\n", | |
" elif mode == 'both':\n", | |
" value = 0\n", | |
" for d, slice in enumerate(self.X._slices):\n", | |
" f, g = self.losses[d].smooth_objective(linpred[slice], 'both')\n", | |
" self._gradient[slice] = g\n", | |
" value += f\n", | |
" return self.scale(value), self.scale(self.X_T.dot(self._gradient))\n", | |
" else:\n", | |
" raise ValueError(\"mode incorrectly specified\")\n", | |
"\n", | |
"loss = cox_stacked(losses, Xblock)\n", | |
"loss.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Check the loss can be computed\n", | |
"\n", | |
"- We'll use `G` to compute $\\lambda_{\\max}$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(5000, 20)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"F, G = loss.smooth_objective(np.zeros(Xblock.input_shape), 'both')\n", | |
"G.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.009502477943897247" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"nfeature = Xblock.input_shape[0]\n", | |
"alpha = 0.95\n", | |
"penalty = rr.sparse_group_block(Xblock.input_shape, l1_weight=alpha, \n", | |
" l2_weight=(1-alpha)*np.sqrt(nfeature), lagrange=1.)\n", | |
"dual = penalty.conjugate\n", | |
"lambda_max = dual.seminorm(G, lagrange=1)\n", | |
"penalty.lagrange = lambda_max\n", | |
"lambda_max" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.0" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"soln = problem.solve(tol=1.e-9)\n", | |
"np.mean(soln == 0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## First 10 values on logscale of length 100 down to 0.01" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0.00950248, 0.00907058, 0.0086583 , 0.00826477, 0.00788912,\n", | |
" 0.00753055, 0.00718828, 0.00686156, 0.00654969, 0.006252 ])" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lagrange_vals = np.exp(np.linspace(0, np.log(0.01), 100))[:10] * lambda_max\n", | |
"lagrange_vals" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Timing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 0 0.009502477943897247 0.009502477943897247\n", | |
"5 33 0.00907132774591446 0.009070575655810235\n", | |
"17 119 0.008658461272716522 0.008658303993288064\n", | |
"34 248 0.008264981210231781 0.008264770714102115\n", | |
"76 587 0.007889382541179657 0.0078891241298101\n", | |
"149 1230 0.007531158626079559 0.0075305512625238645\n", | |
"235 2018 0.0071884579956531525 0.007188276085454979\n", | |
"362 3209 0.0068618617951869965 0.0068615578434302205\n", | |
"504 4537 0.006549973040819168 0.00654968944974222\n", | |
"710 6535 0.006252247840166092 0.006251995955865733\n", | |
"time: 43.948302\n" | |
] | |
} | |
], | |
"source": [ | |
"from time import time\n", | |
"toc = time()\n", | |
"solns = []\n", | |
"problem.coefs[:] = 0\n", | |
"for lagrange in lagrange_vals:\n", | |
" penalty.lagrange = lagrange\n", | |
" soln = problem.solve(tol=1.e-12)\n", | |
" solns.append(soln.copy())\n", | |
" print(np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), lagrange)\n", | |
"tic = time()\n", | |
"print('time: %f' % (tic-toc))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## What if $\\alpha=1$?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.040050357580184937" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"nfeature = Xblock.input_shape[0]\n", | |
"alpha = 1\n", | |
"penalty = rr.sparse_group_block(Xblock.input_shape, l1_weight=alpha, \n", | |
" l2_weight=(1-alpha)*np.sqrt(nfeature), lagrange=1.)\n", | |
"dual = penalty.conjugate\n", | |
"lambda_max = dual.seminorm(G, lagrange=1)\n", | |
"penalty.lagrange = lambda_max\n", | |
"lambda_max" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.0" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"soln = problem.solve(tol=1.e-9)\n", | |
"np.mean(soln == 0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0.04005036, 0.03823001, 0.03649239, 0.03483376, 0.03325051,\n", | |
" 0.03173922, 0.03029663, 0.0289196 , 0.02760516, 0.02635046])" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lagrange_vals = np.exp(np.linspace(0, np.log(0.01), 100))[:10] * lambda_max\n", | |
"lagrange_vals" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 0 0.040050357580184937 0.040050357580184937\n", | |
"3 3 0.03822973370552063 0.03823000701692012\n", | |
"5 5 0.03649678826332092 0.03649239419617218\n", | |
"11 11 0.034835249185562134 0.03483375855985143\n", | |
"15 15 0.033251434564590454 0.033250510473037134\n", | |
"25 26 0.03173968195915222 0.03173922345525575\n", | |
"57 58 0.03029760718345642 0.03029662676485945\n", | |
"101 103 0.028920933604240417 0.028919598320456204\n", | |
"163 165 0.02760659158229828 0.02760515794407164\n", | |
"237 241 0.02635185420513153 0.026350460911419748\n", | |
"time: 30.266288\n" | |
] | |
} | |
], | |
"source": [ | |
"from time import time\n", | |
"toc = time()\n", | |
"solns = []\n", | |
"problem.coefs[:] = 0\n", | |
"for lagrange in lagrange_vals:\n", | |
" penalty.lagrange = lagrange\n", | |
" soln = problem.solve(tol=1.e-12)\n", | |
" solns.append(soln.copy())\n", | |
" print(np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), lagrange)\n", | |
"tic = time()\n", | |
"print('time: %f' % (tic-toc))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"jupytext": { | |
"cell_metadata_filter": "all,-slideshow", | |
"formats": "ipynb,Rmd" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment