Created
April 1, 2015 23:26
-
-
Save jiahao/4cd83660e2c741125a7b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#function axpy!{T}(a::T, b::StridedVector{T}, c::StridedVector{T}) | |
# @simd for i=1:size(b, 1) | |
# c[i] += a*b[i] | |
# end | |
#end | |
mgs(A)=mgs!(copy(A)) | |
function mgs!(Q) #Does not store R | |
m, n = size(Q) | |
Qc= [slice(Q,:,k) for k=1:n] | |
@inbounds for k = 1:n | |
for i = 1:k-1 | |
#r = 0.0 | |
#for j=1:m | |
# r += Q[j,i]*Q[j,k] | |
#end | |
r = Qc[i] ⋅ Qc[k] | |
Base.LinAlg.BLAS.axpy!(-r, Qc[i], Qc[k]) | |
#Slower - but only slightly so | |
#@simd for j=1:m | |
# Q[j,k] -= r*Q[j,i] | |
#end | |
end | |
r = norm(Qc[k]) | |
scale!(Qc[k], 1/r) | |
end | |
Q, nothing | |
end | |
#Test | |
let | |
Q, R=mgs!(randn(5,3)) | |
@assert all([abs(norm(Q[:,i]) - 1) < 1e-12 for i=1:size(Q, 2)]) | |
for i=1:size(Q, 2), j=1:i-1 | |
@assert abs(Q[:,i]⋅Q[:,j]) < 1e-12 | |
end | |
end | |
cgs(A) = cgs!(copy(A)) | |
function cgs!(Q) | |
m, n = size(Q) | |
R = zeros(n, n) | |
Qc = [slice(Q,:,k) for k=1:n] | |
@inbounds for k = 1:n | |
#Step 1: Compute orthogonalization coefficients | |
#gemv and dot have similar performance | |
Ac_mul_B!(slice(R,1:k-1,k), slice(Q,:,1:k-1), Qc[k]) | |
#for i = 1:k-1 | |
# R[i,k] = Qc[i]⋅Qc[k] | |
#end | |
#The scalar form is significantly slower | |
#for i=1:k-1, j=1:m | |
# R[i,k] += Q[j,i]*Q[j,k] | |
#end | |
#Step 2: Orthogonalize vectors in place | |
#Interestingly, we get quite a few variations in speed | |
#for i = 1:k-1 | |
# @simd for j=1:m | |
# Q[j,k] -= R[i,k]*Q[j,i] | |
# end | |
#end | |
#The BLAS1 form is the fastest | |
for i = 1:k-1 | |
Base.LinAlg.BLAS.axpy!(-R[i,k], Qc[i], Qc[k]) | |
end | |
#The BLAS2 form is slower | |
#Base.LinAlg.BLAS.gemv!('N', -1.0, slice(Q,:,1:k-1), slice(R,1:k-1,k), 1.0, slice(Q,:,k)) | |
#The BLAS3 form is slowest of all! | |
#Base.LinAlg.BLAS.gemm!('N', 'N', -1.0, slice(Q,:,1:k-1), slice(R,1:k-1,k:k), 1.0, slice(Q,:,k:k)) | |
#Step 3: Normalize vectors | |
R[k,k] = norm(Qc[k]) | |
scale!(Qc[k], 1/R[k,k]) | |
end | |
Q, UpperTriangular(R) | |
end | |
#Test | |
let | |
Q, R=cgs!(randn(5,3)) | |
@assert all([abs(norm(Q[:,i]) - 1) < 1e-12 for i=1:size(Q, 2)]) | |
for i=1:size(Q, 2), j=1:i-1 | |
@assert abs(Q[:,i]⋅Q[:,j]) < 1e-12 | |
end | |
end | |
#Benchmark | |
let | |
srand(1) | |
A=randn(1200,1200) | |
qr(randn(3,3)) | |
mgs!(randn(3,3)) | |
gc_disable() | |
@time qr(A) | |
@time Q=mgs(A) | |
gc_enable() | |
using Base.Profile | |
Profile.clear() | |
@profile mgs!(A) | |
Profile.print(C=true) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment