using LinearAlgebra using StatsBase using StatsFuns using NaNMath "x->2x-1 in place" function twoxm1!(dat; val=0.0) @inbounds for (i,x) in enumerate(dat) dat[i] = ifelse(isnan(x), val, 2x-1) end dat end "Perform soft shrinkage" function softshrink!(dat, val=0) val = abs(val) @inbounds for (i,x) in enumerate(dat) y = abs(x) dat[i] = ifelse(y < val, 0, x - copysign(val, x)) end dat end """ Normalize the product A*B such that the norms of the columns of B are unity """ function normalize!(A::Matrix{T}, B::Matrix{T}) where T W = sqrt.(sum(x->x^2, B, dims=1)) A .*= W B ./= W end "Subtract column means" center(dat) = dat .- mean(dat, dims=1) function _initialize!(A0, B0, μ0, dat, k, randstart) q = twoxm1!(dat) # forces x to be equal to θ when data is missing n, d = size(dat) # Initialize # ################## if (!randstart) μ=mean(q, dims=1) μ=reshape(μ,d,1) F=svd(center(q)) A=F.U[1:n,1:k] B=F.V[1:d,1:k] * Diagonal(F.S[1:k]) else μ=randn(d,1) A=2rand(n,k).-1 B=2rand(d,k).-1 end if A0 !== nothing W = sqrt.(sum(x->x^2, A0, dims=1)) if B0 !== nothing B = B0 .* W end A = A0 ./ W end if μ0 !== nothing μ = μ0 end A, B, μ, n, d, q end function computeX!(q, μ, A, B, X) # n = size(A, 1) # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B'))) mul!(X, A, B') n, m = size(X) @inbounds for j=1:m, i=1:n X[i,j] = 4q[i,j]*(1 - logistic(q[i,j]*(μ[j] + X[i,j]))) end X end function compute_loglike!(q, μ, A, B, X) # loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B')))) mul!(X, A, B') n, m = size(X) loglike = 0.0 @inbounds for j=1:m, i=1:n c = log(logistic(q[i,j]*(μ[j] + X[i,j]))) loglike += ifelse(isnan(c), 0, c) end loglike end """ From Lee, Huang, Hu (2010) Uses the uniform bound for the log likelihood Can only use lasso=true if λ is the same for all dimensions, which is how this algorithm is coded """ function splogitpca(dat::Matrix; λ=0,k=2,verbose=false,maxiters::Int=100,convcrit=1e-5, randstart=false,procrustes=true,lasso=true,normalize=false, A0=nothing, B0=nothing, μ0=nothing, eps = 1e-10) A, B, μ, n, d, q = _initialize!(A0, B0, μ0, dat, k, randstart) loss_trace=zeros(maxiters) loglike = 0.0 iters = 0 μ_prev=μ A_prev=A B_prev=B X = zeros(size(A,1), size(B,1)) for m = 1:maxiters μ_prev=μ A_prev=A B_prev=B # θ=ones(n)*μ'+A * B' # X=(θ+4*q.*(1.0.-logistic.(q.*θ))) # Xcross=X-A * B' # μ=(1/n*Xcross'ones(n)) # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B'))) X = computeX!(q, μ, A, B, X) μ += vec(sum(X, dims=1))/n # θ=ones(n)*μ'+A * B' # X=(θ+4*q.*(1.0.-logistic.(q.*θ))) # Xstar=X-ones(n)*μ' # if (procrustes) # M=svd(Xstar * B) # A=M.U * M.Vt # else # A = Matrix(qr(Xstar * pinv(B)').Q) # end # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B'))) X = computeX!(q, μ, A, B, X) if (procrustes) M = svd(A*B'B + X*B) A[:] = M.U * M.Vt else A[:] = Matrix(qr(A + X/B').Q) end # θ=(ones(n)*μ') + A * B' # X=(θ+4*q.*(1.0.-logistic.(q.*θ))) # Xstar= X - ones(n)*μ' # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B'))) X = computeX!(q, μ, A, B, X) # Xstar = X + A*B' # C = Xstar'A C = X'A + B*A'A if lasso B = softshrink!(C, 4*n*λ) # B = sign.(B_lse).*rectify.(abs.(B_lse).-4*n*λ) else B[:] = C ./ (1 .+ 4*n*λ*inv.(abs.(B))) # B=abs.(B)/(abs.(B).+4*n*λ).*C end loglike=compute_loglike!(q, μ, A, B, X) # loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B')))) penalty=n*λ*sum(abs, B) loss_trace[m]=(-loglike+penalty)/count(x->!isnan(x), dat) iters = m if verbose println(m," ",(-loglike)," ",(penalty)," ",-loglike+penalty, " ", loss_trace[m], " ", convcrit) end #Converged? if (m>4) && (loss_trace[m-1]-loss_trace[m])<convcrit break end end # if iters > 1 && loss_trace[iters-1]<loss_trace[iters] #This iteration doesn't count # μ=μ_prev # A=A_prev # B=B_prev # loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B')))) # iters -= 1 # end if (normalize) normalize!(A, B) end nzeros=count(x->abs(x)<eps, B) # BIC=-2*loglike+log(n)*(d+n*k+count(x->abs(x)>=eps, B)) BIC=-2*loglike+log(n)*(d+n*k+d*k-nzeros) return μ, A, B, nzeros, BIC, iters, loss_trace[1:iters], λ end using IterativeSolvers function _initializel!(A0, B0, μ0, dat, k, randstart) q = twoxm1!(dat) # forces x to be equal to θ when data is missing n, d = size(dat) # Initialize # ################## if (!randstart) μ = mean(q, dims=1) μ = reshape(μ,d,1) F, _ = svdl(center(q), nsv=k, vecs=:both) A=F.U[1:n,1:k] B=F.V[1:d,1:k] * Diagonal(F.S[1:k]) else μ=randn(d,1) A=2rand(n,k).-1 B=2rand(d,k).-1 end if A0 !== nothing W = sqrt.(sum(x->x^2, A0, dims=1)) if B0 !== nothing B = B0 .* W end A = A0 ./ W end if μ0 !== nothing μ = μ0 end A, B, μ, n, d, q end """ My version using Iterative SVD (L for Lanczos!) """ function splogitpcal(dat::Matrix; λ=0,k=2,verbose=false,maxiters::Int=100,convcrit=1e-5, randstart=false,procrustes=true,lasso=true,normalize=false, A0=nothing, B0=nothing, μ0=nothing, eps = 1e-10) A, B, μ, n, d, q = _initialize!(A0, B0, μ0, dat, k, randstart) loss_trace=zeros(maxiters) loglike = 0.0 iters = 0 μ_prev=μ A_prev=A B_prev=B X = zeros(size(A,1), size(B,1)) for m = 1:maxiters μ_prev=μ A_prev=A B_prev=B # θ=ones(n)*μ'+A * B' # X=(θ+4*q.*(1.0.-logistic.(q.*θ))) # Xcross=X-A * B' # μ=(1/n*Xcross'ones(n)) # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B'))) X = computeX!(q, μ, A, B, X) μ += vec(sum(X, dims=1))/n # θ=ones(n)*μ'+A * B' # X=(θ+4*q.*(1.0.-logistic.(q.*θ))) # Xstar=X-ones(n)*μ' # if (procrustes) # M=svd(Xstar * B) # A=M.U * M.Vt # else # A = Matrix(qr(Xstar * pinv(B)').Q) # end # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B'))) X = computeX!(q, μ, A, B, X) if (procrustes) M = svd(A*B'B + X*B) A[:] = M.U * M.Vt else A[:] = Matrix(qr(A + X/B').Q) A[:] = Matrix(qr(A + X/B').Q) end # θ=(ones(n)*μ') + A * B' # X=(θ+4*q.*(1.0.-logistic.(q.*θ))) # Xstar= X - ones(n)*μ' # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B'))) X = computeX!(q, μ, A, B, X) # Xstar = X + A*B' # C = Xstar'A C = X'A + B*A'A if lasso B = softshrink!(C, 4*n*λ) # B = sign.(B_lse).*rectify.(abs.(B_lse).-4*n*λ) else B[:] = C ./ (1 .+ 4*n*λ*inv.(abs.(B))) # B=abs.(B)/(abs.(B).+4*n*λ).*C end loglike=compute_loglike!(q, μ, A, B, X) # loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B')))) penalty=n*λ*sum(abs, B) loss_trace[m]=(-loglike+penalty)/count(x->!isnan(x), dat) iters = m if verbose println(m," ",(-loglike)," ",(penalty)," ",-loglike+penalty, " ", loss_trace[m], " ", convcrit) end #Converged? if (m>4) && (loss_trace[m-1]-loss_trace[m])<convcrit break end end # if iters > 1 && loss_trace[iters-1]<loss_trace[iters] #This iteration doesn't count # μ=μ_prev # A=A_prev # B=B_prev # loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B')))) # iters -= 1 # end if (normalize) normalize!(A, B) end nzeros=count(x->abs(x)<eps, B) # BIC=-2*loglike+log(n)*(d+n*k+count(x->abs(x)>=eps, B)) BIC=-2*loglike+log(n)*(d+n*k+d*k-nzeros) return μ, A, B, nzeros, BIC, iters, loss_trace[1:iters], λ end function splogitpcacoords(dat; λs=exp10.(range(-2,stop=2,length=10)),k=2,verbose=false,maxiters=100,convcrit=1e-5, randstart=false,normalize=false, A0=nothing,B0=nothing,μ0=nothing,eps = 1e-10) # From Lee, Huang (2013) # Uses the uniform bound for the log likelihood # Initialize # ################## A, B, μ, n, d, q = _initialize!(A0, B0, μ0, dat, k, randstart) BICs=fill(NaN,length(λs),k,dimnames=list(paste0("10^",round(log10(λs),2)),1:k)) zeros_mat=fill(NaN,length(λs),k,dimnames=list(paste0("10^",round(log10(λs),2)),1:k)) iters=fill(NaN,length(λs),k,dimnames=list(paste0("10^",round(log10(λs),2)),1:k)) θ=ones(n)*μ'+A * B' X=(θ+4*q.*(1.0.-logistic.(q.*θ))) Xcross=X-A * B' μ=(1/n*(Xcross)' * ones(n)) loglike = 0.0 iters = 0 for m in 1:k A_prev=A B_prev=B θ=ones(n)*μ'+A * B' X=(θ+4*q.*(1.0.-logistic.(q.*θ))) Xm=X-(ones(n)*μ')+A[:,-m] * B[:,-m]' Bms=fill(NaN,d,length(λs)) Ams=fill(NaN,n,length(λs)) for λ in λs for i in 1:maxiters if (sum(x->x^2, B[:,m])==0) A[:,m]=Xm * B[:,m] break end A[:,m]=Xm * B[:,m]/sum(x->x^2, B[:,m]) A[:,m]=A[:,m]/sqrt(sum(x->x^2, A[:,m])) B_lse=Xm'*A[:,m] B[:,m]=sign.(B_lse).*rectify.(abs.(B_lse).-λ) loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B')))) penalty=0.25*λ*sum(abs, B[:,m]) loss=(-loglike+penalty)/count(x->!isnan(x), dat) iters = m if verbose println(m," ",(-loglike)," ",(penalty)," ",-loglike+penalty) end #Converged? if (i>4) && (prev_loss-loss)/prev_loss<convcrit break end prev_loss=loss end Bms[:,λ==λs]=B[:,m]/ifelse(sum(B[:,m]^2)==0,1,sqrt(sum(B[:,m]^2))) Ams[:,λ==λs]=Xm * Bms[:,λ==λs]/ifelse(sum(Bms[:,λ==λs]^2)==0,1,sum(Bms[:,λ==λs]^2)) BICs[λ==λs,m]=-2*loglike+log(n*d)*(sum(abs.(B).>=eps)) zeros_mat[λ==λs,m]=sum(abs.(B[:,m]).<eps) iters[λ==λs,m]=i end B[:,m]=Bms[:,which.min(BICs[:,m])] A[:,m]=Ams[:,which.min(BICs[:,m])] end if (normalize) normalize!(A, B) end nzeros=sum(abs.(B).<eps) BIC=-2*loglike+log(n*d)*(sum(abs.(B).>=eps)) return μ, A, B, nzeros, zeros_mat, BICs, BIC,λs, iters end # Simple tests let n, d, k = 10, 4, 2 # A0 = rand(n, k) # B0 = rand(d, k) A0 = [0.004757656 0.5484724 0.097842901 0.7151051 0.847809061 0.9244518 0.587115318 0.1226018 0.645016418 0.9158401 0.512536788 0.7956434 0.630291718 0.7374858 0.744668306 0.9878830 0.824168577 0.9300954 0.500897135 0.2217333] B0= [ 0.65122711 0.5302900 0.12274696 0.2805810 0.09102886 0.1355389 0.91832927 0.2721085] X = A0*B0' let A, B, μ, n, d, q = _initialize!(nothing, nothing, nothing, copy(X), k, false) A1 = [ -0.5388835 0.333868208 -0.3754175 0.428292562 0.3857167 -0.028172460 -0.2365052 -0.652655997 0.2034485 0.146812728 0.0281748 0.138057987 0.1030301 -0.031392469 0.3262791 0.133665953 0.3677298 -0.000633053 -0.2635728 -0.467843459 ] B1 = [ 1.7012514 0.2104329 0.5569920 0.2680516 0.3181553 0.1070674 1.7766535 -0.3047110 ] @assert norm(A - A1) < 1e-6 @assert norm(B - B1) < 1e-6 end let μ, A, B, nzeros, BIC, iters, losses, λ = splogitpca(copy(X)) μ1 = [0.11143104907230131, -6.564739360771476, -5.150479248661679, 3.1076465866696625] A1 = [-0.4711911569853924 -0.1489574609442475; -0.6112679385689244 -0.21907436712904096; 0.17677392090006658 -0.32920587480319885; -0.12646233029137596 -0.5057072366168655; 0.21416589014147108 -0.29440618782253014; 0.29999916840057045 -0.2510075989146518; 0.25869949074847665 -0.23679822898302308; 0.18476194498116696 -0.33697878594619274; 0.17935771304814505 -0.3272536782315915; -0.29663199282213143 -0.3732276622116619] B1 = [29.334390862738818 3.0935769680639233; -3.1451796495190014 17.718901626804065; -0.8610989625744958 11.142872726019872; 20.653625916707483 -11.121500611519389] nzeros1 = 0 BIC1 = 78.09400917615692 iters1 = 100 losses1 = [0.3192568800107156, 0.24977073285809892, 0.21330490458541732, 0.1899828633617256, 0.1735122256779609, 0.16114605957633704, 0.15146109632832494, 0.14363527557579114, 0.1371565015500379, 0.13168741770455047, 0.12699595883671305, 0.12291687282860156, 0.11932909880827051, 0.11614181739372556, 0.11328550369088539, 0.11070599638614712, 0.10836045395242884, 0.10621452962566721, 0.10424035540856966, 0.1024150761609641, 0.10071976572936119, 0.09913861346320799, 0.09765830536474736, 0.0962675475009829, 0.09495669485112829, 0.09371745929147622, 0.09254267767201296, 0.09142612601238616, 0.09036236944427692, 0.08934664011448008, 0.08837473714519413, 0.08744294413267219, 0.08654796069463733, 0.08568684534937958, 0.08485696759454661, 0.084055967500555, 0.08328172147765685, 0.08253231314263992, 0.08180600841968862, 0.08110123417396806, 0.08041655980631743, 0.07975068134082348, 0.07910240761983256, 0.07847064828761784, 0.07785440329787183, 0.07725275372407196, 0.07666485368762423, 0.07608992324812383, 0.07552724212433402, 0.07497614413457007, 0.07443601226186682, 0.07390627426323096, 0.07338639875394007, 0.07287589170764217, 0.0723742933212689, 0.07188117520075957, 0.07139613782951738, 0.07091880828656347, 0.07044883818565822, 0.06998590181034249, 0.06952969442301517, 0.06907993072888352, 0.06863634347797201, 0.06819868219040548, 0.0677667119919442, 0.06734021254827935, 0.0669189770879308, 0.06650281150475232, 0.06609153353207073, 0.0656849719813751, 0.06528296603925979, 0.06488536461701414, 0.06449202574786106, 0.06410281602738557, 0.0637176100931686, 0.06333629014006628, 0.0629587454679458, 0.0625848720590257, 0.06221457218226193, 0.06184775402248566, 0.061484331332236145, 0.06112422310443886, 0.060767353264271115, 0.06041365037872024, 0.060063047382495384, 0.059715481319083705, 0.05937089309586481, 0.059029227252303126, 0.05869043174033748, 0.058354457716168026, 0.05802125934272438, 0.057690793602161984, 0.0573630201177989, 0.05703790098495918, 0.05671540061023694, 0.05639548555874026, 0.05607812440891309, 0.05576328761456515, 0.05545094737377308, 0.055141077504343126] λ1 = 0 @assert norm(μ - μ1) < 1e-10 @assert norm(A - A1) < 1e-10 @assert norm(B - B1) < 1e-10 @assert nzeros == nzeros1 @assert norm(BIC - BIC1) < 1e-10 @assert iters == iters1 @assert norm(losses - losses1) < 1e-10 @assert norm(λ - λ1) < 1e-10 end let μ, A, B, nzeros, BIC, iters, losses, λ = splogitpca(copy(X), procrustes = false) μ1 = [0.15487984167564228, -6.568819637068933, -5.153303852683945, 3.1354113472419947] A1 = [-0.4713085616574335 -0.1486499336486215; -0.6116225488233055 -0.21892149365503077; 0.17668326805953308 -0.3291070254305201; -0.12676563488667017 -0.5058256360958331; 0.2136841499552063 -0.2947788948414918; 0.29919833886531016 -0.25214788479509065; 0.2579814533003914 -0.2376067790175394; 0.18463978727899122 -0.3369513463746447; 0.17923695827317881 -0.3271853654971497; -0.2975682504108814 -0.37187223776149625] B1 = [29.36371490195748 3.3473468390064522; -3.3541459487802157 17.669448756051334; -0.9910645451446817 11.124796865100537; 20.688053338772107 -10.876220678727316] nzeros1 = 0 BIC1 = 78.09522219179412 iters1 = 100 losses1 = [0.31927078677106835, 0.24979181872115347, 0.2133183230695471, 0.18999383437616968, 0.1735228915110179, 0.16115692983943103, 0.1514721476812388, 0.14364636426924043, 0.13716749159847258, 0.1316982099026348, 0.12700648917113633, 0.12292710375757314, 0.11933901071427137, 0.11615140198242953, 0.11329475936292606, 0.11071492516618249, 0.10836905953592055, 0.10622281619007441, 0.10424832693190614, 0.10242273606938757, 0.1007271167532976, 0.09914565762194505, 0.09766504402696699, 0.09627398148718055, 0.09496282455615426, 0.09372328481354868, 0.09254819893988216, 0.09143134290580095, 0.09036728190515164, 0.08935124824758726, 0.08837904130810297, 0.08744694501519108, 0.08655165938832812, 0.08569024340786899, 0.08486006708542487, 0.08405877104965467, 0.08328423230647461, 0.08253453509960995, 0.08180794600595748, 0.08110289256425271, 0.08041794486536624, 0.07975179963593723, 0.0791032664298458, 0.07847125560869023, 0.07785476784639267, 0.0772528849369398, 0.07666476172012959, 0.07608961896963139, 0.07552673711193403, 0.07497545066484831, 0.07443514330092219, 0.07390524345505597, 0.07338522040726606, 0.07287458078134124, 0.07237286540839802, 0.07187964651132406, 0.0713945251720289, 0.07091712904846326, 0.07044711031267445, 0.06998414378485004, 0.06952792524146519, 0.06907816987837065, 0.06863461091200698, 0.06819699830396261, 0.0677650975958532, 0.06733868884303228, 0.06691756563697579, 0.06650153420734686, 0.06609041259576673, 0.06568402989420989, 0.06528222554172623, 0.06488484867388328, 0.06449175751993028, 0.06410281884322391, 0.06371790742093102, 0.06333690555944513, 0.06295970264232852, 0.06258619470792325, 0.06221628405407158, 0.06184987886764911, 0.061486892876849705, 0.061127245024371245, 0.060770859159838375, 0.060417663749967876, 0.06006759160512929, 0.059720579621091584, 0.05937656853486391, 0.059035502693647504, 0.05869732983601197, 0.05836200088449596, 0.05802946974890764, 0.05769969313967223, 0.057372630390632874, 0.057048243290767894, 0.056726495924335395, 0.05640735451900135, 0.05609078730154475, 0.05577676436076882, 0.055465257517278245, 0.05515624019980808] λ1 = 0 @assert norm(μ - μ1) < 1e-10 @assert norm(A - A1) < 1e-10 @assert norm(B - B1) < 1e-10 @assert nzeros == nzeros1 @assert norm(BIC - BIC1) < 1e-10 @assert iters == iters1 @assert norm(losses - losses1) < 1e-10 @assert norm(λ - λ1) < 1e-10 end let A1 = copy(A0) B1 = copy(B0) normalize!(A1, B1) @assert norm(A0*B0' - A1*B1') < 1e-10 @assert norm(norm(B1[:,1]) - 1) < 1e-10 @assert norm(norm(B1[:,2]) - 1) < 1e-10 end end # using Profile # Profile.clear() # X=copy(X0) # @profile q=splogitpca(X) #1.16 seconds # Profile.print() # #R: 7.65368 secs n,m=2000,1600 X0=rand(n,m) X=copy(X0) @time q=splogitpca(X) X=copy(X0) @time q=splogitpcal(X) # X=copy(X0) # @time q=splogitpca(X, procrustes=false) # X=copy(X0) # @time q=splogitpca(X, procrustes=false, lasso=false) # X=copy(X0) # @time q=splogitpca(X, lasso=false) #2.6s