Last active
January 23, 2016 14:02
-
-
Save agramfort/68ce1a4a142afeb6f63e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
from scipy import linalg, io, sparse | |
import matplotlib.pyplot as plt | |
from sklearn.externals.joblib import Memory | |
from sklearn.linear_model import lasso_path | |
from sklearn.datasets.mldata import fetch_mldata | |
from sklearn import datasets | |
markers = ['s', 'd', '^', 'v', '<', 'p', '>'] | |
colors = plt.rcParams['axes.color_cycle'] | |
plt.close('all') | |
import sys | |
dataset_id = 0 | |
if len(sys.argv) > 1: | |
dataset_id = int(sys.argv[1]) | |
eps = 1e-2 # the smaller it is the longer is the path | |
selection = 'cyclic' | |
# selection = 'random' | |
if dataset_id == 0: | |
dataset_name = 'leukemia' | |
data = fetch_mldata(dataset_name) | |
X = data.data | |
y = data.target | |
X = X.astype(float) | |
y = y.astype(float) | |
y /= linalg.norm(y) | |
elif dataset_id == 1: | |
dataset_name = 'news' | |
def get_Xy_news(): | |
data = datasets.fetch_20newsgroups(categories=['comp.graphics', | |
'talk.religion.misc']) | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
vect = TfidfVectorizer(max_df=0.95, min_df=2, stop_words='english') | |
X = vect.fit_transform(data.data) | |
y = data.target.astype(np.float) | |
y[y == 0] = -1. | |
return X, y | |
mem = Memory('.') | |
X, y, = mem.cache(get_Xy_news)() | |
elif dataset_id == 2: | |
dataset_name = 'rcv1' | |
data = io.loadmat('rcv1_train.binary.mat') | |
X = data['X'].astype(np.float) | |
y = data['y'].astype(np.float).ravel() | |
eps = 1e-2 # the smaller it is the longer is the path | |
if sparse.issparse(X): | |
pass | |
else: | |
# Standardize data (easier to set the l1_ratio parameter) | |
X /= np.sqrt(np.sum(X ** 2, axis=0)) | |
mask = np.sum(np.isnan(X), axis=0) == 0 | |
if np.any(mask): | |
X = X[:, mask] | |
print(X.shape) | |
y /= linalg.norm(y) # to correct for tol scaling | |
# Compute paths | |
tols = range(2, 11, 2) | |
# screenings = [0, 10] | |
screenings = [0, 5, 10, 20] | |
positive = False | |
screenings_names = [] | |
for screening in screenings: | |
if screening == 0: | |
screenings_names.append('No screening') | |
else: | |
screenings_names.append("GAP SAFE (scr. iter %d)" % screening) | |
times = np.zeros((len(screenings), len(tols))) | |
plt.close('all') | |
for itol, tol in enumerate(tols): | |
plt.figure() | |
for iscreening, screening in enumerate(screenings): | |
begin = time.time() | |
print("Computing regularization path using the lasso " | |
"with screening=%d..." % screening) | |
alphas, coefs, gaps = lasso_path(X, y, eps, n_alphas=100, | |
precompute=False, tol=10 ** (-tol), | |
verbose=0, max_iter=3000, | |
screening=screening, | |
selection=selection, | |
positive=positive) | |
duration = time.time() - begin | |
print(duration) | |
times[iscreening, itol] = duration | |
# Display results | |
gap_lasso = np.maximum(np.abs(gaps), 1e-15) | |
l1 = plt.plot(-np.log10(alphas / np.max(alphas)), np.log10(gaps.T), | |
label=screenings_names[iscreening], | |
marker=markers[iscreening]) | |
plt.axhline(-tol, linestyle='--', color='k') | |
plt.xlabel('-Log(alpha)') | |
plt.ylabel('Log(gap)') | |
plt.axis('tight') | |
plt.legend(loc="lower right") | |
plt.show() | |
time.sleep(0.1) | |
plt.tight_layout() | |
plt.show() | |
import pandas as pd | |
df = pd.DataFrame(times.T, columns=screenings_names) | |
fig, ax = plt.subplots(1, 1, figsize=(9, 6)) | |
df.plot(kind='bar', ax=ax) | |
plt.xticks(range(len(tols)), [str(t) for t in tols]) | |
plt.xlabel("-log10(duality gap)") | |
plt.ylabel("Time (s)") | |
plt.title("Lasso on %s (selection=%s)" % (dataset_name, selection)) | |
plt.show() | |
time.sleep(0.2) | |
plt.tight_layout() | |
plt.show() | |
leg = plt.legend(frameon=True, loc='upper left') | |
leg.get_frame().set_alpha(0.5) | |
plt.savefig('img/lasso_bench_%s_%s.png' % (dataset_name, selection)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment