Last active
March 7, 2020 07:54
-
-
Save jonathan-taylor/a4311d4c0f662c4e97f99475f389ef90 to your computer and use it in GitHub Desktop.
Comparison to regular sparse group LASSO
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"lines_to_next_cell": 2 | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from regreg.smooth.cox import cox_loglike\n", | |
"import regreg.api as rr" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"ncase, nfeature, ndisease = 1000, 5000, 20\n", | |
"losses = []\n", | |
"for _ in range(ndisease):\n", | |
" times = np.random.exponential(size=(ncase,))\n", | |
" censoring = np.array([0]*int(0.3*ncase) + [1]*int(0.7*ncase))\n", | |
" np.random.shuffle(censoring)\n", | |
" losses.append(cox_loglike((ncase,),\n", | |
" times,\n", | |
" censoring))\n", | |
"X = np.random.standard_normal((ncase, nfeature))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class cox_stacked(rr.smooth_atom): \n", | |
" \n", | |
" def __init__(self,\n", | |
" losses,\n", | |
" X,\n", | |
" quadratic=None, \n", | |
" initial=None,\n", | |
" offset=None):\n", | |
" \n", | |
" self.losses = losses\n", | |
" self.ndisease = len(losses)\n", | |
" self.ncase, self.nfeature = X.shape\n", | |
"\n", | |
" self.X, self.X_T = X, X.T\n", | |
"\n", | |
" assert(np.all(np.array([loss.shape[0] for loss in losses]) == self.ncase))\n", | |
" \n", | |
" rr.smooth_atom.__init__(self,\n", | |
" (self.nfeature, self.ndisease),\n", | |
" offset=offset,\n", | |
" quadratic=quadratic,\n", | |
" initial=initial)\n", | |
" self._gradient = np.zeros((self.ndisease, self.ncase))\n", | |
"\n", | |
" def smooth_objective(self, arg, mode='both', check_feasibility=False):\n", | |
"\n", | |
" arg = self.apply_offset(arg) # (nfeature, ndisease)\n", | |
" linpred = arg.T.dot(self.X_T) # (ndisease, ncase)\n", | |
" if mode == 'grad':\n", | |
" for d in range(self.ndisease):\n", | |
" self._gradient[d] = self.losses[d].smooth_objective(linpred[d], 'grad')\n", | |
" return self.scale(self._gradient.dot(self.X).T)\n", | |
" elif mode == 'func':\n", | |
" value = 0\n", | |
" for d in range(self.ndisease):\n", | |
" value += self.losses[d].smooth_objective(linpred[d], 'func')\n", | |
" return self.scale(value)\n", | |
" elif mode == 'both':\n", | |
" value = 0\n", | |
" for d in range(self.ndisease):\n", | |
" f, g = self.losses[d].smooth_objective(linpred[d], 'both')\n", | |
" self._gradient[d] = g\n", | |
" value += f\n", | |
" return self.scale(value), self.scale(self._gradient.dot(self.X).T)\n", | |
" else:\n", | |
" raise ValueError(\"mode incorrectly specified\")\n", | |
"\n", | |
"loss = cox_stacked(losses, X)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Check the loss can be computed\n", | |
"\n", | |
"- We'll use `G` to compute $\\lambda_{\\max}$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"F, G = loss.smooth_objective(np.zeros((nfeature, ndisease)), 'both')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"25.649429321289062" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"penalty = rr.sparse_group_block(loss.shape, l1_weight=1., l2_weight=np.sqrt(ndisease), lagrange=1.)\n", | |
"dual = penalty.conjugate\n", | |
"lambda_max = dual.seminorm(G, lagrange=1)\n", | |
"penalty.lagrange = lambda_max\n", | |
"lambda_max" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Timing at $\\lambda_{\\max}$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4.67 s ± 392 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"soln = problem.solve(tol=1.e-9, min_its=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.99987" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"soln = problem.solve(tol=1.e-9, min_its=100)\n", | |
"np.mean(soln == 0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Timing at $0.9 \\lambda_{\\max}$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4.48 s ± 360 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"penalty.lagrange = 0.9 * lambda_max\n", | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"problem.coefs[:] = 0 # start at 0\n", | |
"soln = problem.solve(tol=1.e-9, min_its=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(5, 52)" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"penalty.lagrange = 0.9 * lambda_max\n", | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"soln = problem.solve(tol=1.e-9, min_its=100)\n", | |
"np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Timing at $0.8 \\lambda_{\\max}$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4.14 s ± 234 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"penalty.lagrange = 0.8 * lambda_max\n", | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"problem.coefs[:] = 0 # start at 0\n", | |
"soln = problem.solve(tol=1.e-9, min_its=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(95, 1032, 20.519546508789062, 20.51954345703125)" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"penalty.lagrange = 0.8 * lambda_max\n", | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"soln = problem.solve(tol=1.e-9, min_its=100) \n", | |
"np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), 0.8 * lambda_max" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Timing at $0.7 \\lambda_{\\max}$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4.04 s ± 148 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"penalty.lagrange = 0.7 * lambda_max\n", | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"problem.coefs[:] = 0 # start at 0\n", | |
"soln = problem.solve(tol=1.e-9, min_its=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(475, 5351, 17.954605102539062, 17.954600524902343)" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"penalty.lagrange = 0.7 * lambda_max\n", | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"problem.coefs[:] = 0\n", | |
"soln = problem.solve(tol=1.e-9, min_its=100) \n", | |
"np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), 0.7 * lambda_max" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Comparison with sparse group LASSO (not block)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"groups = np.multiply.outer(np.arange(nfeature), np.ones(ndisease)).reshape(-1)\n", | |
"group_penalty = rr.sparse_group_lasso(groups, np.ones(nfeature * ndisease), lagrange=0.7 * lambda_max)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class cox_stacked_flat(rr.smooth_atom): \n", | |
" \n", | |
" def __init__(self,\n", | |
" losses,\n", | |
" X,\n", | |
" quadratic=None, \n", | |
" initial=None,\n", | |
" offset=None):\n", | |
" \n", | |
" self.losses = losses\n", | |
" self.ndisease = len(losses)\n", | |
" self.ncase, self.nfeature = X.shape\n", | |
"\n", | |
" self.X, self.X_T = X, X.T\n", | |
"\n", | |
" assert(np.all(np.array([loss.shape[0] for loss in losses]) == self.ncase))\n", | |
" \n", | |
" rr.smooth_atom.__init__(self,\n", | |
" (self.nfeature * self.ndisease,),\n", | |
" offset=offset,\n", | |
" quadratic=quadratic,\n", | |
" initial=initial)\n", | |
" self._gradient = np.zeros((self.ndisease, self.ncase))\n", | |
"\n", | |
" def smooth_objective(self, arg, mode='both', check_feasibility=False):\n", | |
"\n", | |
" arg = self.apply_offset(arg) # (nfeature, ndisease)\n", | |
" arg = arg.reshape((self.nfeature, self.ndisease))\n", | |
" linpred = arg.T.dot(self.X_T) # (ndisease, ncase)\n", | |
" if mode == 'grad':\n", | |
" for d in range(self.ndisease):\n", | |
" self._gradient[d] = self.losses[d].smooth_objective(linpred[d], 'grad')\n", | |
" return self.scale(self._gradient.dot(self.X).T.reshape(-1))\n", | |
" elif mode == 'func':\n", | |
" value = 0\n", | |
" for d in range(self.ndisease):\n", | |
" value += self.losses[d].smooth_objective(linpred[d], 'func')\n", | |
" return self.scale(value)\n", | |
" elif mode == 'both':\n", | |
" value = 0\n", | |
" for d in range(self.ndisease):\n", | |
" f, g = self.losses[d].smooth_objective(linpred[d], 'both')\n", | |
" self._gradient[d] = g\n", | |
" value += f\n", | |
" return self.scale(value), self.scale(self._gradient.dot(self.X).T.reshape(-1))\n", | |
" else:\n", | |
" raise ValueError(\"mode incorrectly specified\")\n", | |
"\n", | |
"loss_flat = cox_stacked_flat(losses, X)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"5.27 s ± 216 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"group_penalty.lagrange = 0.7 * lambda_max\n", | |
"problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
"problem.coefs[:] = 0 # start at 0\n", | |
"soln = problem.solve(tol=1.e-9, min_its=100)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Maybe we haven't solved enough -- let's compare starting at a random point\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"group_penalty.lagrange = 0.7 * lambda_max\n", | |
"problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
"problem.coefs[:] = 0\n", | |
"soln_flat = problem.solve(tol=1.e-9, min_its=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"group_penalty.lagrange = 0.7 * lambda_max\n", | |
"problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
"problem.coefs[:] = np.random.standard_normal(group_penalty.shape) * 0.1\n", | |
"soln_flat_r = problem.solve(tol=1.e-9, min_its=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.0022612382977817364" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.linalg.norm(soln_flat - soln_flat_r) / np.linalg.norm(soln_flat)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Let's up the number of iterations a bit" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"group_penalty.lagrange = 0.7 * lambda_max\n", | |
"problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
"problem.coefs[:] = 0\n", | |
"soln_flat = problem.solve(tol=1.e-9, min_its=200)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"group_penalty.lagrange = 0.7 * lambda_max\n", | |
"problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
"problem.coefs[:] = np.random.standard_normal(group_penalty.shape) * 0.1\n", | |
"soln_flat_r = problem.solve(tol=1.e-9, min_its=200)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"6.217488899625311e-14" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.linalg.norm(soln_flat - soln_flat_r) / np.linalg.norm(soln_flat) # now we've essentially found same solution" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Time at 200 iterations\n", | |
"\n", | |
"Should be roughly double the time but could be a bit less because Lipschitz constant (inverse stepsize) may have settled down so less need for backtracking." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"10.1 s ± 265 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"group_penalty.lagrange = 0.7 * lambda_max\n", | |
"problem = rr.simple_problem(loss_flat, group_penalty)\n", | |
"problem.coefs[:] = 0 # start at 0\n", | |
"soln = problem.solve(tol=1.e-9, min_its=200)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"7.46 s ± 55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"penalty.lagrange = 0.7 * lambda_max\n", | |
"problem = rr.simple_problem(loss, penalty)\n", | |
"problem.coefs[:] = 0 # start at 0\n", | |
"soln = problem.solve(tol=1.e-9, min_its=200)" | |
] | |
} | |
], | |
"metadata": { | |
"jupytext": { | |
"cell_metadata_filter": "all,-slideshow", | |
"formats": "ipynb,Rmd" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment