Created
June 10, 2016 06:43
-
-
Save fabianp/8b62d905ae254fadf763cf1909904e3a to your computer and use it in GitHub Desktop.
1D total variation (also known as fussed lasso) proximal operator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import njit | |
@njit | |
def prox_tv1d(w, stepsize): | |
""" | |
Parameters | |
---------- | |
w: array | |
vector of coefficieents | |
stepsize: float | |
step size (sometimes denoted gamma) in proximal objective function | |
References | |
---------- | |
Condat, Laurent. "A direct algorithm for 1D total variation denoising." | |
IEEE Signal Processing Letters (2013) | |
""" | |
width = w.size | |
# /to avoid invalid memory access to input[0] and invalid lambda values | |
if width > 0 and stepsize >= 0: | |
k, k0 = 0, 0 # k: current sample location, k0: beginning of current segment | |
umin = stepsize # u is the dual variable | |
umax = - stepsize | |
vmin = w[0] - stepsize | |
vmax = w[0] + stepsize # bounds for the segment's value | |
kplus = 0 | |
kminus = 0 # last positions where umax=-lambda, umin=lambda, respectively | |
twolambda = 2.0 * stepsize # auxiliary variable | |
minlambda = -stepsize # auxiliary variable | |
while True: # simple loop, the exit test is inside | |
while k == width-1: # we use the right boundary condition | |
if umin < 0.0: # vmin is too high -> negative jump necessary | |
while True: | |
w[k0] = vmin | |
k0 += 1 | |
if k0 > kminus: | |
break | |
k = k0 | |
kminus = k | |
vmin = w[kminus] | |
umin = stepsize | |
umax = vmin + umin - vmax | |
elif umax > 0.0: # vmax is too low -> positive jump necessary | |
while True: | |
w[k0] = vmax | |
k0 += 1 | |
if k0 > kplus: | |
break | |
k = k0 | |
kplus = k | |
vmax = w[kplus] | |
umax = minlambda | |
umin = vmax + umax -vmin | |
else: | |
vmin += umin / (k-k0+1) | |
while True: | |
w[k0] = vmin | |
k0 += 1 | |
if k0 > k: | |
break | |
return | |
umin += w[k + 1] - vmin | |
if umin < minlambda: # /*negative jump necessary*/ | |
while True: | |
w[k0] = vmin | |
k0 += 1 | |
if k0 > kminus: | |
break | |
k = k0 | |
kminus = k | |
kplus = kminus | |
vmin = w[kplus] | |
vmax = vmin + twolambda | |
umin = stepsize | |
umax = minlambda | |
else: | |
umax += w[k + 1] - vmax | |
if umax > stepsize: | |
while True: | |
w[k0] = vmax | |
k0 += 1 | |
if k0 > kplus: | |
break | |
k = k0 | |
kminus = k | |
kplus = kminus | |
vmax = w[kplus] | |
vmin = vmax - twolambda | |
umin = stepsize | |
umax = minlambda | |
else: # /*no jump necessary, we continue*/ | |
k += 1 | |
if umin >= stepsize: # update of vmin | |
kminus = k | |
vmin += (umin - stepsize) / (kminus - k0 + 1) | |
umin = stepsize | |
if umax <= minlambda: # update of vmax | |
kplus = k | |
vmax += (umax + stepsize) / (kplus - k0 + 1) | |
umax = minlambda |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment