Last active
May 11, 2019 18:38
-
-
Save tansey/29b7db2afd3f546ceee753fcad09aae7 to your computer and use it in GitHub Desktop.
Pool adjacent violators algorithm for monotone matrix factorization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''Pool adjacent violators algorithm for (column-)monotone matrix factorization. | |
Applies the PAV algorithm to column factors of a matrix factorization: | |
Given: M = W.V' | |
Returns: V_proj, a projected version of V such that M[i] is monotone decreasing | |
for all i. | |
Author: Wesley Tansey | |
Date: May 2019 | |
''' | |
import numpy as np | |
def factor_pav(W, V, in_place=False): | |
'''Applies the pool adjacent violators (PAV) algorithm to the V vectors, | |
ensuring the W_i . V is monotone decreasing for all i.''' | |
# Reconstruct the matrix | |
if not in_place: | |
V = np.copy(V) | |
M = W.dot(V.T) | |
# check all rows for monotonicity constraint | |
violators = (M[:,:-1] - M[:,1:]) < 0 | |
q = np.arange(V.shape[0]) | |
while np.any(violators): | |
j = 0 | |
while j < V.shape[0]-1: | |
# Reconstruct the current 2 columns | |
M_j = W.dot(V[j:j+2].T) | |
# Check for any violations | |
if np.any((M_j[:,0] - M_j[:,1]) < 0): | |
# Merge the two pools together by doing a weighted average | |
pool0 = q == q[j] | |
pool1 = q == q[j+1] | |
w0 = pool0.sum() | |
w1 = pool1.sum() | |
V[pool0 | pool1] = (w0*V[j] + w1*V[j+1]) / (w0+w1) | |
q[pool1] = q[j] | |
j += w1 | |
else: | |
j += 1 | |
# Check for new violators | |
M = W.dot(V.T) | |
violators = (M[:,:-1] - M[:,1:]) < 0 | |
return V | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
nrows = 4 | |
ncols = 20 | |
nembed = 5 | |
W = np.random.gamma(1,1,size=(nrows, nembed)) | |
V = np.random.gamma(1,1,size=(ncols, nembed)).cumsum(axis=1)[::-1] + np.random.gamma(0.1,0.1,size=(ncols, nembed)) | |
# Project to montone curves | |
V_proj = factor_pav(W, V) | |
fig, axarr = plt.subplots(1, nrows, figsize=(nrows*5, 5)) | |
x = np.arange(ncols) | |
M = W.dot(V.T) | |
M_proj = W.dot(V_proj.T) | |
for i in range(nrows): | |
axarr[i].scatter(x, M[i], color='gray', alpha=0.5) | |
axarr[i].plot(x, M_proj[i], color='blue') | |
plt.savefig('plots/factor-pav.pdf', bbox_inches='tight') | |
plt.close() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment