Skip to content

Instantly share code, notes, and snippets.

@jiahao
Last active August 29, 2015 13:56
Show Gist options
  • Save jiahao/9240888 to your computer and use it in GitHub Desktop.
Save jiahao/9240888 to your computer and use it in GitHub Desktop.
Conjugate gradients (Hestenes-Stiefel algorithm) implemented as a Julia iterator
import Base: start, next, done
immutable Terminator #Stores information related to when the algorithm should terminate
maxiter :: Int
resabsthresh :: Real
resrelthresh :: Real
end
Terminator(maxiter::Integer) = Terminator(maxiter, eps(), eps())
Terminator() = Terminator(0) #By default, always terminate.
immutable KrylovSpace
A
v0 :: Vector
end
abstract IterativeSolver
abstract IterationState
##################################################
# Conjugate gradients (Hestenes-Stiefel variant) #
##################################################
immutable cg_hs_state <: IterationState
r :: Vector #The current residual
p :: Vector #The current search direction
rnormsq :: Float64 #Squared norm of the _previous_ residual
iter :: Int #Current iteration
end
#Actually, all the IterativeSolvers should look like this, but we will need different types for dispatch purposes. See Issue #4935
immutable cg_hs <: IterativeSolver
K :: KrylovSpace
t :: Terminator
end
start(a::cg_hs) = cg_hs_state(a.K.v0, zeros(size(a.K.v0, 1)), Inf, 1)
function next(a::cg_hs, s::cg_hs_state)
rnormsq = dot(s.r, s.r)
p = isfinite(s.rnormsq) ? s.r + (rnormsq / s.rnormsq) * s.p : s.r
Ap = a.K.A*p
µ = rnormsq / dot(p, Ap)
µ*p, cg_hs_state(s.r - µ * Ap, p, rnormsq, s.iter+1)
end
done(a::cg_hs, s::cg_hs_state) = s.rnormsq<a.t.resabsthresh^2 ||
s.rnormsq<a.t.resrelthresh^2*s.rnormsq ||
s.iter == a.t.maxiter
#Generic linear solver (!)
function solve(A, b, x, maxiter, method) #Initial guess gets destroyed
K = method(KrylovSpace(A, b-A*x), Terminator(maxiter))
for dx in K
x += dx #Here we update the solution
end
x
end
#Sample problem
A=randn(10,10); A*=A'; #A is SPD
b=randn(10)
x0 = diagm(diag(A))\b
x = solve(A, b, x0, 1000, cg_hs)
norm(A*x - b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment