Skip to content

Instantly share code, notes, and snippets.

@tansey
Created April 5, 2022 16:26
Show Gist options
  • Save tansey/e61f6194bd4212f163a373b3df38511d to your computer and use it in GitHub Desktop.
Save tansey/e61f6194bd4212f163a373b3df38511d to your computer and use it in GitHub Desktop.
1d and 2d pool adjacent violators (PAV)
import numpy as np
def pav(y):
"""
PAV uses the pair adjacent violators method to produce a monotonic
smoothing of y
translated from matlab by Sean Collins (2006) as part of the EMAP toolbox
Author : Alexandre Gramfort
license : BSD
"""
y = np.asarray(y)
assert y.ndim == 1
n_samples = len(y)
v = y.copy()
lvls = np.arange(n_samples)
lvlsets = np.c_[lvls, lvls]
while True:
deriv = np.diff(v)
if np.all(deriv >= 0):
break
viol = np.where(deriv < 0)[0]
start = lvlsets[viol[0], 0]
last = lvlsets[viol[0] + 1, 1]
s = 0
n = last - start + 1
for i in range(start, last + 1):
s += v[i]
val = s / n
for i in range(start, last + 1):
v[i] = val
lvlsets[i, 0] = start
lvlsets[i, 1] = last
return v
def pav2d(Y, tol=1e-6):
'''
Two dimensional pool adjacent violators algorithm.
Projects y to a monotone surface for 2-dimensional y arrays.
Based on Lin and Dunson (2014).
'''
Y = np.asarray(Y)
V = Y.copy()
residuals = np.zeros((2,) + Y.shape)
for i in range(200):
# Check convergence
deriv0 = np.diff(V, axis=0)
deriv1 = np.diff(V, axis=1)
if np.all(deriv0 >= -tol) and np.all(deriv1 >= -tol):
break
# Project and smooth along the x-axis
W = Y + residuals[1]
V = np.array([pav(w) for w in W])
residuals[0] = V - W
# Project and smooth along the y-axis
W = Y + residuals[0]
V = np.array([pav(w) for w in W.T]).T
residuals[1] = V - W
return V
def test_pav2d():
Mu = np.array([[0,0,1,1],
[0,0,1,2],
[1,1,1,2],
[1,2,2,2]])
X = np.random.normal(Mu, scale=1)
Z = pav2d(X)
import matplotlib.pyplot as plt
import seaborn as sns
fig, axarr = plt.subplots(1,3,figsize=(15,5), sharey=True, sharex=True)
axarr[0].imshow(Mu, interpolation='none', vmin=-1, vmax=3, cmap='gray_r')
axarr[1].imshow(X, interpolation='none', vmin=-1, vmax=3, cmap='gray_r')
axarr[2].imshow(Z, interpolation='none', vmin=-1, vmax=3, cmap='gray_r')
plt.savefig('pav2d.pdf', bbox_inches='tight')
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment