Skip to content

Instantly share code, notes, and snippets.

@agramfort
Created January 31, 2016 14:33
Show Gist options
  • Select an option

  • Save agramfort/ac52a57dc6551138e89b to your computer and use it in GitHub Desktop.

Select an option

Save agramfort/ac52a57dc6551138e89b to your computer and use it in GitHub Desktop.
Lasso with ISTA and FISTA
#!/usr/bin/env python
#
# Solve LASSO regression problem with ISTA and FISTA
# iterative solvers.
# Author : Alexandre Gramfort, first.last@telecom-paristech.fr
# License BSD
import time
from math import sqrt
import numpy as np
from scipy import linalg
rng = np.random.RandomState(42)
m, n = 15, 20
# random design
A = rng.randn(m, n) # random design
x0 = rng.rand(n)
x0[x0 < 0.9] = 0
b = np.dot(A, x0)
l = 0.5 # regularization parameter
def soft_thresh(x, l):
return np.sign(x) * np.maximum(np.abs(x) - l, 0.)
def ista(A, b, l, maxit):
x = np.zeros(A.shape[1])
pobj = []
L = linalg.norm(A) ** 2 # Lipschitz constant
time0 = time.time()
for _ in xrange(maxit):
x = soft_thresh(x + np.dot(A.T, b - A.dot(x)) / L, l / L)
this_pobj = 0.5 * linalg.norm(A.dot(x) - b) ** 2 + l * linalg.norm(x, 1)
pobj.append((time.time() - time0, this_pobj))
times, pobj = map(np.array, zip(*pobj))
return x, pobj, times
def fista(A, b, l, maxit):
x = np.zeros(A.shape[1])
pobj = []
t = 1
z = x.copy()
L = linalg.norm(A) ** 2
time0 = time.time()
for _ in xrange(maxit):
xold = x.copy()
z = z + A.T.dot(b - A.dot(z)) / L
x = soft_thresh(z, l / L)
t0 = t
t = (1. + sqrt(1. + 4. * t ** 2)) / 2.
z = x + ((t0 - 1.) / t) * (x - xold)
this_pobj = 0.5 * linalg.norm(A.dot(x) - b) ** 2 + l * linalg.norm(x, 1)
pobj.append((time.time() - time0, this_pobj))
times, pobj = map(np.array, zip(*pobj))
return x, pobj, times
maxit = 3000
x_ista, pobj_ista, times_ista = ista(A, b, l, maxit)
x_fista, pobj_fista, times_fista = fista(A, b, l, maxit)
import matplotlib.pyplot as plt
plt.close('all')
plt.figure()
plt.stem(x0, markerfmt='go')
plt.stem(x_ista, markerfmt='bo')
plt.stem(x_fista, markerfmt='ro')
plt.figure()
plt.plot(times_ista, pobj_ista, label='ista')
plt.plot(times_fista, pobj_fista, label='fista')
plt.xlabel('Time')
plt.ylabel('Primal')
plt.legend()
plt.show()
@Mullahz
Copy link
Copy Markdown

Mullahz commented Mar 9, 2018

is there similar kind of this code in MATLAB?
or if anyone has MATLAB version of this program, kindly share here.

@okbalefthanded
Copy link
Copy Markdown

@Mullahz The SpaMS (Sparse Modeling Software) implements Dictionary Learning and sparse decomposition algorithms, it's written in C++ with MATLAB interface.
http://spams-devel.gforge.inria.fr/

@louity
Copy link
Copy Markdown

louity commented Apr 1, 2019

In order to have the fastest possible converge, one has to take the smallest possible lipschitz constant L in the Fista algorithm. The frobenius norm square of D is a valid lipishitz constant, but a much smaller one is the largest eigenvalue of D^T D, i.e. the square of the largest singular value of D.
In numpy according to the docs it would be :
L = linalg.norm(A, ord=2) **2
I've made the change and it converges significantly faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment