Skip to content

Instantly share code, notes, and snippets.

@jonathan-taylor
Last active March 7, 2020 07:54
Show Gist options
  • Save jonathan-taylor/a4311d4c0f662c4e97f99475f389ef90 to your computer and use it in GitHub Desktop.
Save jonathan-taylor/a4311d4c0f662c4e97f99475f389ef90 to your computer and use it in GitHub Desktop.
Comparison to regular sparse group LASSO
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"import numpy as np\n",
"from regreg.smooth.cox import cox_loglike\n",
"import regreg.api as rr"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"ncase, nfeature, ndisease = 1000, 5000, 20\n",
"losses = []\n",
"for _ in range(ndisease):\n",
" times = np.random.exponential(size=(ncase,))\n",
" censoring = np.array([0]*int(0.3*ncase) + [1]*int(0.7*ncase))\n",
" np.random.shuffle(censoring)\n",
" losses.append(cox_loglike((ncase,),\n",
" times,\n",
" censoring))\n",
"X = np.random.standard_normal((ncase, nfeature))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class cox_stacked(rr.smooth_atom): \n",
" \n",
" def __init__(self,\n",
" losses,\n",
" X,\n",
" quadratic=None, \n",
" initial=None,\n",
" offset=None):\n",
" \n",
" self.losses = losses\n",
" self.ndisease = len(losses)\n",
" self.ncase, self.nfeature = X.shape\n",
"\n",
" self.X, self.X_T = X, X.T\n",
"\n",
" assert(np.all(np.array([loss.shape[0] for loss in losses]) == self.ncase))\n",
" \n",
" rr.smooth_atom.__init__(self,\n",
" (self.nfeature, self.ndisease),\n",
" offset=offset,\n",
" quadratic=quadratic,\n",
" initial=initial)\n",
" self._gradient = np.zeros((self.ndisease, self.ncase))\n",
"\n",
" def smooth_objective(self, arg, mode='both', check_feasibility=False):\n",
"\n",
" arg = self.apply_offset(arg) # (nfeature, ndisease)\n",
" linpred = arg.T.dot(self.X_T) # (ndisease, ncase)\n",
" if mode == 'grad':\n",
" for d in range(self.ndisease):\n",
" self._gradient[d] = self.losses[d].smooth_objective(linpred[d], 'grad')\n",
" return self.scale(self._gradient.dot(self.X).T)\n",
" elif mode == 'func':\n",
" value = 0\n",
" for d in range(self.ndisease):\n",
" value += self.losses[d].smooth_objective(linpred[d], 'func')\n",
" return self.scale(value)\n",
" elif mode == 'both':\n",
" value = 0\n",
" for d in range(self.ndisease):\n",
" f, g = self.losses[d].smooth_objective(linpred[d], 'both')\n",
" self._gradient[d] = g\n",
" value += f\n",
" return self.scale(value), self.scale(self._gradient.dot(self.X).T)\n",
" else:\n",
" raise ValueError(\"mode incorrectly specified\")\n",
"\n",
"loss = cox_stacked(losses, X)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check the loss can be computed\n",
"\n",
"- We'll use `G` to compute $\\lambda_{\\max}$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"F, G = loss.smooth_objective(np.zeros((nfeature, ndisease)), 'both')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"25.649429321289062"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"penalty = rr.sparse_group_block(loss.shape, l1_weight=1., l2_weight=np.sqrt(ndisease), lagrange=1.)\n",
"dual = penalty.conjugate\n",
"lambda_max = dual.seminorm(G, lagrange=1)\n",
"penalty.lagrange = lambda_max\n",
"lambda_max"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Timing at $\\lambda_{\\max}$"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.67 s ± 392 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"problem = rr.simple_problem(loss, penalty)\n",
"soln = problem.solve(tol=1.e-9, min_its=100)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.99987"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"problem = rr.simple_problem(loss, penalty)\n",
"soln = problem.solve(tol=1.e-9, min_its=100)\n",
"np.mean(soln == 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Timing at $0.9 \\lambda_{\\max}$"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.48 s ± 360 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"penalty.lagrange = 0.9 * lambda_max\n",
"problem = rr.simple_problem(loss, penalty)\n",
"problem.coefs[:] = 0 # start at 0\n",
"soln = problem.solve(tol=1.e-9, min_its=100)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(5, 52)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"penalty.lagrange = 0.9 * lambda_max\n",
"problem = rr.simple_problem(loss, penalty)\n",
"soln = problem.solve(tol=1.e-9, min_its=100)\n",
"np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Timing at $0.8 \\lambda_{\\max}$"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.14 s ± 234 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"penalty.lagrange = 0.8 * lambda_max\n",
"problem = rr.simple_problem(loss, penalty)\n",
"problem.coefs[:] = 0 # start at 0\n",
"soln = problem.solve(tol=1.e-9, min_its=100)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(95, 1032, 20.519546508789062, 20.51954345703125)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"penalty.lagrange = 0.8 * lambda_max\n",
"problem = rr.simple_problem(loss, penalty)\n",
"soln = problem.solve(tol=1.e-9, min_its=100) \n",
"np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), 0.8 * lambda_max"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Timing at $0.7 \\lambda_{\\max}$"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.04 s ± 148 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"penalty.lagrange = 0.7 * lambda_max\n",
"problem = rr.simple_problem(loss, penalty)\n",
"problem.coefs[:] = 0 # start at 0\n",
"soln = problem.solve(tol=1.e-9, min_its=100)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(475, 5351, 17.954605102539062, 17.954600524902343)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"penalty.lagrange = 0.7 * lambda_max\n",
"problem = rr.simple_problem(loss, penalty)\n",
"problem.coefs[:] = 0\n",
"soln = problem.solve(tol=1.e-9, min_its=100) \n",
"np.sum(np.sum(soln**2, 1) > 0), np.sum(soln != 0), dual.seminorm(loss.smooth_objective(soln, 'grad'), lagrange=1.), 0.7 * lambda_max"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comparison with sparse group LASSO (not block)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"groups = np.multiply.outer(np.arange(nfeature), np.ones(ndisease)).reshape(-1)\n",
"group_penalty = rr.sparse_group_lasso(groups, np.ones(nfeature * ndisease), lagrange=0.7 * lambda_max)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class cox_stacked_flat(rr.smooth_atom): \n",
" \n",
" def __init__(self,\n",
" losses,\n",
" X,\n",
" quadratic=None, \n",
" initial=None,\n",
" offset=None):\n",
" \n",
" self.losses = losses\n",
" self.ndisease = len(losses)\n",
" self.ncase, self.nfeature = X.shape\n",
"\n",
" self.X, self.X_T = X, X.T\n",
"\n",
" assert(np.all(np.array([loss.shape[0] for loss in losses]) == self.ncase))\n",
" \n",
" rr.smooth_atom.__init__(self,\n",
" (self.nfeature * self.ndisease,),\n",
" offset=offset,\n",
" quadratic=quadratic,\n",
" initial=initial)\n",
" self._gradient = np.zeros((self.ndisease, self.ncase))\n",
"\n",
" def smooth_objective(self, arg, mode='both', check_feasibility=False):\n",
"\n",
" arg = self.apply_offset(arg) # (nfeature, ndisease)\n",
" arg = arg.reshape((self.nfeature, self.ndisease))\n",
" linpred = arg.T.dot(self.X_T) # (ndisease, ncase)\n",
" if mode == 'grad':\n",
" for d in range(self.ndisease):\n",
" self._gradient[d] = self.losses[d].smooth_objective(linpred[d], 'grad')\n",
" return self.scale(self._gradient.dot(self.X).T.reshape(-1))\n",
" elif mode == 'func':\n",
" value = 0\n",
" for d in range(self.ndisease):\n",
" value += self.losses[d].smooth_objective(linpred[d], 'func')\n",
" return self.scale(value)\n",
" elif mode == 'both':\n",
" value = 0\n",
" for d in range(self.ndisease):\n",
" f, g = self.losses[d].smooth_objective(linpred[d], 'both')\n",
" self._gradient[d] = g\n",
" value += f\n",
" return self.scale(value), self.scale(self._gradient.dot(self.X).T.reshape(-1))\n",
" else:\n",
" raise ValueError(\"mode incorrectly specified\")\n",
"\n",
"loss_flat = cox_stacked_flat(losses, X)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5.27 s ± 216 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"group_penalty.lagrange = 0.7 * lambda_max\n",
"problem = rr.simple_problem(loss_flat, group_penalty)\n",
"problem.coefs[:] = 0 # start at 0\n",
"soln = problem.solve(tol=1.e-9, min_its=100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Maybe we haven't solved enough -- let's compare starting at a random point\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"group_penalty.lagrange = 0.7 * lambda_max\n",
"problem = rr.simple_problem(loss_flat, group_penalty)\n",
"problem.coefs[:] = 0\n",
"soln_flat = problem.solve(tol=1.e-9, min_its=100)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"group_penalty.lagrange = 0.7 * lambda_max\n",
"problem = rr.simple_problem(loss_flat, group_penalty)\n",
"problem.coefs[:] = np.random.standard_normal(group_penalty.shape) * 0.1\n",
"soln_flat_r = problem.solve(tol=1.e-9, min_its=100)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0022612382977817364"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.linalg.norm(soln_flat - soln_flat_r) / np.linalg.norm(soln_flat)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Let's up the number of iterations a bit"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"group_penalty.lagrange = 0.7 * lambda_max\n",
"problem = rr.simple_problem(loss_flat, group_penalty)\n",
"problem.coefs[:] = 0\n",
"soln_flat = problem.solve(tol=1.e-9, min_its=200)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"group_penalty.lagrange = 0.7 * lambda_max\n",
"problem = rr.simple_problem(loss_flat, group_penalty)\n",
"problem.coefs[:] = np.random.standard_normal(group_penalty.shape) * 0.1\n",
"soln_flat_r = problem.solve(tol=1.e-9, min_its=200)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6.217488899625311e-14"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.linalg.norm(soln_flat - soln_flat_r) / np.linalg.norm(soln_flat) # now we've essentially found same solution"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Time at 200 iterations\n",
"\n",
"Should be roughly double the time but could be a bit less because Lipschitz constant (inverse stepsize) may have settled down so less need for backtracking."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10.1 s ± 265 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"group_penalty.lagrange = 0.7 * lambda_max\n",
"problem = rr.simple_problem(loss_flat, group_penalty)\n",
"problem.coefs[:] = 0 # start at 0\n",
"soln = problem.solve(tol=1.e-9, min_its=200)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7.46 s ± 55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"penalty.lagrange = 0.7 * lambda_max\n",
"problem = rr.simple_problem(loss, penalty)\n",
"problem.coefs[:] = 0 # start at 0\n",
"soln = problem.solve(tol=1.e-9, min_its=200)"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "all,-slideshow",
"formats": "ipynb,Rmd"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment