Skip to content

Instantly share code, notes, and snippets.

@agramfort
Created January 31, 2016 14:33
Show Gist options
  • Save agramfort/ac52a57dc6551138e89b to your computer and use it in GitHub Desktop.
Save agramfort/ac52a57dc6551138e89b to your computer and use it in GitHub Desktop.
Lasso with ISTA and FISTA
#!/usr/bin/env python
#
# Solve LASSO regression problem with ISTA and FISTA
# iterative solvers.
# Author : Alexandre Gramfort, [email protected]
# License BSD
import time
from math import sqrt
import numpy as np
from scipy import linalg
rng = np.random.RandomState(42)
m, n = 15, 20
# random design
A = rng.randn(m, n) # random design
x0 = rng.rand(n)
x0[x0 < 0.9] = 0
b = np.dot(A, x0)
l = 0.5 # regularization parameter
def soft_thresh(x, l):
return np.sign(x) * np.maximum(np.abs(x) - l, 0.)
def ista(A, b, l, maxit):
x = np.zeros(A.shape[1])
pobj = []
L = linalg.norm(A) ** 2 # Lipschitz constant
time0 = time.time()
for _ in xrange(maxit):
x = soft_thresh(x + np.dot(A.T, b - A.dot(x)) / L, l / L)
this_pobj = 0.5 * linalg.norm(A.dot(x) - b) ** 2 + l * linalg.norm(x, 1)
pobj.append((time.time() - time0, this_pobj))
times, pobj = map(np.array, zip(*pobj))
return x, pobj, times
def fista(A, b, l, maxit):
x = np.zeros(A.shape[1])
pobj = []
t = 1
z = x.copy()
L = linalg.norm(A) ** 2
time0 = time.time()
for _ in xrange(maxit):
xold = x.copy()
z = z + A.T.dot(b - A.dot(z)) / L
x = soft_thresh(z, l / L)
t0 = t
t = (1. + sqrt(1. + 4. * t ** 2)) / 2.
z = x + ((t0 - 1.) / t) * (x - xold)
this_pobj = 0.5 * linalg.norm(A.dot(x) - b) ** 2 + l * linalg.norm(x, 1)
pobj.append((time.time() - time0, this_pobj))
times, pobj = map(np.array, zip(*pobj))
return x, pobj, times
maxit = 3000
x_ista, pobj_ista, times_ista = ista(A, b, l, maxit)
x_fista, pobj_fista, times_fista = fista(A, b, l, maxit)
import matplotlib.pyplot as plt
plt.close('all')
plt.figure()
plt.stem(x0, markerfmt='go')
plt.stem(x_ista, markerfmt='bo')
plt.stem(x_fista, markerfmt='ro')
plt.figure()
plt.plot(times_ista, pobj_ista, label='ista')
plt.plot(times_fista, pobj_fista, label='fista')
plt.xlabel('Time')
plt.ylabel('Primal')
plt.legend()
plt.show()
@louity
Copy link

louity commented Apr 1, 2019

In order to have the fastest possible converge, one has to take the smallest possible lipschitz constant L in the Fista algorithm. The frobenius norm square of D is a valid lipishitz constant, but a much smaller one is the largest eigenvalue of D^T D, i.e. the square of the largest singular value of D.
In numpy according to the docs it would be :
L = linalg.norm(A, ord=2) **2
I've made the change and it converges significantly faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment