Skip to content

Instantly share code, notes, and snippets.

@albertotb
Created March 17, 2016 11:07
Show Gist options
  • Save albertotb/73be447b6ee95913fa62 to your computer and use it in GitHub Desktop.
Save albertotb/73be447b6ee95913fa62 to your computer and use it in GitHub Desktop.
#!/usr/bin/env julia
soft_thresholding(x, α) = sign(x).*max(abs(x)-α, 0)
function lasso(X, y, λ, maxiter=10000, tol=1e-9)
keep_going = true
(n, m) = size(X)
G = X'*X
c = X'*y
λ = λ*n
ev, = eigs(G, nev=1)
L = ev[1]
w = zeros(m)
t = 1
f = 1/2*norm(X*w - y) + λ*norm(w, 1)
v = w
iter = 0
while(keep_going && iter < maxiter)
w_prev = w
t_prev = t
f_prev = f
grad = G*v - c
w = soft_thresholding(v - (1/L)*grad, λ/L)
t = (1 + sqrt(1 + 4*t_prev^2)) / 2
v = w + (t_prev-1)/t * (w-w_prev)
f = 1/2*norm(X*w - y) + λ*norm(w, 1)
keep_going = 2*abs(f-f_prev) > tol*(abs(f)+abs(f_prev)+1e-10)
iter += 1
end
return w, iter
end
if length(ARGS) != 2
println("usage: lasso TRAIN LAMBDA")
exit(1)
end
@time train = readdlm(ARGS[1], ' ')
# standarize (note Julia automatic broadcasting)
train = (train .- mean(train, 1)) ./ std(train, 1)
X_train = train[:, 2:end]
y_train = train[:, 1]
λ = float(ARGS[2])
println("Train file: $(ARGS[1])")
println("Lambda: $λ")
@time w, iter = lasso(X_train, y_train, λ)
mae = mean(abs(X_train*w - y_train))
println("MAE: $mae")
println("Active weights: $(length(find(w)))/$(length(w))")
println("Iter: $iter")
println("Weights: $w")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment