Last active
March 7, 2017 03:26
-
-
Save jverzani/035114b20c011dfc23ea31d95e788fba to your computer and use it in GitHub Desktop.
Julia implementation of AMVW algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module AMVW | |
## Julia implementation of | |
## Fast and backward stable computation of roots of polynomials | |
## https://lirias.kuleuven.be/bitstream/123456789/461961/1/TW654.pdf | |
## Derived from fortran code https://people.cs.kuleuven.be/~raf.vandebril/homepage/software/companion_qr.php?menu=5 | |
## License is unclear, but hopefully can be MIT licensed | |
## TODO | |
## handle case on non convergence | |
## API work | |
## Get faster! seems to be 6x slower -- or more -- than just roots(p) and | |
## about as accurate | |
## did: | |
## * check norm in vals! so that rotatators have norm \approx 1 | |
## * implement speed up for C_i = B_i in initial bulge chasing | |
## Utils | |
## take poly [p0, p1, ..., pn] and return | |
## [q_m-1, q_m-2, ..., q0], k | |
## where we trim of k roots of 0, and then make p monic, then reverese | |
function reverse_poly{T}(ps::Vector{T}) | |
## trim any 0s from the end of ps | |
N = findlast(p -> p != zero(T), ps) | |
N == 0 && return(zeros(T,0), length(ps)) | |
ps = ps[1:N] | |
## find 0s | |
K = findfirst(p -> p != zero(T), ps) | |
ps = ps[K:end] | |
qs = reverse(ps./ps[end])[2:end] | |
qs, K-1 | |
end | |
using Compat | |
## Types | |
@compat abstract type CoreTransform{T} end | |
@compat abstract type Rotator{T} <: CoreTransform{T} end | |
function Base.ctranspose(r::Rotator) | |
c,s = r.xs | |
RealRotator([c,-s], r.i) | |
end | |
# the index is supeflous for now, and a bit of a hassle to keep immutable | |
# but might be of help later if twisting is approached. Shouldn't effect speed, but does mean 9N storage, not 6N | |
# so may be | |
struct RealRotator{T} <: Rotator{T} | |
xs::Vector{T} | |
i::Vector{Int} | |
end | |
Base.one{T}(::Type{RealRotator{T}})=RealRotator([one(T), zero(T)], zeros(Int,1)) | |
Base.ones{T}(S::Type{RealRotator{T}}, N) = [one(S) for i in 1:N] | |
## get/set values | |
vals{T}(r::RealRotator{T}) = r.xs | |
function vals!{T}(r::RealRotator, c::T, s::T) | |
# normalize in case of roundoff errors | |
# but, using hueristic on 6.3 on square roots | |
nrmi = c^2 + s^2 | |
nrmi = norm(nrmi - one(T)) >= 1e2*eps(T) ? inv(sqrt(nrmi)) : one(T) | |
r.xs[1] = c * nrmi | |
r.xs[2] = s * nrmi | |
end | |
idx(r::RealRotator) = r.i[1] | |
idx!(r::RealRotator, i::Int) = r.i[1] = i | |
Base.copy(a::RealRotator) = RealRotator(copy(a.xs), a.i) | |
function Base.copy!(a::RealRotator, b::RealRotator) | |
vals!(a, vals(b)...) | |
idx!(a, idx(b)) | |
end | |
# Core transform is 2x2 matrix [a b; c d] | |
mutable struct RealTransform{T} <: CoreTransform{T} | |
xs::Vector{T} # [a b; c d] | |
i::Int | |
end | |
Base.ctranspose(r::RealTransform) = RealTransform(r.xs[[1,3,2,4]], r.i) | |
using Compat | |
mutable struct DoubleShiftCounter | |
zero_index::Int | |
start_index::Int | |
stop_index::Int | |
it_count::Int | |
end | |
@compat abstract type ShiftType{T} end | |
struct RealDoubleShift{T} <: ShiftType{T} | |
N::Int | |
POLY::Vector{T} | |
Q::Vector{RealRotator{T}} | |
Ct::Vector{RealRotator{T}} # We use C', not C here | |
B::Vector{RealRotator{T}} | |
REIGS::Vector{T} | |
IEIGS::Vector{T} | |
## reusable storage | |
U::RealRotator{T} | |
V::RealRotator{T} | |
W::RealRotator{T} | |
A::Matrix{T} # for parts of A = QR | |
Bk::Matrix{T} # for diagonal block | |
R::Matrix{T} # temp storage, sometimes R part of QR | |
e1::Vector{T} # eigen values e1, e2 | |
e2::Vector{T} | |
ctrs::DoubleShiftCounter | |
end | |
function Base.convert{T}(::Type{RealDoubleShift}, ps::Vector{T}) | |
N = length(ps) | |
RealDoubleShift(N, ps, | |
ones(RealRotator{T}, N), #Q | |
ones(RealRotator{T}, N), #Ct | |
ones(RealRotator{T}, N), #B | |
zeros(T, N), zeros(T, N), #EIGS | |
one(RealRotator{T}), one(RealRotator{T}), one(RealRotator{T}), #UVW | |
zeros(T, 3, 2),zeros(T, 3, 2),zeros(T, 3, 2), # A Bk R | |
zeros(T,2), zeros(T,2), | |
DoubleShiftCounter(0,1,N-1, 0) | |
) | |
end | |
## need to compute by hand in case we use big values | |
function Base.eigvals{T}(state::RealDoubleShift{T}) | |
A = state.A[1:2, 1:2] | |
b = (A[1,1] + A[2,2]) # trace(A) | |
c = A[1,1] * A[2,2] - A[1,2] * A[2,1] # det(A) | |
discr = b^2 - 4.0 * c | |
if sign(discr) < 0 | |
state.e1[1], state.e1[2] = b/2.0, sqrt(-discr)/2.0 | |
state.e2[1], state.e2[2] = state.e1[1], -state.e1[2] | |
else | |
state.e1[2], state.e2[2] = zero(T), zero(T) # real | |
sdiscr = sqrt(discr) | |
u, v = b + sdiscr, b - sdiscr | |
if iszero(u) || iszero(v) | |
u, v = zero(T), zero(T) | |
elseif abs(u) > abs(v) | |
u = u / 2.0 | |
v = c / u | |
else | |
v = v / 2.0 | |
u = c / v | |
end | |
state.e1[1], state.e2[1] = u, v | |
end | |
end | |
################################################## | |
# | |
# | |
## Diagonostic code | |
function as_full{T}(a::RealRotator{T}, N::Int) | |
c,s = vals(a) | |
i = idx(a) | |
i < N || error("too big") | |
A = eye(T, N) | |
A[i:i+1, i:i+1] = [c -s; s c] | |
A | |
end | |
zero_out!{T}(A::Array{T}, tol=1e-12) = A[abs.(A) .<= tol] = zero(T) | |
## diagnostic | |
## create Full matrix from state object. For diagnostic purposes. | |
function Base.full{T}(state::RealDoubleShift{T}, what=:A) | |
N = state.N | |
Q = as_full(state.Q[1],N+1); for i in 2:N Q = Q * as_full(state.Q[i],N+1) end | |
Ct = as_full(state.Ct[1], N+1); for i in 2:N Ct = as_full(state.Ct[i],N+1)*Ct end | |
B = as_full(state.B[1],N+1); for i in 2:N B = B * as_full(state.B[i],N+1) end | |
x = -vcat(state.POLY[2:N], state.POLY[1], one(T)) | |
alpha = norm(x) | |
e1 = zeros(T, N+1); e1[1]=one(T) | |
en = zeros(T, N+1); en[N] = one(T) | |
y = alpha * en | |
en1 = zeros(T, N+1); en1[N+1] = one(T) | |
rho = en1' * Ct * e1 | |
yt = vec(-1/rho * en1' * Ct * B) | |
yt[abs.(yt) .< 1e-12] = 0 # tidy | |
# A = Q * Ct * (B + e1 * y') | |
W = (B + e1 * yt') | |
zero_out!(W) | |
what == :W && return W | |
R = Ct * W | |
zero_out!(R) | |
what == :R && return R | |
A = Q * Ct * (B + e1 * yt') | |
zero_out!(A) | |
A | |
end | |
# simple graphic to show march of algorithm | |
function show_status{T}(state::RealDoubleShift{T}) | |
x = fill(".", state.N+2) | |
x[state.ctrs.zero_index+1] = "α" | |
x[state.ctrs.start_index+1] = "x" | |
x[state.ctrs.stop_index+2] = "Δ" | |
println(join(x, "")) | |
end | |
## | |
rotm(a,b) = [a -b; b a] | |
## | |
################################################## | |
# rotations; find values | |
# Real Givens | |
# This subroutine computes c and s such that, | |
# | |
# [c -s] * [a, b] = [0]; c^2 + s^2 = 1 | |
# | |
# and | |
# | |
# r = sqrt(|a|^2 + |b|^2). | |
# | |
# XXX seems faster to just return r, then not | |
function givensrot{T <: Real}(a::T,b::T, donorm=Val{false}) | |
iszero(b) && return (sign(a) * one(T), zero(T), abs(a)) | |
iszero(a) && return(zero(T), sign(b) * one(T), abs(b)) | |
r = hypot(a,b) | |
c, s = a/r, b/r | |
return(c,-s,r) | |
end | |
#### Operations on [,[ terms | |
## The zero_index and stop_index+1 point at "P" matrices; RealRotator(p,0) with p^2 = 1 (better name for these matrices?) | |
## | |
## We have pflip for moving P_i * R_{i+1} -> R'_{i+1} * P_i (solving R'_{i+1} = P_i * R_{i+1} * P_i | |
## basically just R(a, p*b) | |
## | |
function pflip{T}(a::RealRotator{T}, p=one(T)) | |
u,v = vals(a) | |
vals!(a, u, sign(p)*v) | |
end | |
# get p from rotator which is RR(1,0) or RR(-1, 0) | |
function getp{T}(a::RealRotator{T}) | |
c, s = vals(a) | |
norm(s) <= 4eps(T) || error("a is not a 'P' matrix") | |
sign(c) | |
end | |
## fuse combines two rotations into one, :left updates a, :right updates b | |
function fuse{T}(a::RealRotator{T}, b::RealRotator{T}, dir=Val{:right}) | |
idx(a) == idx(b) || error("can't fuse") | |
ac, as = vals(a) | |
bc, bs = vals(b) | |
u,v = ac * bc - as * bs, ac * bs + as *bc | |
if dir == Val{:left} | |
vals!(a, u, v) | |
else | |
vals!(b, u, v) | |
end | |
end | |
# Turnover: Q1 Q3 | x x x | Q1 | |
# Q2 = | x x x | = Q3 Q2 <-- misfit=3 Q1, Q2 shift; | |
# | x x x | | |
# | |
# misfit is Val{:right} for <-- (right to left turnover), Val{:left} for --> | |
# | |
function turnover{T}(Q1::RealRotator{T}, Q2::RealRotator{T}, Q3::RealRotator{T}, misfit=Val{:right}) | |
i,j,k = idx(Q1), idx(Q2), idx(Q3) | |
(i == k) || error("Need to have a turnover up down up or down up down: have i=$i, j=$j, k=$k") | |
abs(j-i) == 1 || error("Need to have |i-j| == 1") | |
c1,s1 = vals(Q1) | |
c2,s2 = vals(Q2) | |
c3,s3 = vals(Q3) | |
# initialize c4 and s4 | |
a = c1*c2*s3 + s1*c3 | |
b = s2*s3 | |
# check norm([a,b]) \approx 1 | |
c4, s4, nrm = givensrot(a,b)#, Val{true}) | |
# initialize c5 and s5 | |
a = c1*c3 - s1*c2*s3 | |
b = nrm | |
# check norm([a,b]) \approx 1 | |
c5, s5, tmp = givensrot(a,b) | |
# second column | |
u = -c1*s3 - s1*c2*c3 | |
v = c1*c2*c3 - s1*s3 | |
w = s2 * c3 | |
a = c4*c5*v - s4*c5*w + s5*u | |
b = c4*w + s4*v | |
c6, s6, tmp = givensrot(a,b) | |
## for misfit=false move --> (misfit starts on left), true <-- (misfit starts on right) | |
if misfit == Val{:left} | |
vals!(Q2, c4, -s4) | |
vals!(Q3, c5, -s5) | |
vals!(Q1, c6, -s6) | |
idx!(Q1, j) # misfit gets shifted | |
else | |
vals!(Q3, c4, -s4) | |
vals!(Q1, c5, -s5) | |
vals!(Q2, c6, -s6) | |
idx!(Q3, j) # misfit | |
end | |
end | |
### Related to decompostion QR into QC(B + ...) | |
## fill A[k:k+2, k:k+1] k in 2:N | |
## updates A | |
## | |
# We look for r_j,k. Depending on |j-k| there are different amounts of work | |
# we have wk = (B + e1 y^t) * ek = B*ek + e1 yk; we consider B * ek only B1 ... Bk ek applies | |
# | |
# julia> @vars bk1 bk2 bj1 bj2 bi1 bi2 | |
# julia> rotm(bi1, bi2, 1, 4) * rotm(bj1, bj2, 2, 4) * rotm(bk1, bk2, 3, 4) * [0, 0, 1, 0] # B_{k-2} * B_{k-1} * B_k * ek = W | |
# 4-element Array{SymPy.Sym,1} | |
# ⎡bi₂⋅bj₂⋅bk₁ ⎤ | |
# ⎢ ⎥ | |
# ⎢-bi₁⋅bj₂⋅bk₁⎥ | |
# ⎢ ⎥ | |
# ⎢ bj₁⋅bk₁ ⎥ | |
# ⎢ ⎥ | |
# ⎣ bk₂ ⎦ | |
# which gives W = [what_{k-2} w_{k-1} w_k w_{k+1}] | |
# For rkk, we have Ck * W = [rkk, 0] | |
# @vars ck1 ck2 what w1 | |
# u = rotm(ck1, ck2, 1,2) * [what, w1] | |
# u[1](what => solve(u[2], what)[1]) |> simplify | |
# ⎛ 2 2⎞ | |
# -w₁⋅⎝c₁ + c₂ ⎠ | |
# ──────────────── # this is rkk = -w1/c2 = -bk2/ck2 | |
# c₂ | |
# | |
# For r[k-1, k] we need to do more work. We need [what_{k-1}, w_k, w_{k+1}], where w_k, w_{k+1} found from B values as above. | |
# | |
# julia> @vars ck1 ck2 cj1 cj2 what w w1 | |
# (ck1, ck2, cj1, cj2, what, w, w1) | |
# julia> u = rotm(ck1, ck2, 2, 3) * rotm(cj1, cj2, 1, 3) * [what, w, w1] # C^*_{k} * C^*{k-1} * W = [r_{k-1,k}, r_{k,k}, 0] | |
# 3-element Array{SymPy.Sym,1} | |
# ⎡ cj₁⋅ŵ - cj₂⋅w ⎤ | |
# ⎢ ⎥ | |
# ⎢cj₁⋅ck₁⋅w + cj₂⋅ck₁⋅ŵ - ck₂⋅w₁ ⎥ | |
# ⎢ ⎥ | |
# ⎣cj₁⋅ck₂⋅w + cj₂⋅ck₂⋅ŵ + ck₁⋅w₁ ⎦ | |
# julia> u[1](what => solve(u[3], what)[1]) |> simplify | |
# 2 | |
# cj₁ ⋅w cj₁⋅ck₁⋅w₁ | |
# - ────── - ────────── - cj₂⋅w | |
# cj₂ cj₂⋅ck₂ | |
# | |
# For r_{k-2,k} we need to reach back one more step | |
# C^*_{k} * C^*{k-1} * C^*_{k-2} W = [r_{k-2,k} r_{k-1,k}, r_{k,k}, 0] | |
# | |
# julia> @vars ck1 ck2 cj1 cj2 ci1 ci2 what wm1 w w1 | |
# julia> u = rotm(ck1, ck2, 3, 4) * rotm(cj1, cj2, 2, 4) * rotm(ci1, ci2, 1, 4) * [what, wm1, w, w1] | |
# julia> u[1](what => solve(u[4], what)[1]) |> simplify | |
# 2 | |
# ci₁ ⋅wm₁ ci₁⋅cj₁⋅w ci₁⋅ck₁⋅w₁ | |
# - ──────── - ───────── - ─────────── - ci₂⋅wm₁ | |
# ci₂ ci₂⋅cj₂ ci₂⋅cj₂⋅ck₂ | |
function diagonal_block{T}(state::RealDoubleShift{T}, k) | |
k >= 2 && k <= state.N || error("$k not in [2,n]") | |
A = state.A | |
R = state.R | |
if k == 2 | |
# here we only need [r11 r12; 0 r22], so only use top part of R | |
for j in 1:2 | |
ck1, ck2 = vals(state.Ct[k - (2-j)]) | |
w1 = vals(state.B[k - (2-j)])[2] | |
R[j,j] = - w1 / ck2 | |
end | |
cj1, cj2 = vals(state.Ct[k-1]) | |
ck1, ck2 = vals(state.Ct[k]) | |
w = vals(state.B[k-1])[1] * vals(state.B[k])[1] | |
w1 = vals(state.B[k])[2] | |
val = -(cj1^2*w)/cj2 - (cj1 * ck1 * w1) / (cj2 * ck2) - cj2 * w | |
R[1,2] = val | |
q11, q12 = vals(state.Q[k-1]); q21, q22 = vals(state.Q[k]) | |
A[1,1] = q11 * R[1,1] | |
A[1,2] = q11 * R[1,2] - q12 * q21 * R[2,2] | |
A[2,1] = q12 * R[1,1] | |
A[2,2] = q11 * q21 * R[2,2] + q12 * R[1,2] | |
else ## Need condition on N, as Bn is Bn*Zn | |
## R = R[k-2:k, k-1:k] | |
Qk_2, Qk_1, Qk = state.Q[k-2].xs, state.Q[k-1].xs, state.Q[k].xs | |
wk_1, wk, wk1 = state.B[k-2].xs[2], state.B[k-1].xs[2], state.B[k].xs[2] | |
# r_kk | |
for j in 1:2 | |
K = k - (2-j) | |
ck1, ck2 = vals(state.Ct[K]) | |
w1 = vals(state.B[K])[2] | |
R[j+1,j] = - w1 / ck2 | |
end | |
for j in 1:2 | |
K = k - (2-j) | |
cj1, cj2 = vals(state.Ct[K-1]) | |
ck1, ck2 = vals(state.Ct[K]) | |
w = vals(state.B[K-1])[1] * vals(state.B[K])[1] | |
w1 = vals(state.B[K])[2] | |
val = -(cj1^2*w)/cj2 - (cj1 * ck1 * w1) / (cj2 * ck2) - cj2 * w | |
R[j,j] = val | |
end | |
# last is R[1,2] | |
wm1 = -vals(state.B[k-2])[1] * vals(state.B[k-1])[2] * vals(state.B[k])[1] | |
w = vals(state.B[k-1])[1] * vals(state.B[k])[1] | |
w1 = vals(state.B[k])[2] | |
ci1, ci2 = vals(state.Ct[k-2]) | |
cj1, cj2 = vals(state.Ct[k-1]) | |
ck1, ck2 = vals(state.Ct[k]) | |
R[1,2] = -(ci1^2 * wm1/ci2) - (ci1 * cj1 * w) / (ci2 * cj2) - (ci1 * ck1 * w1) / (ci2 * cj2 * ck2) - ci2*wm1 | |
# A is Q*R, but not all Q contribute | |
# This is for k = 5 | |
# julia> A[4,4] | |
# q₃ ₁⋅q₄ ₁⋅r₄₄ + q₃ ₂⋅r₃₄ | |
# julia> A[4,5] | |
# q₃ ₁⋅q₄ ₁⋅r₄₅ - q₃ ₁⋅q₄ ₂⋅q₅ ₁⋅r₅₅ + q₃ ₂⋅r₃₅ | |
# julia> A[5,4] | |
# q₄ ₂⋅r₄₄ | |
# julia> A[5,5] | |
# q₄ ₁⋅q₅ ₁⋅r₅₅ + q₄ ₂⋅r₄₅ | |
A[1,1] = Qk_2[1] * Qk_1[1] * R[2,1] + Qk_2[2] * R[1,1] | |
A[1,2] = Qk_2[1] * Qk_1[1] * R[2,2] - Qk_2[1] * Qk_1[2] * Qk[1] * R[3,2] + Qk_2[2] * R[1,2] | |
A[2,1] = Qk_1[2] * R[2,1] | |
A[2,2] = Qk_1[1] * Qk[1] * R[3,2] + Qk_1[2] * R[2,2] | |
end | |
end | |
## Deflation | |
function check_deflation{T}(state::RealDoubleShift{T}, tol = eps(T)) | |
# println([u.xs[2] for u in state.Q]) | |
for k in state.ctrs.stop_index:-1:state.ctrs.start_index | |
if abs(vals(state.Q[k])[2]) <= tol | |
deflate(state, k) | |
return | |
end | |
end | |
end | |
# deflate a term | |
function deflate{T}(state::RealDoubleShift{T}, k) | |
# make a P matrix | |
vals!(state.Q[k], getp(state.Q[k]), zero(T)) | |
# shift zero counter | |
state.ctrs.zero_index = k # points to a matrix Q[k] either RealRotator(-1, 0) or RealRotator(1, 0) | |
state.ctrs.start_index = k + 1 | |
# reset counter | |
state.ctrs.it_count = 1 | |
end | |
## Bulge chasing | |
function create_bulge{T}(state::RealDoubleShift{T}) | |
if state.ctrs.it_count == 0 | |
t = rand() * 2pi | |
re1, ie1 = cos(t), sin(t) | |
re2, ie2 = re1, -ie1 | |
vals!(state.U, re1, ie1); idx!(state.U, state.ctrs.start_index) | |
vals!(state.V, re2, ie2); idx!(state.V, state.ctrs.start_index + 1) | |
else | |
# compute (A-rho1) * (A - rho2) * e_1 | |
# get first columns of A | |
Bk = state.Bk | |
# find e1, e2 | |
diagonal_block(state, state.ctrs.stop_index+1) | |
eigvals(state) | |
l1r, l1i = state.e1 | |
l2r, l2i = state.e2 | |
# find first part of A[1:3, 1:2] | |
diagonal_block(state, state.ctrs.start_index+1) | |
Bk[1:2, 1:2] = state.A[1:2, 1:2] | |
if state.ctrs.start_index + 2 <= state.N # (Why this condition?) | |
diagonal_block(state, state.ctrs.start_index+2) | |
Bk[3,2] = state.A[2, 1] | |
end | |
# make first three elements of c1,c2,c3 | |
# c1 = real(-l1i⋅l2i + ⅈ⋅l1i⋅l2r - ⅈ⋅l1i⋅t₁₁ + ⅈ⋅l1r⋅l2i + l1r⋅l2r - l1r⋅t₁₁ - ⅈ⋅l2i⋅t₁₁ - l2r⋅t₁₁ + t₁₁^2 + t₁₂⋅t₂₁) | |
# c2 = real(-ⅈ⋅l1i⋅t₂₁ - l1r⋅t₂₁ - ⅈ⋅l2i⋅t₂₁ - l2r⋅t₂₁ + t₁₁⋅t₂₁ + t₂₁⋅t₂₂) | |
# c3 = real(t₂₁⋅t₃₂) | |
c1 = -l1i * l2i + l1r*l2r -l1r*Bk[1,1] -l2r * Bk[1,1] + Bk[1,1]^2 + Bk[1,2] * Bk[2,1] | |
c2 = -l1r * Bk[2,1] - l2r * Bk[2,1] + Bk[1,1]* Bk[2,1] + Bk[2,1] * Bk[2,2] | |
c3 = Bk[2,1] * Bk[3,2] | |
c,s, nrm = givensrot(c2, c3, Val{true}) | |
vals!(state.V, c,-s) | |
idx!(state.V, state.ctrs.start_index + 1) | |
c,s, tmp = givensrot(c1, nrm) | |
vals!(state.U, c, -s) | |
idx!(state.U, state.ctrs.start_index) | |
end | |
end | |
## make W on left side | |
# | |
# initial Q0 | |
# we do turnover U1' Q1 --> U1' --> U1' --> Q1 | |
# V1' Q2 Q1 V1' Q2 Q1 (V1'Q2) W1 Q2 | |
# With this, W will be kept on the left side until the last step, U,V | |
# move through left to right by one step, right to left by unitariness | |
# | |
# Q0 Q0 Q0 Q0 | |
# U1' Q1 U1' Q1* -> U1 --> U1* | |
# V1' Q3 -> V1' Q3 Q1** V1' Q3 W (V1'Q3) | |
# | |
# Q0 is (p,0) rotator, p 1 or -1. We have | |
# Q0 --> Q0 | |
# R (r, pr2) | |
function prepare_bulge{T}(state::RealDoubleShift{T}) | |
# println("prepare bulge") | |
# as_full(V', N+1)* as_full(U', N+1)* full(state) * as_full(V, N+1) * as_full(U, N+1) |> eigvals |> println | |
k = state.ctrs.start_index | |
Ut = state.U'; Vt = state.V' | |
copy!(state.W, state.Q[k]) | |
p = k == 1 ? one(T) : state.Q[k-1].xs[1] # zero index implies Q0 = RR(1,0) or RR(-1,0) | |
pflip(state.W, p) | |
turnover(Ut, Vt, state.W, Val{:right}) | |
fuse(Vt, state.Q[k+1], Val{:right}) # V' Q3 | |
pflip(Ut, p) | |
vals!(state.Q[k], vals(Ut)...) # Ut.xs[1], p * Ut.xs[2]) | |
end | |
function chase_bulge{T}(state::RealDoubleShift{T}, tr) | |
# println(" begin chase at level $(state.V.i)") | |
# as_full(state.W, state.N+1)* full(state) * as_full(state.V, state.N+1) * as_full(state.U, state.N+1) |> eigvals |> println | |
# one step | |
i = idx(state.V) | |
#println("k=$(state.ctrs.stop_index); Is Qk+1 identity? ", state.Q[state.ctrs.stop_index+1].xs[2]) | |
## The i < tr is the speed up described in Exploting C_i = B_i in early stages | |
while i < state.ctrs.stop_index # loops from start_index to stop_index - 1 | |
if i <= tr | |
turnover(state.B[i], state.B[i+1], copy(state.V)) | |
turnover(state.B[i-1], state.B[i], copy(state.U)) | |
for k in -1:1 | |
a,b = vals(state.B[i+k]) | |
vals!(state.Ct[i+k], a, -b) # using copy!(Ct, B') is slower | |
end | |
idx!(state.U, i-1); idx!(state.V, i) | |
else | |
turnover(state.B[i], state.B[i+1], state.V) | |
turnover(state.Ct[i+1], state.Ct[i], state.V) | |
j = idx(state.U) | |
turnover(state.B[j], state.B[j+1], state.U) | |
turnover(state.Ct[j+1], state.Ct[j], state.U) | |
end | |
turnover(state.Q[i], state.Q[i+1], state.V) | |
turnover(state.Q[i-1], state.Q[i], state.U) | |
turnover(state.W, state.V, state.U, Val{:left}) | |
i = idx(state.V) | |
end | |
# println("end chase") | |
# as_full(state.W, state.N+1)* full(state) * as_full(state.V, state.N+1) * as_full(state.U, state.N+1) |> eigvals |> println | |
end | |
function absorb_bulge{T}(state::RealDoubleShift{T}) | |
# println("absorb 0") | |
# as_full(state.W, state.N+1) * full(state) * as_full(state.V, state.N+1) * as_full(state.U, N+1) |> eigvals |> println | |
# first V goes through B, C then fuses with Q | |
i = idx(state.V) | |
turnover(state.B[i], state.B[i+1], state.V, Val{:right}) | |
turnover(state.Ct[i+1], state.Ct[i], state.V) | |
## We may be fusing Q P --> (Q') | |
# RR(-1,0) RR(-1,0) | |
# | |
p = getp(state.Q[i+1]) | |
pflip(state.V, p) | |
fuse(state.Q[i], state.V, Val{:left}) # fuse Q*V -> Q | |
# println("absorb 1") | |
# as_full(state.W, state.N+1) * full(state) * as_full(state.U, state.N+1) |> eigvals |> println | |
# Then bring U through B, C, and Q to fuse with W | |
j = idx(state.U) | |
turnover(state.B[j], state.B[j+1], state.U) | |
turnover(state.Ct[j+1], state.Ct[j], state.U) | |
turnover(state.Q[j], state.Q[j+1], state.U) | |
fuse(state.W, state.U, Val{:right}) | |
# println("absorb 2") | |
# as_full(state.U, state.N+1) * full(state) |> eigvals |> println | |
# similarity transformation, bring through then fuse with Q | |
j = idx(state.U) | |
turnover(state.B[j], state.B[j+1], state.U, Val{:right}) | |
turnover(state.Ct[j+1], state.Ct[j], state.U) | |
p = getp(state.Q[j+1]) | |
pflip(state.U, p) | |
fuse(state.Q[j], state.U, Val{:left}) | |
# println("absorb final") | |
# full(state) |> eigvals |> println | |
end | |
function bulge_step{T}(state::RealDoubleShift{T}, tr) | |
create_bulge(state) | |
#println("bulge created") | |
#as_full(U', N+1)* as_full(V', N+1)* full(state) * as_full(V, N+1) * as_full(U, N+1) |> eigvals |> println | |
prepare_bulge(state) | |
#println("prepare bulge, make W") | |
#as_full(W, N+1) * full(state) * as_full(V, N+1) * as_full(U, N+1) |> eigvals |> println | |
chase_bulge(state, tr) | |
absorb_bulge(state) | |
#full(state) |> eigvals |> println | |
end | |
function init_state{T}(state::RealDoubleShift{T}) | |
N, ps= state.N, state.POLY | |
Q, Ct, B = state.Q, state.Ct, state.B | |
for ii = 1:(N-1) | |
vals!(Q[ii], zero(T), one(T)); idx!(Q[ii], ii) | |
end | |
vals!(Q[N], one(T), zero(T)); idx!(Q[N], N) | |
# play with signs here. | |
s = iseven(N) ? one(T) : -one(T) | |
a, b, temp = givensrot(-ps[N], -one(T), Val{true}) | |
vals!(Ct[N], -s*a, -s*b); idx!(Ct[N], N) | |
vals!(B[N], -b, -a); idx!(B[N], N) | |
for ii in 2:N | |
a, b, temp = givensrot(-ps[ii-1], temp, Val{true}) | |
vals!(Ct[N-ii + 1], a, -b); idx!(Ct[N-ii+1], N-ii+1) | |
vals!(B[N-ii + 1], a, b); idx!(B[N-ii+1], N-ii+1) | |
end | |
end | |
## Main algorithm of AMV&W | |
function AMVW_algorithm{T}(state::RealDoubleShift{T}) | |
it_max = 30 * state.N | |
kk = 0 | |
tr = state.N - 2 | |
while kk <= it_max | |
## finished up! | |
state.ctrs.stop_index <= 0 && return | |
check_deflation(state) | |
kk += 1 | |
# show_status(state) | |
k = state.ctrs.stop_index | |
if state.ctrs.stop_index - state.ctrs.zero_index >= 2 | |
bulge_step(state, tr) | |
state.ctrs.it_count += 1 | |
tr -= 2 | |
elseif state.ctrs.stop_index - state.ctrs.zero_index == 1 | |
diagonal_block(state, k + 1) | |
eigvals(state) | |
state.REIGS[k], state.IEIGS[k] = state.e2 | |
state.REIGS[k+1], state.IEIGS[k+1] = state.e1 | |
diagonal_block(state, 2) | |
if state.ctrs.stop_index == 2 | |
diagonal_block(state, 2) | |
state.REIGS[1] = state.A[1,1] | |
end | |
state.ctrs.zero_index = 0 | |
state.ctrs.start_index = 1 | |
state.ctrs.stop_index = state.ctrs.stop_index - 2 | |
elseif state.ctrs.stop_index - state.ctrs.zero_index == 0 | |
diagonal_block(state, state.ctrs.stop_index + 1) | |
e1, e2 = state.A[1,1], state.A[2,2] # eigvals(Bk[1:2, 1:2]) | |
if state.ctrs.stop_index == 1 | |
state.REIGS[state.ctrs.stop_index] = e1 | |
state.REIGS[state.ctrs.stop_index+1] = e2 | |
state.ctrs.stop_index = 0 | |
else | |
state.REIGS[state.ctrs.stop_index+1] = e2 | |
k = state.ctrs.stop_index | |
state.ctrs.zero_index = 0 | |
state.ctrs.start_index = 1 | |
state.ctrs.stop_index = k - 1 | |
end | |
end | |
end | |
println("error -- didn't work!!! Take care of this") | |
end | |
""" | |
Use AMVW algorithm doubleshift alorithm to find roots | |
of the polynomial p_0 + p_1 x + p_2 x^2 + ... + p_n x^n encoded as | |
`[p_0, p_1, ..., p_n]` (the same ordering used by `Polynomials`). | |
Returns an object of type `RealDoubleShift`. | |
Example: API needs work! | |
``` | |
using Polynomials | |
x = variable() | |
p = poly(x - i/10 for i in 5:10) | |
state = amvw(p.a) | |
complex.(state.REIGS, state.IEIGS) | |
``` | |
""" | |
function amvw{T <: Real}(ps::Vector{T}) | |
qs, k = reverse_poly(ps) | |
# k is number of 0 factors | |
n = length(qs) | |
n ==0 && error("0 polynomial") | |
state = RealDoubleShift(qs) | |
init_state(state) | |
AMVW_algorithm(state) | |
state | |
end | |
end | |
# test | |
using AMVW, Polynomials | |
A = AMVW | |
## some interface for polynomials | |
function amvw(p::Poly) | |
qs, k = A.reverse_poly(p.a) | |
state = A.RealDoubleShift(qs) | |
A.init_state(state) | |
A.AMVW_algorithm(state) | |
complex.(state.REIGS, state.IEIGS) | |
end | |
## quick hack -- doesn't work with complex, big, ... | |
function residual_check(p::Poly) | |
# r1 = |P(lambda)/P'(lambda)| | |
# r2 = |P(lambda)/P'(lambda)/lambda| | |
# r3 = ||Cv-lambda v||/||C||/||v|| | |
state = A.amvw(p.a) | |
#lambdas = complex.(state.REIGS, state.IEIGS) | |
lambdas = state.REIGS | |
lambdas = lambdas - p.(lambdas) ./ polyder(p).(lambdas) | |
sort!(lambdas) # complex??? | |
r1 = norm(p.(lambdas) ./ polyder(p).(lambdas)) | |
r2 = norm(p.(lambdas) ./ polyder(p).(lambdas) ./ lambdas) | |
A.init_state(state) | |
C = full(state)[1:end-1, 1:end-1] | |
es = eigfact(C) | |
vals = es.values | |
vecs = es.vectors | |
ind = sortperm(vals) | |
vecs = vecs[:,ind] | |
r3 = 0.0 | |
for i in eachindex(vals) | |
lambda = lambdas[i] | |
v = vecs[:,i] | |
r3 += norm(C*v - lambda * v) / norm(C) / norm(v) | |
end | |
## now for roots | |
lambdas = sort(roots(p)) | |
lambdas = lambdas - p.(lambdas) ./ polyder(p).(lambdas) | |
rr1 = norm(p.(lambdas) ./ polyder(p).(lambdas)) | |
rr2 = norm(p.(lambdas) ./ polyder(p).(lambdas) ./ lambdas) | |
rr3 = 0.0 | |
for i in eachindex(vals) | |
lambda = lambdas[i] | |
v = vecs[:,i] | |
rr3 += norm(C*v - lambda * v) / norm(C) / norm(v) | |
end | |
(r1/rr1, r2/rr2, r3/rr3) | |
end | |
## Tests | |
## Time compared to `roots` -- slower | |
## | |
# julia> n = 15; p = poly(linspace(1/n, 1, n)); | |
# julia> using BenchmarkTools; @benchmark roots(p) | |
# BenchmarkTools.Trial: | |
# memory estimate: 41.58 KiB | |
# allocs estimate: 51 | |
# -------------- | |
# minimum time: 189.171 μs (0.00% GC) | |
# median time: 191.314 μs (0.00% GC) | |
# mean time: 205.429 μs (2.56% GC) | |
# maximum time: 3.588 ms (88.18% GC) | |
# -------------- | |
# samples: 10000 | |
# evals/sample: 1 | |
# time tolerance: 5.00% | |
# memory tolerance: 1.00% | |
# # | |
# julia> @benchmark amvw(p) | |
# BenchmarkTools.Trial: | |
# memory estimate: 48.06 KiB | |
# allocs estimate: 1068 | |
# -------------- | |
# minimum time: 490.720 μs (0.00% GC) | |
# median time: 797.224 μs (0.00% GC) | |
# mean time: 816.253 μs (1.21% GC) | |
# maximum time: 6.539 ms (81.94% GC) | |
# -------------- | |
# samples: 5915 | |
# evals/sample: 1 | |
# time tolerance: 5.00% | |
# memory tolerance: 1.00% | |
## and for | |
# julia> n = 25; p = poly(linspace(1/n, 1, n)); | |
# julia> @benchmark roots(p) | |
# BenchmarkTools.Trial: | |
# memory estimate: 46.44 KiB | |
# allocs estimate: 60 | |
# -------------- | |
# minimum time: 301.662 μs (0.00% GC) | |
# median time: 310.948 μs (0.00% GC) | |
# mean time: 324.381 μs (1.90% GC) | |
# maximum time: 4.811 ms (90.75% GC) | |
# -------------- | |
# samples: 10000 | |
# evals/sample: 1 | |
# time tolerance: 5.00% | |
# memory tolerance: 1.00% | |
# julia> @benchmark amvw(p) | |
# BenchmarkTools.Trial: | |
# memory estimate: 64.08 KiB | |
# allocs estimate: 1640 | |
# -------------- | |
# minimum time: 1.114 ms (0.00% GC) | |
# median time: 1.859 ms (0.00% GC) | |
# mean time: 1.956 ms (0.94% GC) | |
# maximum time: 8.569 ms (71.05% GC) | |
# -------------- | |
# samples: 2512 | |
# evals/sample: 1 | |
# time tolerance: 5.00% | |
# memory tolerance: 1.00% | |
## scalen in n is not linear, not quadratic | |
# ns = [4,8,16,32,64,128,256] | |
# ts = zeros(length(ns)) | |
# for i in eachindex(ns) | |
# n = ns[i] | |
# p = poly(rand(n)) | |
# ts[i] = time() | |
# amvw(p) | |
# ts[i] = time() - ts[i] | |
# end | |
# julia> ts[2:end] ./ ts[1:end-1] | |
# 6-element Array{Float64,1}: | |
# 0.81346 | |
# 3.22482 | |
# 3.85639 | |
# 2.87251 | |
# 4.80438 | |
# 4.24259 | |
## Accuracy | |
# for n in 5:10 | |
# p = poly(1.0:n) | |
# print("n=$n: "); println(residual_check(p)) | |
# end | |
# ## ratios of residuasl of amvw to roots | |
# julia> for n in 5:10 | |
# p = poly(1.0:n) | |
# print("n=$n: "); println(residual_check(p)) | |
# end | |
# n=5: (1.2249137731278716, 1.373964400725604, 0.8348154805166964) | |
# n=6: (0.973487711587202, 0.8271239165425018, 0.9888450594647631) | |
# n=7: (0.6176463305637532, 0.5859554825199832, 1.00385484319229) | |
# n=8: (0.9089421995514068, 0.8834377341004336, 0.9318743185837587) | |
# n=9: (1.2426448919951487, 1.3151521779378252, 0.9904264553424764) | |
# n=10: (2.280076313005534, 2.30722853091966, 0.9395012114789933) | |
## testing | |
wpoly(n) = poly(1.0:n) | |
amvw(wpoly(10)) | |
#residual_check(wpoly(10)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment