Skip to content

Instantly share code, notes, and snippets.

@tansey
Last active May 11, 2019 18:38
Show Gist options
  • Save tansey/29b7db2afd3f546ceee753fcad09aae7 to your computer and use it in GitHub Desktop.
Save tansey/29b7db2afd3f546ceee753fcad09aae7 to your computer and use it in GitHub Desktop.
Pool adjacent violators algorithm for monotone matrix factorization
'''Pool adjacent violators algorithm for (column-)monotone matrix factorization.
Applies the PAV algorithm to column factors of a matrix factorization:
Given: M = W.V'
Returns: V_proj, a projected version of V such that M[i] is monotone decreasing
for all i.
Author: Wesley Tansey
Date: May 2019
'''
import numpy as np
def factor_pav(W, V, in_place=False):
'''Applies the pool adjacent violators (PAV) algorithm to the V vectors,
ensuring the W_i . V is monotone decreasing for all i.'''
# Reconstruct the matrix
if not in_place:
V = np.copy(V)
M = W.dot(V.T)
# check all rows for monotonicity constraint
violators = (M[:,:-1] - M[:,1:]) < 0
q = np.arange(V.shape[0])
while np.any(violators):
j = 0
while j < V.shape[0]-1:
# Reconstruct the current 2 columns
M_j = W.dot(V[j:j+2].T)
# Check for any violations
if np.any((M_j[:,0] - M_j[:,1]) < 0):
# Merge the two pools together by doing a weighted average
pool0 = q == q[j]
pool1 = q == q[j+1]
w0 = pool0.sum()
w1 = pool1.sum()
V[pool0 | pool1] = (w0*V[j] + w1*V[j+1]) / (w0+w1)
q[pool1] = q[j]
j += w1
else:
j += 1
# Check for new violators
M = W.dot(V.T)
violators = (M[:,:-1] - M[:,1:]) < 0
return V
if __name__ == '__main__':
import matplotlib.pyplot as plt
nrows = 4
ncols = 20
nembed = 5
W = np.random.gamma(1,1,size=(nrows, nembed))
V = np.random.gamma(1,1,size=(ncols, nembed)).cumsum(axis=1)[::-1] + np.random.gamma(0.1,0.1,size=(ncols, nembed))
# Project to montone curves
V_proj = factor_pav(W, V)
fig, axarr = plt.subplots(1, nrows, figsize=(nrows*5, 5))
x = np.arange(ncols)
M = W.dot(V.T)
M_proj = W.dot(V_proj.T)
for i in range(nrows):
axarr[i].scatter(x, M[i], color='gray', alpha=0.5)
axarr[i].plot(x, M_proj[i], color='blue')
plt.savefig('plots/factor-pav.pdf', bbox_inches='tight')
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment