Created
January 31, 2016 14:33
-
-
Save agramfort/ac52a57dc6551138e89b to your computer and use it in GitHub Desktop.
Lasso with ISTA and FISTA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# | |
# Solve LASSO regression problem with ISTA and FISTA | |
# iterative solvers. | |
# Author : Alexandre Gramfort, [email protected] | |
# License BSD | |
import time | |
from math import sqrt | |
import numpy as np | |
from scipy import linalg | |
rng = np.random.RandomState(42) | |
m, n = 15, 20 | |
# random design | |
A = rng.randn(m, n) # random design | |
x0 = rng.rand(n) | |
x0[x0 < 0.9] = 0 | |
b = np.dot(A, x0) | |
l = 0.5 # regularization parameter | |
def soft_thresh(x, l): | |
return np.sign(x) * np.maximum(np.abs(x) - l, 0.) | |
def ista(A, b, l, maxit): | |
x = np.zeros(A.shape[1]) | |
pobj = [] | |
L = linalg.norm(A) ** 2 # Lipschitz constant | |
time0 = time.time() | |
for _ in xrange(maxit): | |
x = soft_thresh(x + np.dot(A.T, b - A.dot(x)) / L, l / L) | |
this_pobj = 0.5 * linalg.norm(A.dot(x) - b) ** 2 + l * linalg.norm(x, 1) | |
pobj.append((time.time() - time0, this_pobj)) | |
times, pobj = map(np.array, zip(*pobj)) | |
return x, pobj, times | |
def fista(A, b, l, maxit): | |
x = np.zeros(A.shape[1]) | |
pobj = [] | |
t = 1 | |
z = x.copy() | |
L = linalg.norm(A) ** 2 | |
time0 = time.time() | |
for _ in xrange(maxit): | |
xold = x.copy() | |
z = z + A.T.dot(b - A.dot(z)) / L | |
x = soft_thresh(z, l / L) | |
t0 = t | |
t = (1. + sqrt(1. + 4. * t ** 2)) / 2. | |
z = x + ((t0 - 1.) / t) * (x - xold) | |
this_pobj = 0.5 * linalg.norm(A.dot(x) - b) ** 2 + l * linalg.norm(x, 1) | |
pobj.append((time.time() - time0, this_pobj)) | |
times, pobj = map(np.array, zip(*pobj)) | |
return x, pobj, times | |
maxit = 3000 | |
x_ista, pobj_ista, times_ista = ista(A, b, l, maxit) | |
x_fista, pobj_fista, times_fista = fista(A, b, l, maxit) | |
import matplotlib.pyplot as plt | |
plt.close('all') | |
plt.figure() | |
plt.stem(x0, markerfmt='go') | |
plt.stem(x_ista, markerfmt='bo') | |
plt.stem(x_fista, markerfmt='ro') | |
plt.figure() | |
plt.plot(times_ista, pobj_ista, label='ista') | |
plt.plot(times_fista, pobj_fista, label='fista') | |
plt.xlabel('Time') | |
plt.ylabel('Primal') | |
plt.legend() | |
plt.show() |
@Mullahz The SpaMS (Sparse Modeling Software) implements Dictionary Learning and sparse decomposition algorithms, it's written in C++ with MATLAB interface.
http://spams-devel.gforge.inria.fr/
In order to have the fastest possible converge, one has to take the smallest possible lipschitz constant L in the Fista algorithm. The frobenius norm square of D is a valid lipishitz constant, but a much smaller one is the largest eigenvalue of D^T D, i.e. the square of the largest singular value of D.
In numpy according to the docs it would be :
L = linalg.norm(A, ord=2) **2
I've made the change and it converges significantly faster.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
is there similar kind of this code in MATLAB?
or if anyone has MATLAB version of this program, kindly share here.