using LinearAlgebra
using StatsBase
using StatsFuns
using NaNMath

"x->2x-1 in place"
function twoxm1!(dat; val=0.0)
    @inbounds for (i,x) in enumerate(dat)
        dat[i] = ifelse(isnan(x), val, 2x-1)
    end
    dat
end

"Perform soft shrinkage"
function softshrink!(dat, val=0)
    val = abs(val)
    @inbounds for (i,x) in enumerate(dat)
        y = abs(x)
        dat[i] = ifelse(y < val, 0, x - copysign(val, x))
    end
    dat
end

"""
Normalize the product A*B such that the norms of the columns of B are unity
"""
function normalize!(A::Matrix{T}, B::Matrix{T}) where T
    W = sqrt.(sum(x->x^2, B, dims=1))
    A .*= W
    B ./= W
end

"Subtract column means"
center(dat) = dat .- mean(dat, dims=1)

function _initialize!(A0, B0, μ0, dat, k, randstart)
    q = twoxm1!(dat) # forces x to be equal to θ when data is missing
    n, d = size(dat)

    # Initialize #
    ##################
    if (!randstart)
        μ=mean(q, dims=1)
        μ=reshape(μ,d,1)
        F=svd(center(q))
        A=F.U[1:n,1:k]
        B=F.V[1:d,1:k] * Diagonal(F.S[1:k])
    else
        μ=randn(d,1)
        A=2rand(n,k).-1
        B=2rand(d,k).-1
    end
    if A0 !== nothing
        W = sqrt.(sum(x->x^2, A0, dims=1))
        if B0 !== nothing
            B = B0 .* W
        end
        A = A0 ./ W
    end
    if μ0 !== nothing
        μ = μ0
    end
    A, B, μ, n, d, q
end

function computeX!(q, μ, A, B, X)
    # n = size(A, 1)
    # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
    mul!(X, A, B')
    n, m = size(X)
    @inbounds for j=1:m, i=1:n
        X[i,j] = 4q[i,j]*(1 - logistic(q[i,j]*(μ[j] + X[i,j])))
    end
    X
end


function compute_loglike!(q, μ, A, B, X)
    # loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
    mul!(X, A, B')
    n, m = size(X)
    loglike = 0.0
    @inbounds for j=1:m, i=1:n
        c = log(logistic(q[i,j]*(μ[j] + X[i,j])))
        loglike += ifelse(isnan(c), 0, c)
    end
    loglike
end

"""
From Lee, Huang, Hu (2010)
Uses the uniform bound for the log likelihood
Can only use lasso=true if λ is the same for all dimensions, 
    which is how this algorithm is coded
"""
function splogitpca(dat::Matrix; λ=0,k=2,verbose=false,maxiters::Int=100,convcrit=1e-5,
                                randstart=false,procrustes=true,lasso=true,normalize=false,
                                A0=nothing, B0=nothing, μ0=nothing, eps = 1e-10)

    A, B, μ, n, d, q = _initialize!(A0, B0, μ0, dat, k, randstart)

    loss_trace=zeros(maxiters)
    loglike = 0.0
    iters = 0
    μ_prev=μ
    A_prev=A
    B_prev=B
    X = zeros(size(A,1), size(B,1))
    for m = 1:maxiters
        μ_prev=μ
        A_prev=A
        B_prev=B

        # θ=ones(n)*μ'+A * B'
        # X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
        # Xcross=X-A * B'
        # μ=(1/n*Xcross'ones(n))

        # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
        X = computeX!(q, μ, A, B, X)
        μ += vec(sum(X, dims=1))/n
                
        # θ=ones(n)*μ'+A * B'
        # X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
        # Xstar=X-ones(n)*μ'
        # if (procrustes)
        #     M=svd(Xstar * B)
        #     A=M.U * M.Vt
        # else 
        #     A = Matrix(qr(Xstar * pinv(B)').Q)
        # end
        # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
        X = computeX!(q, μ, A, B, X)
        if (procrustes)
            M = svd(A*B'B + X*B)
            A[:] = M.U * M.Vt
        else
            A[:] = Matrix(qr(A + X/B').Q)
        end
        
        # θ=(ones(n)*μ') + A * B'
        # X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
        # Xstar= X - ones(n)*μ'
        # X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
        X = computeX!(q, μ, A, B, X)
        # Xstar = X + A*B'
        # C = Xstar'A
        C = X'A + B*A'A
        if lasso
            B = softshrink!(C, 4*n*λ)
#           B = sign.(B_lse).*rectify.(abs.(B_lse).-4*n*λ)
        else
            B[:] = C ./ (1 .+ 4*n*λ*inv.(abs.(B)))
            # B=abs.(B)/(abs.(B).+4*n*λ).*C
        end
        
        loglike=compute_loglike!(q, μ, A, B, X)
        # loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))

        penalty=n*λ*sum(abs, B)
        loss_trace[m]=(-loglike+penalty)/count(x->!isnan(x), dat)

        iters = m
        if verbose
            println(m,"  ",(-loglike),"   ",(penalty),"     ",-loglike+penalty, "     ", loss_trace[m], "      ", convcrit)
        end

        #Converged?
        if (m>4) && (loss_trace[m-1]-loss_trace[m])<convcrit
            break
        end
    end

    # if iters > 1 && loss_trace[iters-1]<loss_trace[iters] #This iteration doesn't count
    #     μ=μ_prev
    #     A=A_prev
    #     B=B_prev
    #     loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
    #     iters -= 1
    # end

    if (normalize) 
        normalize!(A, B)
    end
    
    nzeros=count(x->abs(x)<eps, B)
    # BIC=-2*loglike+log(n)*(d+n*k+count(x->abs(x)>=eps, B))
    BIC=-2*loglike+log(n)*(d+n*k+d*k-nzeros)
    return μ, A, B, nzeros, BIC, iters, loss_trace[1:iters], λ
end

using IterativeSolvers

function _initializel!(A0, B0, μ0, dat, k, randstart)
    q = twoxm1!(dat) # forces x to be equal to θ when data is missing
    n, d = size(dat)

    # Initialize #
    ##################
    if (!randstart)
        μ = mean(q, dims=1)
        μ = reshape(μ,d,1)
        F, _ = svdl(center(q), nsv=k, vecs=:both)
        A=F.U[1:n,1:k]
        B=F.V[1:d,1:k] * Diagonal(F.S[1:k])
    else
        μ=randn(d,1)
        A=2rand(n,k).-1
        B=2rand(d,k).-1
    end
    if A0 !== nothing
        W = sqrt.(sum(x->x^2, A0, dims=1))
        if B0 !== nothing
            B = B0 .* W
        end
        A = A0 ./ W
    end
    if μ0 !== nothing
        μ = μ0
    end
    A, B, μ, n, d, q
end

"""
My version using Iterative SVD (L for Lanczos!)
"""
function splogitpcal(dat::Matrix; λ=0,k=2,verbose=false,maxiters::Int=100,convcrit=1e-5,
    randstart=false,procrustes=true,lasso=true,normalize=false,
    A0=nothing, B0=nothing, μ0=nothing, eps = 1e-10)

A, B, μ, n, d, q = _initialize!(A0, B0, μ0, dat, k, randstart)

loss_trace=zeros(maxiters)
loglike = 0.0
iters = 0
μ_prev=μ
A_prev=A
B_prev=B
X = zeros(size(A,1), size(B,1))
for m = 1:maxiters
μ_prev=μ
A_prev=A
B_prev=B

# θ=ones(n)*μ'+A * B'
# X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
# Xcross=X-A * B'
# μ=(1/n*Xcross'ones(n))

# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
X = computeX!(q, μ, A, B, X)
μ += vec(sum(X, dims=1))/n

# θ=ones(n)*μ'+A * B'
# X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
# Xstar=X-ones(n)*μ'
# if (procrustes)
#     M=svd(Xstar * B)
#     A=M.U * M.Vt
# else 
#     A = Matrix(qr(Xstar * pinv(B)').Q)
# end
# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
X = computeX!(q, μ, A, B, X)
if (procrustes)
M = svd(A*B'B + X*B)
A[:] = M.U * M.Vt
else 
    A[:] = Matrix(qr(A + X/B').Q)
    A[:] = Matrix(qr(A + X/B').Q)
end

# θ=(ones(n)*μ') + A * B'
# X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
# Xstar= X - ones(n)*μ'
# X = 4q.*(1 .- logistic.(q.*(ones(n)*μ'+A*B')))
X = computeX!(q, μ, A, B, X)
# Xstar = X + A*B'
# C = Xstar'A
C = X'A + B*A'A
if lasso
B = softshrink!(C, 4*n*λ)
#           B = sign.(B_lse).*rectify.(abs.(B_lse).-4*n*λ)
else
B[:] = C ./ (1 .+ 4*n*λ*inv.(abs.(B)))
# B=abs.(B)/(abs.(B).+4*n*λ).*C
end

loglike=compute_loglike!(q, μ, A, B, X)
# loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))

penalty=n*λ*sum(abs, B)
loss_trace[m]=(-loglike+penalty)/count(x->!isnan(x), dat)

iters = m
if verbose
println(m,"  ",(-loglike),"   ",(penalty),"     ",-loglike+penalty, "     ", loss_trace[m], "      ", convcrit)
end

#Converged?
if (m>4) && (loss_trace[m-1]-loss_trace[m])<convcrit
break
end
end

# if iters > 1 && loss_trace[iters-1]<loss_trace[iters] #This iteration doesn't count
#     μ=μ_prev
#     A=A_prev
#     B=B_prev
#     loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
#     iters -= 1
# end

if (normalize) 
normalize!(A, B)
end

nzeros=count(x->abs(x)<eps, B)
# BIC=-2*loglike+log(n)*(d+n*k+count(x->abs(x)>=eps, B))
BIC=-2*loglike+log(n)*(d+n*k+d*k-nzeros)
return μ, A, B, nzeros, BIC, iters, loss_trace[1:iters], λ
end


function splogitpcacoords(dat; λs=exp10.(range(-2,stop=2,length=10)),k=2,verbose=false,maxiters=100,convcrit=1e-5,
                                randstart=false,normalize=false,
                                A0=nothing,B0=nothing,μ0=nothing,eps = 1e-10) 
  # From Lee, Huang (2013)
  # Uses the uniform bound for the log likelihood

  # Initialize #
  ##################
  A, B, μ, n, d, q = _initialize!(A0, B0, μ0, dat, k, randstart)
  
  BICs=fill(NaN,length(λs),k,dimnames=list(paste0("10^",round(log10(λs),2)),1:k))
  zeros_mat=fill(NaN,length(λs),k,dimnames=list(paste0("10^",round(log10(λs),2)),1:k))
  iters=fill(NaN,length(λs),k,dimnames=list(paste0("10^",round(log10(λs),2)),1:k))
  
  θ=ones(n)*μ'+A * B'
  X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
  Xcross=X-A * B'
  μ=(1/n*(Xcross)' * ones(n))
  loglike = 0.0
  iters = 0
  for m in 1:k
    A_prev=A
    B_prev=B
    
    θ=ones(n)*μ'+A * B'
    X=(θ+4*q.*(1.0.-logistic.(q.*θ)))
    Xm=X-(ones(n)*μ')+A[:,-m] * B[:,-m]'
    
    Bms=fill(NaN,d,length(λs))
    Ams=fill(NaN,n,length(λs))
    for λ in λs 
      for i in 1:maxiters 
        if (sum(x->x^2, B[:,m])==0) 
          A[:,m]=Xm * B[:,m]
          break
        end
        A[:,m]=Xm * B[:,m]/sum(x->x^2, B[:,m])
        A[:,m]=A[:,m]/sqrt(sum(x->x^2, A[:,m]))
        
        B_lse=Xm'*A[:,m]
        B[:,m]=sign.(B_lse).*rectify.(abs.(B_lse).-λ)
        
        loglike=NaNMath.sum(log.(logistic.(q.*(ones(n)*μ' + A * B'))))
        penalty=0.25*λ*sum(abs, B[:,m])
        loss=(-loglike+penalty)/count(x->!isnan(x), dat)
        iters = m
        
        if verbose
            println(m,"  ",(-loglike),"   ",(penalty),"     ",-loglike+penalty)
        end
    
        #Converged?
        if (i>4) && (prev_loss-loss)/prev_loss<convcrit
            break
        end
        
        prev_loss=loss
      end
      Bms[:,λ==λs]=B[:,m]/ifelse(sum(B[:,m]^2)==0,1,sqrt(sum(B[:,m]^2)))
      Ams[:,λ==λs]=Xm * Bms[:,λ==λs]/ifelse(sum(Bms[:,λ==λs]^2)==0,1,sum(Bms[:,λ==λs]^2))
      
      BICs[λ==λs,m]=-2*loglike+log(n*d)*(sum(abs.(B).>=eps))
      zeros_mat[λ==λs,m]=sum(abs.(B[:,m]).<eps)
      iters[λ==λs,m]=i
    end
    B[:,m]=Bms[:,which.min(BICs[:,m])]
    A[:,m]=Ams[:,which.min(BICs[:,m])]
  end
  
  if (normalize)
        normalize!(A, B)
  end
  
  nzeros=sum(abs.(B).<eps)
  BIC=-2*loglike+log(n*d)*(sum(abs.(B).>=eps))
  return μ, A, B, nzeros, zeros_mat, BICs, BIC,λs, iters
end

# Simple tests
let 
    n, d, k = 10, 4, 2
    # A0 = rand(n, k)
    # B0 = rand(d, k)
    A0 = [0.004757656 0.5484724
    0.097842901 0.7151051
    0.847809061 0.9244518
    0.587115318 0.1226018
    0.645016418 0.9158401
    0.512536788 0.7956434
    0.630291718 0.7374858
    0.744668306 0.9878830
    0.824168577 0.9300954
    0.500897135 0.2217333]
    B0= [
    0.65122711 0.5302900
    0.12274696 0.2805810
    0.09102886 0.1355389
    0.91832927 0.2721085]

    X = A0*B0'

    let
        A, B, μ, n, d, q = _initialize!(nothing, nothing, nothing, copy(X), k, false)
        
        A1 = [
            -0.5388835  0.333868208
            -0.3754175  0.428292562
            0.3857167 -0.028172460
            -0.2365052 -0.652655997
            0.2034485  0.146812728
            0.0281748  0.138057987
            0.1030301 -0.031392469
            0.3262791  0.133665953
            0.3677298 -0.000633053
            -0.2635728 -0.467843459
        ]
        B1 = [
            1.7012514  0.2104329
            0.5569920  0.2680516
            0.3181553  0.1070674
            1.7766535 -0.3047110
        ]

        @assert norm(A - A1) < 1e-6
        @assert norm(B - B1) < 1e-6
    end

    let 
        μ, A, B, nzeros, BIC, iters, losses, λ = splogitpca(copy(X))

        μ1 = [0.11143104907230131, -6.564739360771476, -5.150479248661679, 3.1076465866696625]
        A1 = [-0.4711911569853924 -0.1489574609442475; -0.6112679385689244 -0.21907436712904096; 0.17677392090006658 -0.32920587480319885; -0.12646233029137596 -0.5057072366168655; 0.21416589014147108 -0.29440618782253014; 0.29999916840057045 -0.2510075989146518; 0.25869949074847665 -0.23679822898302308; 0.18476194498116696 -0.33697878594619274; 0.17935771304814505 -0.3272536782315915; -0.29663199282213143 -0.3732276622116619]
        B1 = [29.334390862738818 3.0935769680639233; -3.1451796495190014 17.718901626804065; -0.8610989625744958 11.142872726019872; 20.653625916707483 -11.121500611519389]
        nzeros1 = 0
        BIC1 = 78.09400917615692
        iters1 = 100
        losses1 = [0.3192568800107156, 0.24977073285809892, 0.21330490458541732, 0.1899828633617256, 0.1735122256779609, 0.16114605957633704, 0.15146109632832494, 0.14363527557579114, 0.1371565015500379, 0.13168741770455047, 0.12699595883671305, 0.12291687282860156, 0.11932909880827051, 0.11614181739372556, 0.11328550369088539, 0.11070599638614712, 0.10836045395242884, 0.10621452962566721, 0.10424035540856966, 0.1024150761609641, 0.10071976572936119, 0.09913861346320799, 0.09765830536474736, 0.0962675475009829, 0.09495669485112829, 0.09371745929147622, 0.09254267767201296, 0.09142612601238616, 0.09036236944427692, 0.08934664011448008, 0.08837473714519413, 0.08744294413267219, 0.08654796069463733, 0.08568684534937958, 0.08485696759454661, 0.084055967500555, 0.08328172147765685, 0.08253231314263992, 0.08180600841968862, 0.08110123417396806, 0.08041655980631743, 0.07975068134082348, 0.07910240761983256, 0.07847064828761784, 0.07785440329787183, 0.07725275372407196, 0.07666485368762423, 0.07608992324812383, 0.07552724212433402, 0.07497614413457007, 0.07443601226186682, 0.07390627426323096, 0.07338639875394007, 0.07287589170764217, 0.0723742933212689, 0.07188117520075957, 0.07139613782951738, 0.07091880828656347, 0.07044883818565822, 0.06998590181034249, 0.06952969442301517, 0.06907993072888352, 0.06863634347797201, 0.06819868219040548, 0.0677667119919442, 0.06734021254827935, 0.0669189770879308, 0.06650281150475232, 0.06609153353207073, 0.0656849719813751, 0.06528296603925979, 0.06488536461701414, 0.06449202574786106, 0.06410281602738557, 0.0637176100931686, 0.06333629014006628, 0.0629587454679458, 0.0625848720590257, 0.06221457218226193, 0.06184775402248566, 0.061484331332236145, 0.06112422310443886, 0.060767353264271115, 0.06041365037872024, 0.060063047382495384, 0.059715481319083705, 0.05937089309586481, 0.059029227252303126, 0.05869043174033748, 0.058354457716168026, 0.05802125934272438, 0.057690793602161984, 0.0573630201177989, 0.05703790098495918, 0.05671540061023694, 0.05639548555874026, 0.05607812440891309, 0.05576328761456515, 0.05545094737377308, 0.055141077504343126]
        λ1 = 0

        @assert norm(μ - μ1) < 1e-10
        @assert norm(A - A1) < 1e-10
        @assert norm(B - B1) < 1e-10
        @assert nzeros == nzeros1
        @assert norm(BIC - BIC1) < 1e-10
        @assert iters == iters1
        @assert norm(losses - losses1) < 1e-10
        @assert norm(λ - λ1) < 1e-10
    end 

    let 
        μ, A, B, nzeros, BIC, iters, losses, λ = splogitpca(copy(X), procrustes = false)

        μ1 = [0.15487984167564228, -6.568819637068933, -5.153303852683945, 3.1354113472419947]
        A1 = [-0.4713085616574335 -0.1486499336486215; -0.6116225488233055 -0.21892149365503077; 0.17668326805953308 -0.3291070254305201; -0.12676563488667017 -0.5058256360958331; 0.2136841499552063 -0.2947788948414918; 0.29919833886531016 -0.25214788479509065; 0.2579814533003914 -0.2376067790175394; 0.18463978727899122 -0.3369513463746447; 0.17923695827317881 -0.3271853654971497; -0.2975682504108814 -0.37187223776149625]
        B1 = [29.36371490195748 3.3473468390064522; -3.3541459487802157 17.669448756051334; -0.9910645451446817 11.124796865100537; 20.688053338772107 -10.876220678727316]
        nzeros1 = 0
        BIC1 = 78.09522219179412
        iters1 = 100
        losses1 = [0.31927078677106835, 0.24979181872115347, 0.2133183230695471, 0.18999383437616968, 0.1735228915110179, 0.16115692983943103, 0.1514721476812388, 0.14364636426924043, 0.13716749159847258, 0.1316982099026348, 0.12700648917113633, 0.12292710375757314, 0.11933901071427137, 0.11615140198242953, 0.11329475936292606, 0.11071492516618249, 0.10836905953592055, 0.10622281619007441, 0.10424832693190614, 0.10242273606938757, 0.1007271167532976, 0.09914565762194505, 0.09766504402696699, 0.09627398148718055, 0.09496282455615426, 0.09372328481354868, 0.09254819893988216, 0.09143134290580095, 0.09036728190515164, 0.08935124824758726, 0.08837904130810297, 0.08744694501519108, 0.08655165938832812, 0.08569024340786899, 0.08486006708542487, 0.08405877104965467, 0.08328423230647461, 0.08253453509960995, 0.08180794600595748, 0.08110289256425271, 0.08041794486536624, 0.07975179963593723, 0.0791032664298458, 0.07847125560869023, 0.07785476784639267, 0.0772528849369398, 0.07666476172012959, 0.07608961896963139, 0.07552673711193403, 0.07497545066484831, 0.07443514330092219, 0.07390524345505597, 0.07338522040726606, 0.07287458078134124, 0.07237286540839802, 0.07187964651132406, 0.0713945251720289, 0.07091712904846326, 0.07044711031267445, 0.06998414378485004, 0.06952792524146519, 0.06907816987837065, 0.06863461091200698, 0.06819699830396261, 0.0677650975958532, 0.06733868884303228, 0.06691756563697579, 0.06650153420734686, 0.06609041259576673, 0.06568402989420989, 0.06528222554172623, 0.06488484867388328, 0.06449175751993028, 0.06410281884322391, 0.06371790742093102, 0.06333690555944513, 0.06295970264232852, 0.06258619470792325, 0.06221628405407158, 0.06184987886764911, 0.061486892876849705, 0.061127245024371245, 0.060770859159838375, 0.060417663749967876, 0.06006759160512929, 0.059720579621091584, 0.05937656853486391, 0.059035502693647504, 0.05869732983601197, 0.05836200088449596, 0.05802946974890764, 0.05769969313967223, 0.057372630390632874, 0.057048243290767894, 0.056726495924335395, 0.05640735451900135, 0.05609078730154475, 0.05577676436076882, 0.055465257517278245, 0.05515624019980808]
        λ1 = 0

        @assert norm(μ - μ1) < 1e-10
        @assert norm(A - A1) < 1e-10
        @assert norm(B - B1) < 1e-10
        @assert nzeros == nzeros1
        @assert norm(BIC - BIC1) < 1e-10
        @assert iters == iters1
        @assert norm(losses - losses1) < 1e-10
        @assert norm(λ - λ1) < 1e-10
    end

    let
        A1 = copy(A0)
        B1 = copy(B0)
        normalize!(A1, B1)
        @assert norm(A0*B0' - A1*B1') < 1e-10
        @assert norm(norm(B1[:,1]) - 1) < 1e-10
        @assert norm(norm(B1[:,2]) - 1) < 1e-10
    end
end

# using Profile
# Profile.clear()
# X=copy(X0)
# @profile q=splogitpca(X) #1.16 seconds
# Profile.print()

# 
#R: 7.65368 secs
n,m=2000,1600
X0=rand(n,m)

X=copy(X0)
@time q=splogitpca(X)

X=copy(X0)
@time q=splogitpcal(X)

# X=copy(X0)
# @time q=splogitpca(X, procrustes=false)

# X=copy(X0)
# @time q=splogitpca(X, procrustes=false, lasso=false)

# X=copy(X0)
# @time q=splogitpca(X, lasso=false)

#2.6s